@@ -841,7 +894,7 @@ Python Module Index
- © Copyright 2022, Meta.
+ © Copyright 2024, Meta.
@@ -878,6 +931,9 @@ Python Module Index
+
+
+
diff --git a/search.html b/search.html
index d5e24ce05..1ee7e48b0 100644
--- a/search.html
+++ b/search.html
@@ -9,7 +9,7 @@
- Search — TorchRec 0.9.0 documentation
+ Search — TorchRec 1.1.0 documentation
@@ -29,6 +29,9 @@
+
+
+
@@ -265,7 +268,7 @@
- 0.9.0.dev20240801+cpu
+ 1.1.0.dev20240924+cpu
@@ -414,7 +417,7 @@
- © Copyright 2022, Meta.
+ © Copyright 2024, Meta.
@@ -451,6 +454,9 @@
+
+
+
diff --git a/searchindex.js b/searchindex.js
index 798e03f61..9b80fbef5 100644
--- a/searchindex.js
+++ b/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["index", "overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "overview.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.metrics.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "TorchRec Overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"pytorch": [0, 1, 4, 11, 12, 14], "domain": 0, "librari": [0, 1], "built": [0, 1, 11], "provid": [0, 1, 4, 5, 6, 8, 9, 11, 13], "common": [0, 1, 11, 14], "sparsiti": 0, "parallel": [0, 1, 4, 6], "primit": [0, 1, 4, 6], "need": [0, 4, 6, 7, 8, 9, 11, 12, 13, 14], "larg": [0, 1, 5], "scale": [0, 1], "recommend": [0, 1, 9], "system": [0, 1, 4, 5], "recsi": [0, 10, 12], "It": [0, 4, 5, 6, 8, 9, 11, 12, 13, 14], "allow": [0, 1, 4, 5, 7, 9, 11, 12], "author": [0, 1, 4], "train": [0, 1, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "model": [0, 1, 4, 5, 6, 7, 8, 9, 11, 12, 13], "embed": [0, 1, 5, 6, 7, 10, 11, 13, 14], "shard": [0, 1, 4, 5, 8, 11, 12, 13], "across": [0, 4, 5, 6, 9], "mani": [0, 1, 4, 6], "gpu": [0, 4, 5], "For": [0, 4, 5, 6, 9, 10, 11, 12, 13, 14], "instal": 0, "instruct": 0, "visit": 0, "http": [0, 4, 5, 10, 11, 14], "github": [0, 11], "com": [0, 11], "readm": 0, "In": [0, 4, 5, 11, 12, 14], "thi": [0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "we": [0, 4, 5, 6, 7, 9, 11, 12, 13, 14], "introduc": [0, 12], "primari": [0, 8], "call": [0, 4, 5, 6, 8, 9, 11, 12, 13], "distributedmodelparallel": [0, 4], "dmp": [0, 4], "like": [0, 4, 5, 6, 7, 11, 12, 14], "s": [0, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "distributeddataparallel": 0, "wrap": [0, 4, 6, 9, 12], "enabl": [0, 4, 5, 9, 12], "distribut": [0, 1, 8, 9, 11, 12, 14], "sourc": [0, 10, 11], "open": 0, "googl": 0, "colab": 0, "index": [0, 11, 14], "modul": [0, 1, 4, 5, 6, 9], "search": [0, 5], "page": 0, "design": [1, 4, 8, 9, 11], "creat": [1, 4, 7, 8, 9, 11, 12, 14], "state": [1, 4, 8, 9, 11, 12], "art": 1, "person": 1, "path": [1, 4, 5, 8], "product": 1, "wide": 1, "adopt": 1, "meta": [1, 4, 5], "infer": [1, 4, 5, 6, 13, 14], "workflow": 1, "address": [1, 4], "uniqu": [1, 5, 9, 11], "challeng": 1, "build": [1, 5], "deploi": [1, 8], "massiv": 1, "which": [1, 4, 5, 6, 8, 9, 11, 12, 14], "focu": [1, 9], "regular": 1, "more": [1, 4, 5, 6, 9, 11], "specif": [1, 4, 5, 8, 12], "follow": [1, 4, 5, 6, 9, 10, 11, 12, 14], "gener": [1, 4, 5, 7, 8, 10, 11, 12, 14], "special": [1, 7, 9, 11, 12], "compon": [1, 9, 11], "simplist": 1, "ar": [1, 4, 5, 6, 8, 9, 11, 12, 13, 14], "tabl": [1, 4, 5, 6, 7, 10, 11, 13], "advanc": [1, 12], "techniqu": 1, "flexibl": [1, 11], "customiz": [1, 5], "method": [1, 4, 7, 8, 9, 11], "row": [1, 4, 5, 6], "wise": [1, 4, 5, 6, 11], "column": [1, 5, 6], "so": [1, 4, 5, 9, 12, 14], "can": [1, 4, 5, 9, 11, 12, 14], "automat": [1, 4, 5, 9, 14], "determin": [1, 4, 5, 6], "best": [1, 5], "plan": [1, 4, 5, 11], "devic": [1, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "topolog": [1, 4, 5, 6], "effici": [1, 5, 11], "memori": [1, 4, 5, 9, 12], "balanc": [1, 5], "while": [1, 4, 6, 7, 8, 11], "support": [1, 4, 5, 6, 7, 9, 11, 12], "basic": [1, 10, 14], "extend": [1, 4], "capabl": 1, "sophist": 1, "incred": 1, "optim": [1, 4, 5, 9, 11, 13], "top": [1, 4, 9, 11], "fbgemm": [1, 4, 5, 6, 13], "after": [1, 4, 5, 6, 9, 11], "all": [1, 4, 5, 6, 8, 9, 10, 11, 12, 14], "power": 1, "some": [1, 4, 9, 14], "largest": [1, 5], "frictionless": 1, "deploy": 1, "simpl": [1, 10], "api": [1, 4, 6, 7, 9, 11], "transform": [1, 4, 8, 11], "load": [1, 4, 5, 6, 12], "c": [1, 4, 6, 8, 14], "environ": [1, 4, 8], "most": [1, 4, 12], "integr": 1, "ecosystem": 1, "mean": [1, 4, 5, 9, 11], "seamlessli": 1, "exist": [1, 4, 6, 11, 14], "code": [1, 4, 11], "tool": 1, "develop": 1, "leverag": [1, 11], "knowledg": [1, 4, 5, 9], "codebas": 1, "util": [1, 6], "featur": [1, 4, 5, 6, 9, 10, 11, 13, 14], "By": 1, "being": [1, 4, 5, 8, 9, 11], "part": [1, 4, 5, 6, 11, 12], "benefit": 1, "from": [1, 4, 5, 6, 7, 8, 9, 11, 12, 14], "robust": 1, "commun": [1, 4, 5, 6, 9], "continu": 1, "updat": [1, 4, 5, 6, 8, 9, 11, 12], "improv": [1, 12], "come": [1, 11], "necessari": [4, 5, 6, 9], "oper": [4, 5, 6, 7, 11, 14], "These": [4, 5, 9, 11], "includ": [4, 5, 7, 8, 9, 11, 14], "through": [4, 7, 9, 12], "collect": [4, 6, 10, 11, 12, 13], "reduc": [4, 6, 11, 13], "scatter": [4, 6], "wrapper": [4, 12], "spars": [4, 6, 10, 11, 13], "kjt": [4, 5, 6, 10, 11, 13, 14], "variou": [4, 8, 11], "implement": [4, 5, 6, 9, 11, 12, 14], "shardedembeddingbag": 4, "nn": [4, 5, 7, 11, 13], "shardedembeddingbagcollect": [4, 11, 13], "embeddingbagcollect": [4, 10, 11, 13], "sharder": [4, 5, 8], "defin": [4, 6, 8, 9, 10, 11], "ani": [4, 5, 6, 7, 8, 9, 11, 12, 14], "comput": [4, 5, 6, 8, 9, 10, 11, 13], "kernel": [4, 5, 11], "cpu": [4, 5, 9], "mai": [4, 14], "batch": [4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "togeth": [4, 11], "fusion": 4, "pipelin": [4, 5, 11, 14], "trainpipelinesparsedist": 4, "overlap": 4, "dataload": 4, "transfer": 4, "copi": [4, 6, 8, 9, 11, 12, 14], "inter": [4, 11], "input_dist": [4, 11], "forward": [4, 5, 6, 8, 9, 10, 11, 13, 14], "backward": [4, 5, 7, 12], "increas": [4, 9], "perform": [4, 5, 6, 8, 9, 11, 12, 13], "quantiz": [4, 6, 7, 8, 13], "precis": [4, 11, 13], "file": 4, "contain": [4, 5, 6, 8, 9, 11, 12, 13], "construct": [4, 7, 11, 14], "base": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "control": [4, 7], "flow": [4, 7], "invoke_on_rank_and_broadcast_result": 4, "pg": [4, 5, 6, 9], "processgroup": [4, 5, 6, 9], "rank": [4, 5, 6, 9, 11, 12, 14], "int": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "func": 4, "callabl": [4, 6, 7, 11, 12, 13], "t": [4, 5, 6, 7, 8, 11, 12, 14], "arg": [4, 5, 8, 9, 11, 13, 14], "kwarg": [4, 9, 11, 14], "invok": [4, 5], "function": [4, 5, 6, 7, 8, 11, 12, 14], "broadcast": [4, 5], "result": [4, 5, 6, 8, 9, 11, 13], "member": [4, 11], "within": [4, 5, 6, 8, 11, 14], "group": [4, 5, 6, 9, 11, 12, 14], "exampl": [4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "id": [4, 5, 6, 11], "0": [4, 5, 6, 9, 10, 11, 12, 13, 14], "allocate_id": 4, "is_lead": 4, "option": [4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "leader_rank": 4, "bool": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "check": [4, 5, 9, 11, 12, 14], "current": [4, 5, 6, 8, 9, 11], "processs": 4, "leader": [4, 9], "paramet": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "dist": [4, 6, 9], "process": [4, 5, 6, 9, 10, 11, 13], "us": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "none": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "impli": 4, "onli": [4, 5, 6, 9, 11, 14], "e": [4, 5, 6, 7, 8, 9, 10, 11, 12], "g": [4, 5, 8, 9, 11, 12], "singl": [4, 5, 6, 11, 12], "program": 4, "definit": [4, 7, 8], "default": [4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "The": [4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "caller": 4, "overrid": [4, 5, 7, 8, 9], "context": [4, 6, 14], "run_on_lead": 4, "get_group_rank": 4, "world_siz": [4, 5, 6, 8, 9], "get": [4, 5, 6], "worker": 4, "also": [4, 5, 8, 9, 11, 12], "avail": [4, 5, 6], "group_rank": 4, "varibl": 4, "A": [4, 5, 6, 7, 8, 9, 12, 14], "number": [4, 5, 6, 9, 10, 11, 14], "between": [4, 5, 9, 10, 11, 14], "get_num_group": 4, "see": [4, 5, 6, 7, 9, 11, 14], "org": [4, 5, 10, 11, 14], "doc": [4, 11, 14], "stabl": [4, 11, 14], "elast": 4, "run": [4, 5, 6, 8, 9, 11, 12], "html": [4, 11, 14], "get_local_rank": 4, "local": [4, 5, 6, 9, 11], "usual": [4, 5, 6, 9, 11], "its": [4, 5, 6, 9, 11, 12, 14], "node": [4, 7], "get_local_s": 4, "equival": 4, "max_nnod": 4, "intra_and_cross_node_pg": 4, "backend": [4, 6, 8, 9], "str": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "tupl": [4, 5, 6, 7, 8, 11, 12, 13, 14], "sub": 4, "intra": 4, "cross": [4, 11], "class": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "all2alldenseinfo": 4, "output_split": [4, 6], "list": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "batch_siz": [4, 5, 6, 9, 11, 14], "input_shap": 4, "input_split": [4, 6], "object": [4, 5, 8, 9, 11, 12], "data": [4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "attribut": [4, 5, 9, 12], "when": [4, 5, 7, 9, 11, 12], "alltoall_dens": 4, "all2allpooledinfo": 4, "batch_size_per_rank": [4, 6], "dim_sum_per_rank": [4, 6], "dim_sum_per_rank_tensor": 4, "tensor": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "cumsum_dim_sum_per_rank_tensor": 4, "codec": [4, 6], "quantizedcommcodec": [4, 6], "alltoall_pool": [4, 6], "size": [4, 5, 6, 9, 10, 11, 13, 14], "each": [4, 5, 6, 9, 10, 11, 13, 14], "sum": [4, 5, 6, 11], "dimens": [4, 5, 6, 10, 11, 13, 14], "version": [4, 11, 13], "fast": 4, "_recat_pooled_embedding_grad_out": 4, "cumul": [4, 9, 14], "all2allsequenceinfo": 4, "embedding_dim": [4, 5, 6, 10, 11, 13], "lengths_after_sparse_data_all2al": 4, "forward_recat_tensor": 4, "backward_recat_tensor": 4, "variable_batch_s": 4, "fals": [4, 5, 6, 8, 9, 11, 12, 13, 14], "permuted_lengths_after_sparse_data_all2al": 4, "alltoall_sequ": 4, "length": [4, 5, 6, 10, 11, 13, 14], "alltoal": [4, 6], "recat": [4, 6, 11, 14], "input": [4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "split": [4, 5, 6, 8, 14], "output": [4, 5, 6, 8, 9, 10, 11, 13, 14], "whether": [4, 5, 7, 9, 11, 13], "variabl": [4, 6, 9, 11, 13, 14], "befor": [4, 6, 9, 11, 12], "all2allvinfo": 4, "dims_sum_per_rank": 4, "b_global": 4, "b_local": 4, "b_local_list": 4, "d_local_list": 4, "input_split_s": 4, "factori": [4, 5, 11], "output_split_s": 4, "alltoallv": 4, "global": [4, 5, 6, 9], "my": 4, "rememb": [4, 14], "how": [4, 5, 6, 12], "do": [4, 5, 9, 11, 12, 14], "all_to_all_singl": 4, "fill": 4, "all2all_pooled_req": 4, "static": [4, 5, 9, 12, 14], "ctx": 4, "unus": [4, 11], "formula": 4, "differenti": 4, "mode": [4, 5, 9], "overridden": [4, 6, 8, 9, 11], "subclass": [4, 6, 8, 11, 12], "vjp": 4, "must": [4, 5, 6, 8, 9, 11], "accept": [4, 5, 8, 9, 11], "first": [4, 5, 6, 11, 12, 14], "argument": [4, 7, 8, 9, 11], "return": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "pass": [4, 5, 6, 8, 9, 11, 12, 13, 14], "non": [4, 5, 6, 7, 9, 11, 13], "should": [4, 5, 6, 8, 9, 10, 11, 12, 14], "were": 4, "gradient": [4, 5, 12], "w": [4, 6, 11, 14], "r": [4, 11], "given": [4, 5, 6, 7, 11], "valu": [4, 5, 6, 7, 9, 10, 11, 12, 13, 14], "correspond": [4, 5, 6, 8, 9, 11, 14], "If": [4, 5, 8, 9, 11, 12, 14], "an": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "requir": [4, 5, 9, 11, 12], "grad": [4, 12], "you": [4, 6, 7, 14], "just": [4, 5, 10, 11, 14], "retriev": 4, "save": [4, 5, 11, 12], "dure": [4, 5, 9, 12], "ha": [4, 5, 9, 11, 14], "needs_input_grad": 4, "boolean": 4, "repres": [4, 5, 8, 10, 11, 13, 14], "have": [4, 5, 6, 9, 10, 11, 12, 14], "true": [4, 5, 8, 9, 11, 12, 14], "myreq": 4, "request": [4, 8, 12], "a2ai": 4, "input_embed": [4, 11], "custom": [4, 5, 7, 11], "autograd": [4, 8, 9, 11], "There": 4, "two": [4, 5, 9, 11, 14], "wai": [4, 5], "usag": [4, 5, 9], "1": [4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "combin": [4, 11, 12], "staticmethod": 4, "def": [4, 11], "other": [4, 5, 6, 9, 12], "detail": [4, 5, 6, 9, 11], "2": [4, 5, 6, 9, 10, 11, 12, 13, 14], "separ": 4, "setup_context": 4, "longer": [4, 5], "instead": [4, 6, 8, 11, 12], "torch": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "handl": [4, 5, 6, 7, 11, 12], "set": [4, 5, 6, 8, 9, 11, 12], "up": [4, 5, 13], "store": [4, 5, 6, 14], "arbitrari": 4, "directli": [4, 12], "though": 4, "enforc": [4, 8, 9, 11], "compat": [4, 7, 12], "either": [4, 5, 9, 11], "save_for_backward": 4, "thei": [4, 14], "intend": 4, "save_for_forward": 4, "jvp": 4, "all2all_pooled_wait": 4, "grad_output": 4, "dummy_tensor": 4, "all2all_seq_req": 4, "sharded_input_embed": 4, "all2all_seq_req_wait": 4, "sharded_grad_output": 4, "all2allv_req": 4, "all2allv_wait": 4, "allgatherbaseinfo": 4, "input_s": [4, 5, 11], "all_gatther_base_pool": 4, "allgatherbase_req": 4, "agi": 4, "allgatherbase_wait": 4, "reducescatterbaseinfo": 4, "reduce_scatter_base_pool": 4, "flatten": [4, 6, 11], "reducescatterbase_req": 4, "rsi": 4, "reducescatterbase_wait": 4, "reducescatterinfo": 4, "reduce_scatter_pool": 4, "produc": [4, 5], "reducescattervinfo": 4, "equal_split": 4, "total_input_s": 4, "reduce_scatter_v_pool": 4, "along": [4, 6, 9, 12, 14], "dim": [4, 6], "total": [4, 5, 6, 9], "reducescatterv_req": 4, "reducescatterv_wait": 4, "reducescatter_req": 4, "reducescatter_wait": 4, "await": [4, 6, 7], "variablebatchall2allpooledinfo": 4, "batch_size_per_rank_per_featur": [4, 6], "batch_size_per_feature_pre_a2a": [4, 6], "emb_dim_per_rank_per_featur": [4, 6], "variable_batch_alltoall_pool": [4, 6], "per": [4, 5, 6, 9, 11, 14], "variable_batch_all2all_pooled_req": 4, "variable_batch_all2all_pooled_wait": 4, "all2all_pooled_sync": 4, "all2all_sequence_sync": 4, "all2allv_sync": 4, "all_gather_base_pool": 4, "gather": [4, 6], "form": [4, 11, 13], "pool": [4, 5, 6, 10, 11, 13, 14], "output_tensor_s": 4, "work": [4, 5, 8, 9, 11, 14], "async": [4, 6], "wait": [4, 6], "later": [4, 11], "experiment": [4, 11], "subject": 4, "chang": [4, 11, 12], "all_gather_base_sync": 4, "all_gather_into_tensor_backward": 4, "all_gather_into_tensor_fak": 4, "gather_dim": 4, "group_siz": 4, "group_nam": 4, "gradient_divis": 4, "all_gather_into_tensor_setup_context": 4, "all_to_all_single_backward": 4, "all_to_all_single_fak": 4, "all_to_all_single_setup_context": 4, "a2a_pooled_embs_tensor": 4, "world": [4, 6], "Then": 4, "concaten": [4, 6, 11, 14], "receiv": [4, 6, 12], "Its": 4, "shape": [4, 6, 9, 11, 14], "b": [4, 5, 6, 10, 11, 13, 14], "x": [4, 5, 6, 10, 11, 13, 14], "d_local_sum": 4, "where": [4, 5, 6, 9, 11, 13], "a2a_sequence_embs_tensor": 4, "sequenc": [4, 5, 6], "doe": [4, 10, 11, 12, 14], "mix": 4, "out_split": 4, "per_rank_split_length": 4, "one": [4, 5, 6, 8, 9, 10, 11, 12], "differ": [4, 5, 6, 11, 12, 14], "specifi": [4, 5, 6, 7, 9, 11, 12], "assumpt": [4, 14], "emb": 4, "same": [4, 5, 6, 8, 9, 10, 11, 14], "get_gradient_divis": 4, "get_use_sync_collect": 4, "pg_name": 4, "reduce_scatter_base_sync": 4, "chunk": [4, 6], "reduce_scatter_sync": 4, "reduce_scatter_tensor_backward": 4, "reduce_scatter_tensor_fak": 4, "reduceop": 4, "reduce_scatter_tensor_setup_context": 4, "reduce_scatter_v_per_feature_pool": 4, "v": [4, 6, 11, 14], "d": [4, 5, 10, 11, 13, 14], "unevenli": 4, "accord": [4, 5, 6, 8, 10, 12, 14], "reduce_scatter_v_sync": 4, "set_gradient_divis": 4, "val": 4, "set_use_sync_collect": 4, "torchrec_use_sync_collect": 4, "variable_batch_all2all_pooled_sync": 4, "embeddingsalltoon": [4, 6], "cat_dim": [4, 6, 14], "merg": [4, 6], "buffer": [4, 6, 8, 9, 11], "alloc": [4, 6, 8], "would": [4, 6, 14], "alltoon": [4, 6], "set_devic": [4, 6], "device_str": [4, 6], "embeddingsalltoonereduc": [4, 6], "jaggedtensoralltoal": [4, 6], "jt": [4, 6, 11, 14], "jaggedtensor": [4, 6, 11, 13, 14], "num_items_to_send": [4, 6], "num_items_to_rec": [4, 6], "redistribut": [4, 6], "item": [4, 6], "send": [4, 6], "known": [4, 5, 6, 11], "ahead": [4, 6], "time": [4, 5, 6, 8, 9, 11], "keyedjaggedtensorpool": [4, 6], "lookup": [4, 5, 6, 10, 11, 13], "via": [4, 6], "anoth": [4, 6], "kjtalltoal": [4, 6], "stagger": [4, 6, 14], "keyedjaggedtensor": [4, 6, 10, 11, 13, 14], "kjtalltoallsplitsawait": [4, 6], "transmit": [4, 6], "correct": [4, 6, 14], "space": [4, 5, 6, 10], "kjtalltoalltensorsawait": [4, 6], "actual": [4, 5, 6, 8, 9, 11], "asynchron": [4, 6], "len": [4, 6, 10], "indic": [4, 6, 8, 11, 12, 13, 14], "assum": [4, 5, 6, 8, 9, 10, 12], "order": [4, 5, 6, 8, 9, 11, 14], "destin": [4, 6, 8, 9, 11], "appli": [4, 6, 10, 11], "_get_recat": [4, 6], "kei": [4, 5, 6, 8, 9, 10, 11, 13, 14], "kjta2a": [4, 6], "rank0_input": [4, 6], "hold": [4, 5, 6, 12, 14], "v0": [4, 6, 14], "v1": [4, 6, 11, 14], "v2": [4, 6, 11, 14], "rank1_input": [4, 6], "v3": [4, 6, 14], "v4": [4, 6, 14], "rank0_output": [4, 6], "3": [4, 5, 6, 9, 10, 11, 12, 13, 14], "4": [4, 5, 6, 9, 10, 11, 13, 14], "5": [4, 6, 9, 10, 11, 13, 14], "rank1_output": [4, 6], "relev": [4, 5, 6], "issu": [4, 6, 11], "second": [4, 5, 6, 9, 11, 14], "label": [4, 6, 9], "tensor_split": [4, 6], "input_tensor": [4, 6], "dict": [4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "ie": [4, 5, 6, 11, 14], "stride_per_rank": [4, 6, 14], "stride": [4, 6, 14], "case": [4, 5, 6, 9, 11, 12, 14], "kjtonetoal": [4, 6], "onetoal": [4, 6], "essenti": [4, 6, 14], "p2p": [4, 6], "keyjaggedtensor": [4, 6], "them": [4, 6, 8, 11, 12], "kjtlist": [4, 6], "slice": [4, 6, 7, 14], "mergepooledembeddingsmodul": [4, 6], "merge_pooled_embedding_optim": [4, 6], "_mergepooledembeddingsmoduleimpl": [4, 6], "merge_pooled_embed": [4, 6], "pooledembeddingsallgath": [4, 6], "layout": [4, 6, 7], "want": [4, 6, 9], "nccl": [4, 6], "happen": [4, 6], "init_distribut": [4, 6], "new_group": [4, 6, 9], "randn": [4, 6, 10, 11], "m": [4, 6, 7, 11], "local_emb": [4, 6], "pooledembeddingsawait": [4, 6], "num_bucket": [4, 6], "pooledembeddingsalltoal": [4, 6], "callback": [4, 6], "a2a": [4, 6], "t0": [4, 6], "rand": [4, 6, 10], "6": [4, 5, 6, 10, 11, 13, 14], "t1": [4, 6, 10, 11, 13], "print": [4, 6, 11, 13], "properti": [4, 5, 6, 8, 9, 11, 12, 13], "tensor_await": [4, 6], "pooledembeddingsreducescatt": [4, 6], "twrw": [4, 5, 6], "over": [4, 6, 11, 12], "unequ": [4, 6], "bucket": [4, 6], "seqembeddingsalltoon": [4, 6], "concat": [4, 6, 11, 14], "sequenceembeddingsalltoal": [4, 6], "features_per_rank": [4, 6], "sharding_ctx": [4, 6], "sequenceshardingcontext": [4, 6], "lengths_after_input_dist": [4, 6], "unbucketize_permute_tensor": [4, 6], "sparse_features_recat": [4, 6], "sequenceembeddingsawait": [4, 6], "permut": [4, 6, 14], "splitsalltoallawait": [4, 6], "tensoralltoal": [4, 6], "1d": [4, 5, 6], "tensoralltoallsplitsawait": [4, 6], "tensoralltoallvaluesawait": [4, 6], "tensor_a2a": [4, 6], "rank0": [4, 6], "rank1": [4, 6], "v5": [4, 6, 14], "v6": [4, 6, 14], "v7": [4, 6, 14], "v8": [4, 6], "v9": [4, 6], "v10": [4, 6], "v11": [4, 6], "v12": [4, 6], "tensorvaluesalltoal": [4, 6], "tensor_vals_a2a": [4, 6], "v13": [4, 6], "v14": [4, 6], "v15": [4, 6], "sent": [4, 6], "equal": [4, 6, 11, 14], "self": [4, 5, 6, 11, 14], "_pg": [4, 6], "variablebatchpooledembeddingsalltoal": [4, 6], "kjt_split": [4, 6], "24": [4, 6], "r0_batch_siz": [4, 6], "r1_batch_siz": [4, 6], "f_0": [4, 6], "f_1": [4, 6], "f_2": [4, 6], "r0_batch_size_per_rank_per_featur": [4, 6], "r1_batch_size_per_rank_per_featur": [4, 6], "r0_batch_size_per_feature_pre_a2a": [4, 6], "r1_batch_size_per_feature_pre_a2a": [4, 6], "r0": [4, 6], "r1": [4, 6], "16": [4, 6, 11, 13], "14": [4, 6], "post": [4, 6], "rank_0": [4, 6], "rank_1": [4, 6], "variablebatchpooledembeddingsreducescatt": [4, 6], "rw": [4, 5, 6, 11], "multipli": [4, 5, 6], "batch_size_r0_f0": [4, 6], "emb_dim_f0": [4, 6], "embeddingcollectionawait": 4, "lazyawait": 4, "embeddingcollectioncontext": 4, "sharding_context": 4, "input_featur": 4, "reverse_indic": [4, 11], "seq_vbe_ctx": [4, 11], "sequencevbecontext": [4, 11], "multistream": [4, 11], "record_stream": [4, 11, 14], "stream": [4, 11, 14], "embeddingcollectionshard": 4, "fused_param": [4, 6, 8], "qcomm_codecs_registri": [4, 6], "use_index_dedup": 4, "baseembeddingshard": 4, "embeddingcollect": [4, 11, 13], "module_typ": [4, 13], "param": [4, 5, 9, 12], "parametershard": 4, "env": [4, 6], "shardingenv": [4, 6], "shardedembeddingcollect": [4, 11, 13], "locat": 4, "replic": [4, 5, 6], "embeddingmoduleshardingplan": 4, "fulli": [4, 5, 12], "qualifi": 4, "name": [4, 5, 8, 9, 10, 11, 12, 13, 14], "spec": 4, "shardedmodul": 4, "shardable_paramet": 4, "sharding_typ": [4, 5, 11], "compute_device_typ": 4, "shardingtyp": [4, 5, 11], "well": [4, 5, 11], "table_name_to_parameter_shard": 4, "shardedembeddingmodul": 4, "fusedoptimizermodul": [4, 12], "public": [4, 11], "manual": [4, 12], "dist_input": 4, "compute_and_output_dist": 4, "multipl": [4, 5, 9, 11, 12], "make": [4, 11, 12], "sens": [4, 12], "initi": [4, 11, 12], "distibut": 4, "soon": 4, "complet": [4, 5], "create_context": 4, "fused_optim": [4, 12], "keyedoptim": [4, 12], "output_dist": 4, "reset_paramet": [4, 11], "create_embedding_shard": 4, "sharding_info": [4, 6], "embeddingshardinginfo": [4, 6], "embeddingshard": [4, 6], "create_sharding_infos_by_shard": 4, "embeddingcollectioninterfac": [4, 11, 13], "create_sharding_infos_by_sharding_device_group": 4, "get_device_from_parameter_shard": 4, "ps": [4, 5], "get_ec_index_dedup": 4, "pad_vbe_kjt_length": 4, "set_ec_index_dedup": 4, "commopgradientsc": 4, "functionctx": 4, "scale_gradient_factor": 4, "groupedembeddingslookup": 4, "grouped_config": 4, "groupedembeddingconfig": [4, 6], "baseembeddinglookup": [4, 6], "i": [4, 5, 6, 7, 9, 10, 11], "flush": 4, "sparse_featur": [4, 6, 10], "everi": [4, 5, 6, 8, 11], "although": [4, 6, 8, 11], "recip": [4, 6, 8, 11], "instanc": [4, 6, 7, 8, 9, 11], "afterward": [4, 6, 8, 11], "sinc": [4, 6, 8, 11], "former": [4, 6, 8, 11], "take": [4, 5, 6, 8, 11, 12], "care": [4, 6, 8, 11], "regist": [4, 6, 7, 8, 11], "hook": [4, 6, 8, 11], "latter": [4, 6, 8, 11], "silent": [4, 6, 8, 11], "ignor": [4, 5, 6, 8, 11], "load_state_dict": [4, 12], "state_dict": [4, 8, 9, 11, 12], "ordereddict": [4, 8, 9, 11], "union": [4, 5, 7, 8, 9, 11, 12], "shardedtensor": [4, 12], "strict": [4, 12], "_incompatiblekei": 4, "descend": [4, 5], "exactli": 4, "match": [4, 5, 8, 9, 11], "assign": [4, 14], "unless": [4, 12], "get_swap_module_params_on_convers": 4, "persist": [4, 8, 9, 11], "strictli": [4, 11], "preserv": [4, 11], "except": [4, 5, 9, 11], "requires_grad": 4, "field": [4, 11, 12, 14], "missing_kei": 4, "expect": [4, 5, 10, 11], "miss": [4, 5], "unexpected_kei": 4, "present": [4, 12], "namedtupl": 4, "rais": 4, "runtimeerror": 4, "named_buff": [4, 11], "prefix": [4, 8, 9, 11], "recurs": [4, 11], "remove_dupl": [4, 11], "iter": [4, 5, 11, 12], "yield": [4, 11], "both": [4, 8, 9, 10, 11, 12, 14], "itself": [4, 11], "prepend": [4, 11], "submodul": [4, 11, 12], "otherwis": [4, 5, 8, 9, 11, 12, 14], "direct": [4, 11], "remov": [4, 7, 11], "duplic": [4, 11, 12], "xdoctest": [4, 8, 9, 11], "skip": [4, 8, 9, 11, 12], "undefin": [4, 8, 9, 11], "var": [4, 8, 9, 11], "buf": [4, 11], "running_var": [4, 11], "named_paramet": 4, "bia": [4, 8, 9, 11], "named_parameters_by_t": 4, "tablebatchedembeddingslic": 4, "table_nam": 4, "embedding_weight": 4, "cw": [4, 5], "weight": [4, 5, 6, 8, 9, 11, 12, 13, 14], "compos": [4, 8, 9, 11], "prefetch": [4, 5], "forward_stream": 4, "purg": 4, "keep_var": [4, 8, 9, 11], "dictionari": [4, 8, 9, 11], "refer": [4, 8, 9, 11, 14], "whole": [4, 8, 9, 11], "averag": [4, 5, 8, 9, 11], "shallow": [4, 8, 9, 11], "posit": [4, 5, 6, 8, 9, 11], "howev": [4, 8, 9, 11, 12], "deprec": [4, 8, 9, 11], "keyword": [4, 8, 9, 11], "futur": [4, 8, 9, 11], "releas": [4, 8, 9, 11], "pleas": [4, 5, 8, 9, 11, 14], "avoid": [4, 8, 9, 11, 12], "end": [4, 5, 8, 9, 11], "user": [4, 5, 8, 9, 11, 12], "ad": [4, 8, 9, 11, 12], "detach": [4, 8, 9, 11], "groupedpooledembeddingslookup": 4, "feature_processor": [4, 6, 13], "basegroupedfeatureprocessor": [4, 6, 11], "scale_weight_gradi": 4, "infercpugroupedembeddingslookup": 4, "grouped_configs_per_rank": 4, "infergroupedlookupmixin": 4, "inputdistoutput": [4, 6], "tbetoregistermixin": 4, "get_tbes_to_regist": 4, "intnbittablebatchedembeddingbagscodegen": 4, "infergroupedembeddingslookup": 4, "abc": [4, 5, 8, 9, 11, 12], "input_dist_output": 4, "infergroupedpooledembeddingslookup": 4, "metainfergroupedembeddingslookup": 4, "tbe": [4, 5, 13], "op": [4, 5, 6, 12, 13], "metainfergroupedpooledembeddingslookup": 4, "bag": [4, 6, 7, 10, 11], "dtype": [4, 5, 6, 7, 8, 11, 13, 14], "embeddings_cat_empty_rank_handl": 4, "dummy_embs_tensor": 4, "embeddings_cat_empty_rank_handle_infer": 4, "fx_wrap_tensor_view2d": 4, "dim0": 4, "dim1": 4, "baseembeddingdist": [4, 6], "convert": [4, 7, 8, 14], "embeddinglookup": 4, "abstract": [4, 5, 8, 9, 11, 12], "basesparsefeaturesdist": [4, 6], "f": [4, 5, 6, 10, 11, 13], "featureshardingmixin": 4, "table_wis": [4, 11], "create_input_dist": [4, 6], "create_lookup": [4, 6], "create_output_dist": [4, 6], "embedding_nam": [4, 6, 11], "embedding_names_per_rank": [4, 6], "embedding_shard_metadata": [4, 6], "shardmetadata": [4, 6], "embedding_t": [4, 6], "shardedembeddingt": [4, 6], "uncombined_embedding_dim": [4, 6], "uncombined_embedding_nam": [4, 6], "embeddingshardingcontext": [4, 6], "variable_batch_per_featur": 4, "embedding_config": [4, 13], "embeddingtableconfig": [4, 11], "param_shard": 4, "nonetyp": [4, 9, 11], "fusedkjtlistsplitsawait": 4, "kjtlistsplitsawait": 4, "kjtlistawait": 4, "info": [4, 11], "metadata": [4, 8, 11], "kjtsplitsalltoallmeta": 4, "distributed_c10d": 4, "_input": 4, "jagged_tensor": 4, "splits_tensor": 4, "listofkjtlistawait": 4, "listofkjtlist": 4, "listofkjtlistsplitsawait": 4, "bucketize_kjt_before_all2al": 4, "block_siz": [4, 6], "output_permut": 4, "bucketize_po": 4, "block_bucketize_row_po": 4, "readjust": 4, "note": [4, 5, 6, 11, 14], "map": [4, 9, 11, 12, 13], "unbucket": 4, "offset": [4, 5, 10, 11, 13, 14], "bucketize_kjt_infer": 4, "is_sequ": [4, 6], "group_tabl": 4, "tables_per_rank": 4, "datatyp": [4, 5, 11, 13, 14], "poolingtyp": [4, 11], "embeddingcomputekernel": [4, 5], "consist": 4, "weighted": 4, "interfac": [4, 8, 9, 11], "reli": [4, 8, 11, 13], "etc": [4, 8, 12, 14], "moduleshard": [4, 5, 8], "compute_kernel": [4, 5], "storage_usag": 4, "resourc": 4, "processor": [4, 6, 11], "basequantembeddingshard": 4, "shardable_param": 4, "dtensormetadata": 4, "mesh": 4, "device_mesh": 4, "devicemesh": 4, "placement": [4, 5], "_tensor": 4, "placement_typ": 4, "embeddingattribut": 4, "dens": [4, 5, 10, 11, 14], "enum": [4, 5, 11, 12], "enumer": [4, 11, 12], "fuse": [4, 6, 9], "fused_uvm": 4, "fused_uvm_cach": 4, "key_valu": 4, "quant": 4, "quant_uvm": 4, "quant_uvm_cach": 4, "awar": [4, 14], "feature_nam": [4, 5, 6, 10, 11, 13], "feature_names_per_rank": [4, 6], "data_typ": [4, 11], "is_weight": [4, 5, 11, 13, 14], "has_feature_processor": [4, 6, 11], "dim_sum": 4, "feature_hash_s": [4, 6], "num_featur": [4, 6, 10, 11], "bucket_mapping_tensor": 4, "bucketized_length": 4, "moduleshardingmixin": 4, "access": [4, 5, 12, 14], "scheme": 4, "optimtyp": 4, "adagrad": [4, 12], "adam": [4, 12], "adamw": 4, "lamb": 4, "lars_sgd": 4, "lion": 4, "partial_rowwise_adam": 4, "partial_rowwise_lamb": 4, "rowwise_adagrad": 4, "sgd": 4, "shampoo": 4, "shampoo_v2": 4, "shardedconfig": 4, "local_row": [4, 5], "local_col": [4, 5], "compin": 4, "distout": 4, "out": [4, 11, 14], "shrdctx": 4, "commop": 4, "extra_repr": 4, "pretti": 4, "represent": [4, 5, 7, 11, 14], "num_embed": [4, 5, 10, 11, 13], "fp32": [4, 5, 11], "weight_init_max": [4, 11], "float": [4, 5, 7, 9, 11, 12, 14], "weight_init_min": [4, 11], "pruning_indices_remap": [4, 11], "init_fn": [4, 11], "need_po": [4, 6, 11], "local_metadata": 4, "_shard": 4, "global_metadata": 4, "sharded_tensor": 4, "shardedtensormetadata": 4, "dtensor_metadata": 4, "shardedmetaconfig": 4, "compute_kernel_to_embedding_loc": 4, "embeddingloc": 4, "embeddingawait": 4, "embeddingbagcollectionawait": 4, "lazygetitemmixin": 4, "keyedtensor": [4, 10, 11, 13, 14], "embeddingbagcollectioncontext": 4, "inverse_indic": [4, 11, 14], "divisor": 4, "embeddingbagcollectionshard": 4, "embeddingbagshard": 4, "nullshardedmodulecontext": 4, "per_sample_weight": 4, "named_modul": 4, "memo": 4, "network": [4, 5, 11, 12], "alreadi": [4, 6, 8, 12], "onc": [4, 11], "l": [4, 11, 13], "linear": [4, 5, 11, 12], "net": [4, 11], "sequenti": [4, 5, 11], "idx": 4, "in_featur": [4, 10, 11], "out_featur": [4, 11], "sharded_parameter_nam": 4, "embeddingbagcollectioninterfac": [4, 11, 13], "variablebatchembeddingbagcollectionawait": 4, "construct_output_kt": 4, "create_embedding_bag_shard": 4, "permute_embed": [4, 6], "suffix": 4, "replace_placement_with_meta_devic": 4, "could": [4, 5, 14], "unmatch": 4, "scenario": [4, 11, 13], "cuda": [4, 5, 8], "embeddingshardingplann": [4, 5], "planner": 4, "groupedpositionweightedmodul": 4, "max_feature_length": [4, 11], "dataparallelwrapp": 4, "defaultdataparallelwrapp": 4, "bucket_cap_mb": 4, "25": [4, 9], "static_graph": 4, "find_unused_paramet": 4, "allreduce_comm_precis": 4, "params_to_ignor": 4, "unshard": [4, 5, 11, 13], "shardingplan": [4, 5, 8], "init_data_parallel": 4, "init_paramet": 4, "data_parallel_wrapp": 4, "entri": 4, "point": [4, 5], "collective_plan": [4, 5], "lazi": [4, 11, 12], "delai": 4, "until": 4, "still": [4, 14], "no_grad": [4, 11], "init_weight": [4, 11], "isinst": 4, "fill_": [4, 11], "elif": 4, "init": 4, "kaiming_normal_": 4, "mymodel": 4, "bare_named_paramet": 4, "new": [4, 5, 9], "origin": [4, 5], "tor": 4, "safe": 4, "ddp": 4, "fsdp": 4, "sparse_grad_parameter_nam": [4, 12], "get_modul": 4, "unwrap": 4, "get_unwrapped_modul": 4, "quantembeddingbagcollectionshard": 4, "shardedquantembeddingbagcollect": 4, "quantfeatureprocessedembeddingbagcollectionshard": 4, "featureprocessedembeddingbagcollect": [4, 13], "shardedquantebcinputdist": 4, "sharding_type_device_group_to_shard": 4, "nullshardingcontext": [4, 6], "sharding_type_to_shard": 4, "sqebc_input_dist": 4, "infertwsequenceembeddingshard": 4, "f1": [4, 10, 11, 13], "f2": [4, 10, 11, 13], "7": [4, 9, 10, 11, 13, 14], "8": [4, 5, 10, 11, 13, 14], "shardedquantembeddingmodulest": 4, "embedding_bag_config": [4, 11, 13], "embeddingbagconfig": [4, 10, 11, 13], "execut": [4, 5, 8, 11, 13], "step": [4, 5, 12], "sharding_type_device_group_to_sharding_info": 4, "tbes_config": 4, "shardedquantfeatureprocessedembeddingbagcollect": 4, "featureprocessorscollect": [4, 13], "apply_feature_processor": 4, "kjt_list": [4, 14], "embedding_bag": [4, 13], "moduledict": [4, 11, 13], "modulelist": [4, 9, 11, 13], "create_infer_embedding_bag_shard": 4, "flatten_feature_length": 4, "get_device_from_sharding_info": 4, "emb_shard_info": 4, "cacheparam": [4, 5], "algorithm": 4, "cachealgorithm": 4, "load_factor": [4, 5], "reserved_memori": 4, "prefetch_pipelin": [4, 5], "stat": 4, "cachestatist": [4, 5], "multipass_prefetch_config": 4, "multipassprefetchconfig": 4, "cach": [4, 5], "relat": [4, 5, 9], "uvm": [4, 5], "lru": [4, 5], "lfu": 4, "factor": [4, 5, 11], "decid": 4, "crucial": 4, "reserv": [4, 5], "ideal": 4, "aka": 4, "statist": [4, 5], "better": [4, 5], "tune": [4, 12], "cacheabl": [4, 5], "summar": [4, 5], "measur": [4, 5, 9], "difficulti": [4, 5], "dataset": [4, 5], "independ": [4, 5], "score": [4, 5, 6, 11], "veri": [4, 5], "high": [4, 5, 9, 11], "difficult": [4, 5], "expected_lookup": [4, 5], "distinct": [4, 5], "expected_miss_r": [4, 5], "clf": [4, 5], "rate": [4, 5, 9, 12], "100": [4, 5, 9, 10, 11], "hit": [4, 5], "extrem": [4, 5], "estim": [4, 5, 9], "pooled_embeddings_all_to_al": 4, "pooled_embeddings_reduce_scatt": 4, "sequence_embeddings_all_to_al": 4, "computekernel": 4, "moduleshardingplan": 4, "describ": 4, "genericmeta": 4, "getitemlazyawait": 4, "parentw": 4, "kt": [4, 14], "__getitem__": 4, "parent": 4, "keyvalueparam": [4, 5], "ssd_storage_directori": 4, "ps_host": 4, "ssd_rocksdb_write_buffer_s": 4, "ssd_rocksdb_shard": 4, "gather_ssd_cache_stat": 4, "stats_reporter_config": 4, "tbestatsreporterconfig": 4, "use_passed_in_path": 4, "ssd": [4, 5], "ssdtablebatchedembeddingbag": 4, "directori": 4, "data00_nvidia": 4, "local_rank": 4, "host": [4, 5, 6], "ip": 4, "port": 4, "2000": 4, "2001": 4, "2002": 4, "reason": [4, 12], "hashabl": 4, "rocksdb": 4, "write": 4, "relav": 4, "compact": 4, "frequenc": 4, "std": 4, "report": [4, 9], "od": 4, "report_interv": 4, "interv": [4, 9, 11], "ods_prefix": 4, "expos": [4, 12], "concret": 4, "behavior": [4, 7, 12], "achiev": 4, "late": 4, "possibl": [4, 5, 9], "__torch_function__": 4, "below": 4, "help": 4, "doesn": [4, 11, 12], "python": [4, 7], "magic": 4, "__getattr__": 4, "caveat": 4, "arbitari": 4, "mechan": [4, 11], "ensur": [4, 11, 14], "perfect": 4, "quickli": 4, "long": [4, 5, 11], "kwd": 4, "vt_co": 4, "augment": 4, "trigger": [4, 11], "keyedlazyawait": 4, "defer": 4, "mixin": 4, "inherit": [4, 9, 11], "mro": 4, "properli": [4, 11], "select": [4, 5, 6, 14], "lazynowait": 4, "classmethod": [4, 5, 8, 13], "noopquantizedcommcodec": 4, "quantizationcontext": 4, "No": [4, 6, 9], "calc_quantized_s": 4, "input_len": 4, "decod": 4, "input_grad": 4, "encod": 4, "quantized_dtyp": 4, "nowait": [4, 7], "obj": 4, "objectpoolshardingplan": 4, "objectpoolshardingtyp": 4, "replicated_row_wis": 4, "row_wis": [4, 11], "sharding_spec": 4, "shardingspec": 4, "cache_param": [4, 5], "enforce_hbm": [4, 5], "stochastic_round": [4, 5], "bounds_check_mod": [4, 5], "boundscheckmod": [4, 5], "output_dtyp": [4, 5, 8, 13], "key_value_param": [4, 5], "hbm": [4, 5], "stochast": [4, 5], "round": [4, 5], "bound": [4, 5], "place": [4, 5, 6, 12, 14], "column_wis": [4, 11], "seen": [4, 7], "individu": [4, 5], "table_row_wis": [4, 11], "data_parallel": [4, 5, 11], "parameterstorag": 4, "physic": 4, "constraint": [4, 5, 8], "shardingplann": [4, 5], "ddr": [4, 5], "pipelinetyp": [4, 5], "py": 4, "about": 4, "train_bas": 4, "train_prefetch_sparse_dist": 4, "train_sparse_dist": 4, "pooled_all_to_al": 4, "reduce_scatt": 4, "float32": [4, 8, 11, 13], "quantized_tensor": 4, "quantized_comm_codec": 4, "collective_cal": 4, "output_tensor": 4, "assert_clos": 4, "int8": [4, 8], "addit": [4, 5, 7, 8, 11, 12, 14], "carri": 4, "session": 4, "respect": [4, 11], "sequence_all_to_al": 4, "modulenocopymixin": [4, 13], "respons": 4, "vise": [4, 12], "versa": [4, 12], "practic": 4, "from_loc": 4, "typic": [4, 5, 7, 11, 12, 14], "from_process_group": 4, "fqn": [4, 5], "larger": [4, 5], "desir": 4, "get_plan_for_modul": 4, "module_path": 4, "re": [4, 12], "stabil": 4, "table_column_wis": [4, 11], "get_tensor_size_byt": 4, "rank_devic": 4, "device_typ": 4, "scope": 4, "copyablemixin": 4, "target": [4, 10], "mymodul": 4, "forkedpdb": 4, "completekei": 4, "tab": 4, "stdin": 4, "stdout": 4, "nosigint": 4, "readrc": 4, "pdb": 4, "fork": 4, "multiprocess": 4, "child": 4, "debug": [4, 5, 9], "multiprocessing_util": 4, "import": [4, 5, 8, 11, 13], "get_rank": 4, "set_trac": 4, "barrier": 4, "interact": [4, 10, 11], "add_params_from_parameter_shard": 4, "parameter_shard": 4, "extract": 4, "add": [4, 7, 11, 12], "ones": 4, "add_prefix_to_state_dict": 4, "filter": [4, 11], "append_prefix": 4, "append": 4, "convert_to_fbgemm_typ": 4, "copy_to_devic": 4, "current_devic": [4, 8], "to_devic": 4, "filter_state_dict": 4, "start": [4, 11, 14], "strip": 4, "begin": [4, 12], "get_unsharded_module_nam": 4, "level": [4, 6], "don": [4, 8, 11], "merge_fused_param": 4, "param_fused_param": 4, "configur": 4, "cache_precis": 4, "preset": 4, "table_level_fused_param": 4, "precid": 4, "grouped_fused_param": 4, "null": 4, "none_throw": 4, "_t": 4, "messag": [4, 5], "unexpect": 4, "assertionerror": 4, "optimizer_type_to_emb_opt_typ": 4, "optimizer_class": 4, "emboptimtyp": 4, "sharded_model_copi": 4, "m_cpu": 4, "deepcopi": 4, "managedcollisioncollectionawait": 4, "managedcollisioncollectioncontext": 4, "managedcollisioncollectionshard": 4, "managedcollisioncollect": [4, 11], "shardedmanagedcollisioncollect": 4, "evict": [4, 11], "open_slot": [4, 11], "create_mc_shard": 4, "managedcollisionembeddingbagcollectioncontext": 4, "evictions_per_t": 4, "remapped_kjt": 4, "managedcollisionembeddingbagcollectionshard": 4, "ebc_shard": 4, "mc_sharder": 4, "basemanagedcollisionembeddingcollectionshard": 4, "managedcollisionembeddingbagcollect": [4, 11], "shardedmanagedcollisionembeddingbagcollect": 4, "baseshardedmanagedcollisionembeddingcollect": 4, "managedcollisionembeddingcollectioncontext": 4, "managedcollisionembeddingcollectionshard": 4, "ec_shard": 4, "managedcollisionembeddingcollect": [4, 11], "shardedmanagedcollisionembeddingcollect": 4, "consid": [5, 11, 13, 14], "perf": 5, "storag": [5, 14], "peak": 5, "elimin": 5, "might": [5, 14], "oom": [5, 9], "partit": [5, 6], "kernel_bw_lookup": 5, "compute_devic": [5, 8], "hbm_mem_bw": 5, "ddr_mem_bw": 5, "caching_ratio": 5, "calcul": [5, 9], "bandwidth": 5, "ratio": [5, 9], "embeddingenumer": 5, "parameterconstraint": [5, 8], "shardestim": 5, "use_exact_enumerate_ord": 5, "shardabl": 5, "exact": 5, "name_children": 5, "shardingopt": 5, "valid": [5, 11, 14], "popul": [5, 11], "populate_estim": 5, "sharding_opt": 5, "descript": [5, 9], "get_partition_by_typ": 5, "string": [5, 8, 11], "partitionbytyp": 5, "greedyperfpartition": 5, "sort_bi": 5, "sortbi": 5, "balance_modul": 5, "greedi": 5, "sort": [5, 11], "smaller": 5, "effect": [5, 11], "storage_constraint": 5, "partition_bi": 5, "uniform": [5, 11], "strategi": 5, "final": [5, 9, 10, 11, 13, 14], "docstr": [5, 9, 14], "partition_by_devic": 5, "done": [5, 11, 12, 14], "clariti": 5, "memorybalancedpartition": 5, "max_search_count": 5, "10": [5, 10, 11, 13, 14], "toler": 5, "02": 5, "maximum": [5, 9, 11], "greedypartition": 5, "reject": 5, "200": 5, "wors": 5, "repeatedli": 5, "find": 5, "least": 5, "amount": 5, "ordereddevicehardwar": 5, "devicehardwar": 5, "local_world_s": 5, "shardingoptiongroup": 5, "storage_sum": 5, "perf_sum": 5, "param_count": 5, "set_hbm_per_devic": 5, "hbm_per_devic": 5, "noopperfmodel": 5, "perfmodel": 5, "among": [5, 10], "here": 5, "without": [5, 9, 14], "noopstoragemodel": 5, "storagereserv": 5, "performance_model": 5, "heteroembeddingshardingplann": 5, "topology_group": 5, "embeddingoffloadscaleuppropos": 5, "use_depth": 5, "allocate_budget": 5, "budget": 5, "allocation_prior": 5, "build_affine_storage_model": 5, "uvm_caching_sharding_opt": 5, "clf_to_byt": 5, "feedback": 5, "perf_rat": 5, "get_budget": 5, "get_cach": 5, "get_expected_lookup": 5, "search_spac": 5, "next_plan": 5, "starting_propos": 5, "promote_high_prefetch_overheaad_table_to_hbm": 5, "overhead": 5, "io": 5, "than": [5, 11, 12], "offload": 5, "undo": 5, "promot": 5, "greedypropos": 5, "threshold": [5, 9, 11], "fashion": [5, 6], "On": [5, 11], "tri": [5, 12], "next": 5, "max": [5, 11, 12], "earli": 5, "stop": 5, "consecut": 5, "best_perf_r": 5, "gridsearchpropos": 5, "max_propos": 5, "10000": 5, "uniformpropos": 5, "proposers_to_proposals_list": 5, "proposers_list": 5, "static_feedback": 5, "embeddingoffloadstat": 5, "mrc_hist_count": 5, "height": 5, "uvm_fused_cach": 5, "cachebl": 5, "area": [5, 9], "under": [5, 9], "curv": [5, 9], "n": [5, 8, 11, 14], "histogram": 5, "bin": 5, "nth": 5, "wa": [5, 8], "estimate_cache_miss_r": 5, "cache_s": 5, "hist": 5, "mrc": 5, "embeddingperfestim": 5, "is_infer": 5, "wall": 5, "sharder_map": 5, "perf_func_emb_wall_tim": 5, "shard_siz": 5, "input_length": 5, "input_data_type_s": 5, "table_data_type_s": 5, "output_data_type_s": 5, "fwd_a2a_comm_data_type_s": 5, "bwd_a2a_comm_data_type_s": 5, "fwd_sr_comm_data_type_s": 5, "bwd_sr_comm_data_type_s": 5, "num_pool": 5, "intra_host_bw": 5, "inter_host_bw": 5, "bwd_compute_multipli": 5, "weighted_feature_bwd_compute_multipli": 5, "is_pool": 5, "expected_cache_fetch": 5, "uneven_sharding_perf_multipli": 5, "attempt": 5, "rel": [5, 11], "tw": 5, "dp": 5, "queri": 5, "fwd_comm_data_type_s": 5, "bwd_comm_data_type_s": 5, "sampl": [5, 9, 11], "thread": 5, "machin": [5, 11], "embeddingbag": [5, 7, 10, 11, 13], "unpool": 5, "ebc": [5, 10, 11, 13], "signifi": 5, "fetch": 5, "embeddingstorageestim": 5, "pipeline_typ": 5, "calculate_pipeline_io_cost": 5, "output_s": [5, 11], "prefetch_s": 5, "multipass_prefetch_max_pass": 5, "calculate_shard_storag": 5, "compris": 5, "synonym": 5, "byte": [5, 8, 9], "embeddingstat": 5, "log": [5, 9], "sharding_plan": 5, "num_propos": 5, "num_plan": 5, "run_tim": 5, "best_plan": 5, "tabular": 5, "view": 5, "chosen": [5, 11], "evalu": [5, 11], "successfulli": 5, "taken": 5, "noopembeddingstat": 5, "noop": 5, "round_to_one_sigfig": 5, "fixedpercentagestoragereserv": 5, "percentag": 5, "heuristicalstoragereserv": 5, "parameter_multipli": 5, "dense_tensor_estim": 5, "heurist": 5, "extra": 5, "percent": 5, "act": 5, "margin": 5, "error": [5, 9, 11, 14], "beyond": 5, "inferencestoragereserv": 5, "customtopologydata": 5, "get_data": 5, "has_data": 5, "supported_field": 5, "ddr_cap": 5, "hbm_cap": 5, "512": [5, 9], "min_partit": 5, "pooling_factor": 5, "fbgemm_gpu": 5, "split_table_batched_embeddings_ops_common": 5, "device_group": 5, "around": 5, "lower": [5, 7, 8, 12, 13], "rang": [5, 7, 11], "divid": [5, 9], "divis": 5, "optionallist": 5, "momentum": 5, "determinist": 5, "maintain": 5, "accuraci": [5, 11], "term": [5, 11], "fp16": 5, "exce": 5, "todai": 5, "bldm": 5, "fwd_comput": 5, "fwd_comm": 5, "bwd_comput": 5, "bwd_comm": 5, "prefetch_comput": 5, "breakdown": 5, "plannererror": 5, "error_typ": 5, "plannererrortyp": 5, "classifi": 5, "insufficient_storag": 5, "strict_constraint": 5, "prospos": 5, "paritit": 5, "subset": 5, "much": [5, 12], "depend": [5, 8, 11], "One": [5, 9, 11], "eval": 5, "job": 5, "tower": [5, 11], "cache_load_factor": 5, "module_pool": 5, "sharding_option_nam": 5, "num_input": 5, "num_shard": 5, "total_perf": 5, "total_storag": 5, "capac": 5, "hardwar": 5, "fits_in": 5, "963146416": 5, "128": [5, 9], "54760833": 5, "024": 5, "644245094": 5, "13421772": 5, "custom_topology_data": 5, "binarysearchpred": 5, "extern": [5, 10], "predic": 5, "discov": 5, "binari": [5, 9], "minim": 5, "invoc": 5, "try": 5, "prior_result": 5, "probe": 5, "prior": 5, "entir": [5, 6], "explor": 5, "reach": [5, 9], "luusjaakolasearch": 5, "max_iter": 5, "seed": 5, "42": 5, "left_cost": 5, "clamp": 5, "variant": 5, "luu": 5, "jaakola": 5, "en": 5, "wikipedia": 5, "wiki": 5, "far": 5, "associ": 5, "cost": [5, 11], "left": [5, 14], "right": [5, 9, 11], "fy": 5, "y": [5, 11], "previou": 5, "subsequ": 5, "been": [5, 11], "shrink_right": 5, "shrink": 5, "boundari": 5, "infin": [5, 12], "random": 5, "bytes_to_gb": 5, "num_byt": 5, "bytes_to_mb": 5, "gb_to_byt": 5, "gb": 5, "local_s": [5, 6], "format": [5, 8, 14], "prod": 5, "reset_shard_rank": 5, "sharder_nam": 5, "storage_repr_in_gb": 5, "basecwembeddingshard": 6, "basetwembeddingshard": 6, "cwpooledembeddingshard": 6, "infercwpooledembeddingdist": 6, "infercwpooledembeddingdistwithpermut": 6, "infercwpooledembeddingshard": 6, "type": [6, 7, 8, 9, 10, 11, 12, 13, 14], "basedpembeddingshard": 6, "dppooledembeddingdist": 6, "dppooledembeddingshard": 6, "dpsparsefeaturesdist": 6, "sparsefeatur": 6, "baserwembeddingshard": 6, "inferrwpooledembeddingdist": 6, "inferrwpooledembeddingshard": 6, "inferrwsparsefeaturesdist": 6, "rwpooledembeddingdist": 6, "share": [6, 11], "rwpooledembeddingshard": 6, "evenli": 6, "rwsparsefeaturesdist": 6, "intra_pg": 6, "hash": [6, 11], "get_block_sizes_runtime_devic": 6, "runtime_devic": 6, "tensor_cach": 6, "int32": [6, 14], "get_embedding_shard_metadata": 6, "grouped_embedding_configs_per_rank": 6, "infertwembeddingshard": 6, "infertwpooledembeddingdist": 6, "infertwsparsefeaturesdist": 6, "twpooledembeddingdist": 6, "twpooledembeddingshard": 6, "twsparsefeaturesdist": 6, "twcwpooledembeddingshard": 6, "basetwrwembeddingshard": 6, "twrwpooledembeddingdist": 6, "cross_pg": 6, "dim_sum_per_nod": 6, "emb_dim_per_node_per_featur": 6, "twrwpooledembeddingshard": 6, "twrwsparsefeaturesdist": 6, "id_list_features_per_rank": 6, "id_score_list_features_per_rank": 6, "id_list_feature_hash_s": 6, "id_score_list_feature_hash_s": 6, "shuffl": 6, "look": [6, 7, 14], "reorder": 6, "document": [7, 10], "leaf_modul": 7, "trace": [7, 8], "torchscript": 7, "create_arg": 7, "complex": 7, "memory_format": 7, "opoverload": 7, "symint": 7, "symbool": 7, "symfloat": 7, "prepar": [7, 11], "graph": 7, "emit": 7, "appropri": 7, "is_leaf_modul": 7, "module_qualified_nam": 7, "path_of_modul": 7, "mod": 7, "abil": 7, "made": [7, 12], "root": 7, "concrete_arg": 7, "guarante": [7, 12], "is_fx_trac": 7, "symbolic_trac": 7, "graphmodul": 7, "symbol": 7, "record": [7, 11], "partial": 7, "your": [7, 9], "structur": [7, 12], "predictfactorypackag": 8, "save_predict_factori": 8, "predict_factori": 8, "predictfactori": 8, "config": [8, 9, 11], "pathlib": 8, "binaryio": 8, "extra_fil": 8, "loader_cod": 8, "nimport": 8, "packag": 8, "nmodule_factori": 8, "package_import": 8, "_sysimport": 8, "set_extern_modul": 8, "decor": 8, "abstractmethod": 8, "set_mocked_modul": 8, "load_config_text": 8, "load_pickle_config": 8, "clazz": 8, "batchingmetadata": 8, "pin": 8, "kept": [8, 11], "sync": [8, 9, 14], "learn": [8, 10, 11, 12], "batching_metadata": 8, "infom": 8, "batching_metadata_json": 8, "serial": 8, "json": 8, "eas": [8, 11], "pars": 8, "create_predict_modul": 8, "transformmodul": 8, "transform_state_dict": 8, "init_process_group": 8, "get_world_s": 8, "model_inputs_data": 8, "benchmark": 8, "qualname_metadata": 8, "qualnamemetadata": 8, "qualnam": 8, "inform": [8, 9, 14], "qualname_metadata_json": 8, "result_metadata": 8, "run_weights_dependent_transform": 8, "predict_modul": 8, "predict": [8, 9], "run_weights_independent_tranform": 8, "fx": 8, "predictmodul": 8, "predict_forward": 8, "need_preproc": 8, "quantize_dens": 8, "additional_embedding_module_typ": 8, "quantize_embed": 8, "inplac": [8, 13], "additional_qconfig_spec_kei": 8, "additional_map": 8, "per_table_weight_dtyp": [8, 11], "quantize_featur": 8, "quantize_inference_model": 8, "quantization_map": 8, "fp_weight_dtyp": 8, "shard_quant_model": 8, "device_memory_s": 8, "trim_torch_package_prefix_from_typenam": 8, "typenam": 8, "accuracymetr": 9, "my_rank": 9, "task": 9, "rectaskinfo": 9, "compute_mod": 9, "reccomputemod": 9, "unfused_tasks_comput": 9, "window_s": 9, "fused_update_limit": 9, "compute_on_all_rank": 9, "should_validate_upd": 9, "process_group": 9, "recmetr": 9, "accuracymetriccomput": 9, "recmetriccomput": 9, "constructor": [9, 11], "cut": [9, 11], "off": [9, 11], "compute_accuraci": 9, "accuracy_sum": 9, "weighted_num_sampl": 9, "compute_accuracy_sum": 9, "get_accuracy_st": 9, "aucmetr": 9, "aucmetriccomput": 9, "grouped_auc": 9, "apply_bin": 9, "grouping_kei": 9, "reset": [9, 11, 12], "n_task": 9, "n_exampl": 9, "compute_auc": 9, "classif": 9, "compute_auc_per_group": 9, "auprcmetr": 9, "auprcmetriccomput": 9, "grouped_auprc": 9, "pr": 9, "compute_auprc": 9, "compute_auprc_per_group": 9, "calibrationmetr": 9, "calibrationmetriccomput": 9, "convers": 9, "compute_calibr": 9, "calibration_num": 9, "calibration_denom": 9, "get_calibration_st": 9, "ctrmetric": 9, "ctrmetriccomput": 9, "click": 9, "compute_ctr": 9, "ctr_num": 9, "ctr_denom": 9, "get_ctr_stat": 9, "maemetr": 9, "maemetriccomput": 9, "absolut": 9, "compute_error_sum": 9, "compute_ma": 9, "error_sum": 9, "get_mae_st": 9, "msemetr": 9, "msemetriccomput": 9, "squar": [9, 11], "compute_ms": 9, "compute_rms": 9, "get_mse_st": 9, "multiclassrecallmetr": 9, "multiclassrecallmetriccomput": 9, "compute_multiclass_recall_at_k": 9, "tp_at_k": 9, "total_weight": 9, "compute_true_positives_at_k": 9, "n_class": 9, "k": [9, 11], "tp": 9, "count": [9, 11], "1st": 9, "2nd": [9, 11], "n_sampl": 9, "ground": 9, "truth": 9, "true_positives_list": 9, "9": [9, 10], "15": 9, "compute_multiclass_k_sum": 9, "5000": 9, "7500": 9, "0000": [9, 11], "get_multiclass_recall_st": 9, "ndcgcomput": 9, "exponential_gain": 9, "session_kei": 9, "session_id": 9, "report_ndcg_as_decreasing_curv": 9, "remove_single_length_sess": 9, "scale_by_weights_tensor": 9, "is_negative_task_mask": 9, "normal": [9, 11], "discount": 9, "gain": 9, "tensorboard": 9, "captur": 9, "decreas": 9, "loss": [9, 12], "oppos": 9, "visual": [9, 14], "similarli": 9, "entropi": 9, "pointwis": 9, "noth": 9, "ndcgmetric": 9, "nemetr": 9, "nemetriccomput": 9, "include_logloss": 9, "allow_missing_label_with_zero_weight": 9, "vanilla": 9, "logloss": 9, "compute_cross_entropi": 9, "eta": 9, "compute_logloss": 9, "ce_sum": 9, "pos_label": 9, "neg_label": 9, "compute_n": 9, "get_ne_st": 9, "recallmetr": 9, "recallmetriccomput": 9, "compute_false_neg_sum": 9, "compute_recal": 9, "num_true_posit": 9, "num_false_negit": 9, "compute_true_pos_sum": 9, "get_recall_st": 9, "precisionmetr": 9, "precisionmetriccomput": 9, "compute_false_pos_sum": 9, "compute_precis": 9, "num_false_posit": 9, "get_precision_st": 9, "raucmetr": 9, "raucmetriccomput": 9, "grouped_rauc": 9, "regress": 9, "compute_rauc": 9, "compute_rauc_per_group": 9, "conquer_and_count": 9, "left_index": 9, "mid_index": 9, "right_index": 9, "count_reverse_pairs_divide_and_conqu": 9, "low": [9, 11], "throughputmetr": 9, "window_second": 9, "warmup_step": 9, "32": [9, 11], "time_to_train_one_step": 9, "trainer": 9, "window": 9, "window_throughput": 9, "warmup": 9, "Not": 9, "weightedavgmetr": 9, "weightedavgmetriccomput": 9, "get_mean": 9, "value_sum": 9, "num_sampl": 9, "xaucmetr": 9, "xaucmetriccomput": 9, "compute_weighted_num_pair": 9, "compute_xauc": 9, "weighted_num_pair": 9, "get_xauc_st": 9, "recmetricmodul": 9, "rec_task": 9, "recmetriclist": 9, "throughput_metr": 9, "state_metr": 9, "statemetr": 9, "compute_interval_step": 9, "min_compute_interv": 9, "max_compute_interv": 9, "inf": [9, 12], "memory_usage_limit_mb": 9, "three": 9, "standalon": 9, "characterist": 9, "componenet": 9, "intern": [9, 11, 14], "logic": [9, 11], "unit": [9, 11], "limit": [9, 11], "dataclass": 9, "replac": [9, 12], "defaultmetricsconfig": 9, "statemetricenum": 9, "metricmodul": 9, "generate_metric_modul": 9, "metric_class": 9, "metrics_config": 9, "64": [9, 11], "state_metrics_map": 9, "mock_optim": 9, "check_memory_usag": 9, "compute_count": 9, "sink": 9, "get_memory_usag": 9, "get_required_input": 9, "last_compute_tim": 9, "local_comput": 9, "memory_usage_mb_avg": 9, "oom_count": 9, "should_comput": 9, "unsync": [9, 14], "model_out": 9, "model_output": 9, "due": 9, "slide": 9, "qat": 9, "get_metr": 9, "metricsconfig": 9, "metriccomputationreport": 9, "metrics_namespac": 9, "metricnamebas": 9, "metric_prefix": 9, "metricprefix": 9, "main": 9, "templat": 9, "signal": 9, "mathemat": 9, "own": 9, "__init__": 9, "_namespac": 9, "_metrics_comput": 9, "consum": 9, "invalid": 9, "Will": 9, "defaulttaskinfo": 9, "rec": 9, "underli": 9, "overwrit": 9, "synchron": 9, "get_window_st": 9, "state_nam": 9, "get_window_state_nam": 9, "pre_comput": 9, "pre": [9, 11, 12], "torchmetr": 9, "aggreg": 9, "recmetricexcept": 9, "encapul": 9, "required_input": 9, "windowbuff": 9, "max_siz": 9, "max_buffer_count": 9, "aggregate_st": 9, "window_st": 9, "curr_stat": 9, "dequ": 9, "densearch": 10, "hidden_layer_s": 10, "deepfmnn": 10, "layer": [10, 11, 12], "embedding_dimens": 10, "dimension": 10, "hidden": [10, 11], "sparsearch": 10, "20": [10, 11], "dense_arch": 10, "dense_arch_input": 10, "dense_embed": 10, "fminteractionarch": 10, "fm_in_featur": 10, "sparse_feature_nam": 10, "deep_fm_dimens": 10, "dense_featur": [10, 11], "paper": [10, 11], "arxiv": 10, "pdf": 10, "1703": 10, "04247": 10, "cat": [10, 11], "dense_modul": [10, 11], "deep": [10, 11], "di": 10, "arch": 10, "fm_inter_arch": 10, "length_per_kei": [10, 14], "cat_fm_output": 10, "overarch": 10, "mlp": 10, "over_arch": 10, "logit": 10, "simpledeepfmnn": 10, "num_dense_featur": 10, "embedding_bag_collect": [10, 11], "relationship": 10, "project": 10, "those": [10, 11], "deep_fm": 10, "propos": 10, "notat": 10, "throughout": 10, "eb1_config": [10, 13], "f3": 10, "eb2_config": [10, 13], "t2": [10, 11, 13], "sparse_nn": 10, "over_embedding_dim": 10, "from_offsets_sync": [10, 11, 13, 14], "sparse_arch": 10, "extens": 11, "establish": 11, "pattern": 11, "swishlayernorm": 11, "positionweightedmodul": 11, "lazymoduleextensionmixin": 11, "embeddingtow": 11, "embeddingtowercollect": 11, "input_dim": 11, "swish": 11, "sigmoid": 11, "layernorm": 11, "d1": 11, "d2": 11, "d3": 11, "last": [11, 14], "sln": 11, "num_lay": 11, "stack": 11, "learnabl": 11, "polynom": 11, "full": [11, 12, 14], "matrix": 11, "nxn": 11, "cover": 11, "bit": 11, "x_": 11, "x_0": 11, "w_l": 11, "cdot": 11, "x_l": 11, "b_l": 11, "element": 11, "dcn": 11, "lowrankcrossnet": 11, "low_rank": 11, "highli": 11, "matric": 11, "simplifi": 11, "v_l": 11, "vector": 11, "smartli": 11, "setup": 11, "alwai": [11, 14], "lowrankmixturecrossnet": 11, "num_expert": 11, "relu": 11, "mixtur": 11, "expert": 11, "compar": [11, 14], "subspac": 11, "adapt": 11, "gate": 11, "moe": 11, "expert_i": 11, "k_": 11, "u_": 11, "li": 11, "c_": 11, "v_": 11, "vectorcrossnet": 11, "keep": 11, "nx1": 11, "dot": 11, "thu": [11, 12], "further": [11, 14], "implent": 11, "framework": 11, "factorizationmachin": 11, "fm": 11, "abov": [11, 14], "publish": 11, "learnt": 11, "To": 11, "raw": 11, "architectur": 11, "90": 11, "30": 11, "40": 11, "fb": 11, "lazymlp": 11, "output_dim": 11, "192": 11, "deep_fm_output": 11, "common_spars": 11, "specialized_spars": 11, "embedding_featur": 11, "raw_embedding_featur": 11, "nativ": 11, "trained_embed": 11, "native_embed": 11, "ident": 11, "mention": 11, "baseembeddingconfig": 11, "get_weight_init_max": 11, "get_weight_init_min": 11, "embeddingconfig": [11, 13], "quantconfig": 11, "placeholderobserv": [11, 13], "alia": 11, "data_type_to_dtyp": 11, "data_type_to_sparse_typ": 11, "sparsetyp": 11, "dtype_to_data_typ": 11, "pooling_type_to_pooling_mod": 11, "pooling_typ": 11, "poolingmod": 11, "pooling_type_to_str": 11, "sensit": [11, 13], "jag": [11, 13, 14], "table_0": [11, 13], "table_1": [11, 13], "pooled_embed": 11, "8899": 11, "1342": 11, "9060": 11, "0905": 11, "2814": 11, "9369": 11, "7783": 11, "1598": 11, "0695": 11, "3265": 11, "1011": 11, "4256": 11, "1846": 11, "1648": 11, "0893": 11, "3590": 11, "9784": 11, "7681": 11, "grad_fn": [11, 13], "catbackward0": 11, "offset_per_kei": [11, 14], "need_indic": [11, 13], "e1_config": [11, 13], "e2_config": [11, 13], "ec": [11, 13], "feature_embed": [11, 13], "2050": [11, 13], "5478": [11, 13], "6054": [11, 13], "7352": [11, 13], "3210": [11, 13], "0399": [11, 13], "1279": [11, 13], "1756": [11, 13], "4130": [11, 13], "7519": [11, 13], "4341": [11, 13], "0499": [11, 13], "9329": [11, 13], "0697": [11, 13], "8095": [11, 13], "embeddingbackward": [11, 13], "embedding_names_by_t": [11, 13], "get_embedding_names_by_t": 11, "process_pooled_embed": 11, "reorder_inverse_indic": 11, "basefeatureprocessor": 11, "max_length": 11, "truncat": 11, "positionweightedprocessor": 11, "feature_length": 11, "feature0": [11, 14], "feature1": [11, 14], "feature2": 11, "from_lengths_sync": [11, 14], "pw": 11, "featureprocessorcollect": 11, "feature_processor_modul": 11, "positionweightedfeatureprocessor": 11, "fp_featur": 11, "non_fp_featur": 11, "non_fp": 11, "feature_process": 11, "And": 11, "offsets_to_range_tracebl": 11, "position_weighted_module_update_featur": 11, "weighted_featur": 11, "lazymodulemixin": 11, "temporari": 11, "upstream": 11, "59923": 11, "testlazymoduleextensionmixin": 11, "test": 11, "_infer_paramet": 11, "pariti": 11, "_call_impl": 11, "fn": 11, "children": 11, "uniniti": 11, "dummi": [11, 12], "lazylinear": 11, "fail": [11, 14], "becaus": [11, 12], "hasn": 11, "yet": 11, "now": [11, 14], "lazy_appli": 11, "attach": 11, "numer": 11, "immedi": 11, "seq": 11, "in_siz": 11, "layer_s": 11, "perceptron": 11, "multi": 11, "out_siz": 11, "swish_layernorm": 11, "won": 11, "mlp_modul": 11, "assert": 11, "o": 11, "channel": 11, "unpadded_length": 11, "reindexed_length": 11, "reindexed_length_per_kei": 11, "reindexed_valu": 11, "check_module_output_dimens": 11, "verifi": 11, "construct_jagged_tensor": 11, "features_to_permute_indic": 11, "original_featur": 11, "construct_jagged_tensors_infer": 11, "construct_modulelist_from_single_modul": 11, "nest": 11, "reiniti": 11, "convert_list_of_modules_to_modulelist": 11, "deterministic_dedup": 11, "race": 11, "condit": 11, "conflict": 11, "extract_module_or_tensor_cal": 11, "module_or_cal": 11, "get_module_output_dimens": 11, "init_mlp_weights_xavier_uniform": 11, "jagged_index_select_with_empti": 11, "output_offset": 11, "distancelfu_evictionpolici": 11, "decay_expon": 11, "threshold_filtering_func": 11, "mchevictionpolici": 11, "coalesce_history_metadata": 11, "current_it": 11, "history_metadata": 11, "unique_ids_count": 11, "unique_inverse_map": 11, "additional_id": 11, "threshold_mask": 11, "histori": 11, "invers": [11, 14], "history_accumul": 11, "coalesc": 11, "metadata_info": 11, "mchevictionpolicymetadatainfo": 11, "record_history_metadata": 11, "incoming_id": 11, "incom": 11, "polici": [11, 12], "update_metadata_and_generate_eviction_scor": 11, "mch_size": 11, "coalesced_history_argsort_map": 11, "coalesced_history_sorted_unique_ids_count": 11, "coalesced_history_mch_matching_elements_mask": 11, "coalesced_history_mch_matching_indic": 11, "mch_metadata": 11, "coalesced_history_metadata": 11, "evicted_indic": 11, "selected_new_indic": 11, "mch": 11, "lfu_evictionpolici": 11, "lru_evictionpolici": 11, "metadata_nam": 11, "is_mch_metadata": 11, "is_history_metadata": 11, "mchmanagedcollisionmodul": 11, "zch_size": 11, "eviction_polici": 11, "eviction_interv": 11, "input_hash_s": 11, "9223372036854775807": 11, "input_hash_func": 11, "mch_hash_func": 11, "output_global_offset": 11, "managedcollisionmodul": 11, "zch": 11, "manag": 11, "collis": 11, "output_size_offset": 11, "drive": 11, "greater": 11, "residu": 11, "legaci": 11, "shift": 11, "zch_output_rang": 11, "down": 11, "applic": 11, "slot": 11, "assumptionn": 11, "downstream": 11, "modifi": [11, 12], "rtype": 11, "vs": 11, "preprocess": 11, "profil": 11, "rebuild_with_output_id_rang": 11, "output_id_rang": 11, "mc": 11, "hack": 11, "remap": 11, "managed_collision_modul": 11, "mcc": 11, "embedding_confg": 11, "collsion": 11, "max_output_id": 11, "remapping_range_start_index": 11, "mcm": 11, "mcm_jt": 11, "fp": 11, "apply_mc_method_to_jt_dict": 11, "features_dict": 11, "table_to_featur": 11, "managed_collis": 11, "average_threshold_filt": 11, "id_count": 11, "dynamic_threshold_filt": 11, "threshold_skew_multipli": 11, "total_count": 11, "num_id": 11, "probabilistic_threshold_filt": 11, "per_id_prob": 11, "01": 11, "probabl": 11, "appear": 11, "60": 11, "randomli": 11, "chanc": 11, "basemanagedcollisionembeddingcollect": 11, "managed_collision_collect": 11, "return_remapped_featur": 11, "embedding_collect": 11, "meaning": 12, "prohibit": 12, "empti": [12, 14], "sever": 12, "combinedoptim": 12, "optimizerwrapp": 12, "rowwis": 12, "gradientclip": 12, "norm": 12, "gradientclippingoptim": 12, "max_gradi": 12, "norm_typ": 12, "p": 12, "closur": 12, "reevalu": 12, "emptyfusedoptim": 12, "fusedoptim": 12, "zero_grad": 12, "set_to_non": 12, "zero": [12, 14], "footprint": 12, "modestli": 12, "certain": 12, "0s": 12, "behav": 12, "did": 12, "altogeth": 12, "param_group": 12, "meant": 12, "post_load_state_dict": 12, "prepend_opt_kei": 12, "opt_kei": 12, "save_param_group": 12, "set_optimizer_step": 12, "stricter": 12, "old": 12, "switch": 12, "flag": 12, "identifi": 12, "littl": 12, "add_param_group": 12, "fine": 12, "frozen": 12, "trainabl": 12, "progress": 12, "what": 12, "init_st": 12, "checkpoint": 12, "usabl": 12, "sure": 12, "sd": 12, "load_checkpoint": 12, "protocol": 12, "keyedoptimizerwrapp": 12, "optim_factori": 12, "conveni": 12, "warmupoptim": 12, "stage": 12, "warmupstag": 12, "lr": 12, "lr_param": 12, "param_nam": 12, "__warmup": 12, "adjust": 12, "schedul": 12, "go": 12, "fake": 12, "warmuppolici": 12, "constant": 12, "cosine_annealing_warm_restart": 12, "invsqrt": 12, "inv_sqrt": 12, "poli": 12, "max_it": 12, "lr_scale": 12, "decay_it": 12, "sgdr_period": 12, "speed": 13, "trec_quant": 13, "trec": 13, "qconfig": 13, "activ": 13, "with_arg": 13, "qint8": 13, "quantize_dynam": 13, "qconfig_spec": 13, "table_name_to_quantized_weight": 13, "register_tb": 13, "quant_state_dict_split_scale_bia": 13, "row_align": 13, "qebc": 13, "quantembeddingbagcollect": 13, "from_float": 13, "quantized_embed": 13, "use_precomputed_fake_qu": 13, "for_each_module_of_type_do": 13, "pruned_num_embed": 13, "pruning_indices_map": 13, "quant_prep_customize_row_align": 13, "quant_prep_enable_quant_state_dict_split_scale_bia": 13, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 13, "quant_prep_enable_register_tb": 13, "quantize_state_dict": 13, "table_name_to_data_typ": 13, "table_name_to_pruning_indices_map": 13, "whose": 14, "dimes": 14, "computejtdicttokjt": 14, "jt_dict": 14, "dim_1": 14, "dim_0": 14, "computekjttojtdict": 14, "keyed_jagged_tensor": 14, "jit": 14, "script": 14, "abl": 14, "NOT": 14, "expens": 14, "values_dtyp": 14, "weights_dtyp": 14, "lengths_dtyp": 14, "from_dens": 14, "2d": 14, "11": 14, "12": 14, "j1": 14, "from_dense_length": 14, "lengths_or_non": 14, "offsets_or_non": 14, "non_block": 14, "new_devic": 14, "to_dens": 14, "inttensor": 14, "values_list": 14, "to_dense_weight": 14, "weights_list": 14, "to_padded_dens": 14, "desired_length": 14, "padding_valu": 14, "longest": 14, "pad": 14, "dt": 14, "to_padded_dense_weight": 14, "d_wt": 14, "weights_or_non": 14, "jaggedtensormeta": 14, "namespac": 14, "abcmeta": 14, "proxyableclassmeta": 14, "stride_per_key_per_rank": 14, "outer": 14, "inner": 14, "index_per_kei": 14, "expand": 14, "dedupl": 14, "dim_2": 14, "w0": 14, "w1": 14, "w2": 14, "w3": 14, "w4": 14, "w5": 14, "w6": 14, "w7": 14, "dist_init": 14, "variable_stride_per_kei": 14, "num_work": 14, "dist_label": 14, "dist_split": 14, "key_split": 14, "dist_tensor": 14, "empty_lik": 14, "flatten_length": 14, "from_jt_dict": 14, "implicit": 14, "variable_feature_dim": 14, "But": 14, "That": 14, "didn": 14, "notic": 14, "correctli": 14, "technic": 14, "know": 14, "violat": 14, "precondit": 14, "fix": 14, "inverse_indices_or_non": 14, "length_per_key_or_non": 14, "lengths_offset_per_kei": 14, "offset_per_key_or_non": 14, "indices_tensor": 14, "pin_memori": 14, "segment": 14, "stride_per_kei": 14, "to_dict": 14, "key_dim": 14, "tensor_list": 14, "from_tensor_list": 14, "regroup": 14, "keyed_tensor": 14, "regroup_as_dict": 14, "flatten_kjt_list": 14, "kjt_arr": 14, "jt_is_equ": 14, "jt_1": 14, "jt_2": 14, "comparison": 14, "themselv": 14, "treat": 14, "kjt_is_equ": 14, "kjt_1": 14, "kjt_2": 14, "permute_multi_embed": 14, "regroup_kt": 14, "unflatten_kjt_list": 14}, "objects": {"torchrec": [[4, 0, 0, "-", "distributed"], [7, 0, 0, "module-0", "fx"], [8, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "metrics"], [11, 0, 0, "-", "modules"], [12, 0, 0, "module-0", "optim"], [13, 0, 0, "module-0", "quant"], [14, 0, 0, "module-0", "sparse"]], "torchrec.distributed": [[4, 0, 0, "-", "collective_utils"], [4, 0, 0, "-", "comm"], [4, 0, 0, "-", "comm_ops"], [6, 0, 0, "-", "dist_data"], [4, 0, 0, "-", "embedding"], [4, 0, 0, "-", "embedding_lookup"], [4, 0, 0, "-", "embedding_sharding"], [4, 0, 0, "-", "embedding_types"], [4, 0, 0, "-", "embeddingbag"], [4, 0, 0, "-", "grouped_position_weighted"], [4, 0, 0, "-", "mc_embedding"], [4, 0, 0, "-", "mc_embeddingbag"], [4, 0, 0, "-", "mc_modules"], [4, 0, 0, "-", "model_parallel"], [5, 0, 0, "-", "planner"], [4, 0, 0, "-", "quant_embeddingbag"], [6, 0, 0, "-", "sharding"], [4, 0, 0, "-", "train_pipeline"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[4, 1, 1, "", "invoke_on_rank_and_broadcast_result"], [4, 1, 1, "", "is_leader"], [4, 1, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[4, 1, 1, "", "get_group_rank"], [4, 1, 1, "", "get_local_rank"], [4, 1, 1, "", "get_local_size"], [4, 1, 1, "", "get_num_groups"], [4, 1, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[4, 2, 1, "", "All2AllDenseInfo"], [4, 2, 1, "", "All2AllPooledInfo"], [4, 2, 1, "", "All2AllSequenceInfo"], [4, 2, 1, "", "All2AllVInfo"], [4, 2, 1, "", "All2All_Pooled_Req"], [4, 2, 1, "", "All2All_Pooled_Wait"], [4, 2, 1, "", "All2All_Seq_Req"], [4, 2, 1, "", "All2All_Seq_Req_Wait"], [4, 2, 1, "", "All2Allv_Req"], [4, 2, 1, "", "All2Allv_Wait"], [4, 2, 1, "", "AllGatherBaseInfo"], [4, 2, 1, "", "AllGatherBase_Req"], [4, 2, 1, "", "AllGatherBase_Wait"], [4, 2, 1, "", "ReduceScatterBaseInfo"], [4, 2, 1, "", "ReduceScatterBase_Req"], [4, 2, 1, "", "ReduceScatterBase_Wait"], [4, 2, 1, "", "ReduceScatterInfo"], [4, 2, 1, "", "ReduceScatterVInfo"], [4, 2, 1, "", "ReduceScatterV_Req"], [4, 2, 1, "", "ReduceScatterV_Wait"], [4, 2, 1, "", "ReduceScatter_Req"], [4, 2, 1, "", "ReduceScatter_Wait"], [4, 2, 1, "", "Request"], [4, 2, 1, "", "VariableBatchAll2AllPooledInfo"], [4, 2, 1, "", "Variable_Batch_All2All_Pooled_Req"], [4, 2, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [4, 1, 1, "", "all2all_pooled_sync"], [4, 1, 1, "", "all2all_sequence_sync"], [4, 1, 1, "", "all2allv_sync"], [4, 1, 1, "", "all_gather_base_pooled"], [4, 1, 1, "", "all_gather_base_sync"], [4, 1, 1, "", "all_gather_into_tensor_backward"], [4, 1, 1, "", "all_gather_into_tensor_fake"], [4, 1, 1, "", "all_gather_into_tensor_setup_context"], [4, 1, 1, "", "all_to_all_single_backward"], [4, 1, 1, "", "all_to_all_single_fake"], [4, 1, 1, "", "all_to_all_single_setup_context"], [4, 1, 1, "", "alltoall_pooled"], [4, 1, 1, "", "alltoall_sequence"], [4, 1, 1, "", "alltoallv"], [4, 1, 1, "", "get_gradient_division"], [4, 1, 1, "", "get_use_sync_collectives"], [4, 1, 1, "", "pg_name"], [4, 1, 1, "", "reduce_scatter_base_pooled"], [4, 1, 1, "", "reduce_scatter_base_sync"], [4, 1, 1, "", "reduce_scatter_pooled"], [4, 1, 1, "", "reduce_scatter_sync"], [4, 1, 1, "", "reduce_scatter_tensor_backward"], [4, 1, 1, "", "reduce_scatter_tensor_fake"], [4, 1, 1, "", "reduce_scatter_tensor_setup_context"], [4, 1, 1, "", "reduce_scatter_v_per_feature_pooled"], [4, 1, 1, "", "reduce_scatter_v_pooled"], [4, 1, 1, "", "reduce_scatter_v_sync"], [4, 1, 1, "", "set_gradient_division"], [4, 1, 1, "", "set_use_sync_collectives"], [4, 1, 1, "", "torchrec_use_sync_collectives"], [4, 1, 1, "", "variable_batch_all2all_pooled_sync"], [4, 1, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[4, 3, 1, "", "batch_size"], [4, 3, 1, "", "input_shape"], [4, 3, 1, "", "input_splits"], [4, 3, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[4, 3, 1, "id0", "batch_size_per_rank"], [4, 3, 1, "id1", "codecs"], [4, 3, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [4, 3, 1, "id3", "dim_sum_per_rank"], [4, 3, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[4, 3, 1, "id5", "backward_recat_tensor"], [4, 3, 1, "id6", "codecs"], [4, 3, 1, "id7", "embedding_dim"], [4, 3, 1, "id8", "forward_recat_tensor"], [4, 3, 1, "id9", "input_splits"], [4, 3, 1, "id10", "lengths_after_sparse_data_all2all"], [4, 3, 1, "id11", "output_splits"], [4, 3, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [4, 3, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[4, 3, 1, "id14", "B_global"], [4, 3, 1, "id15", "B_local"], [4, 3, 1, "id16", "B_local_list"], [4, 3, 1, "id17", "D_local_list"], [4, 3, 1, "", "codecs"], [4, 3, 1, "", "dim_sum_per_rank"], [4, 3, 1, "", "dims_sum_per_rank"], [4, 3, 1, "id18", "input_split_sizes"], [4, 3, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[4, 3, 1, "", "codecs"], [4, 3, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[4, 3, 1, "", "codecs"], [4, 3, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[4, 3, 1, "", "codecs"], [4, 3, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[4, 3, 1, "id23", "codecs"], [4, 3, 1, "id24", "equal_splits"], [4, 3, 1, "id25", "input_sizes"], [4, 3, 1, "id26", "input_splits"], [4, 3, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[4, 3, 1, "id28", "batch_size_per_feature_pre_a2a"], [4, 3, 1, "id29", "batch_size_per_rank_per_feature"], [4, 3, 1, "id30", "codecs"], [4, 3, 1, "id31", "emb_dim_per_rank_per_feature"], [4, 3, 1, "id32", "input_splits"], [4, 3, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.dist_data": [[6, 2, 1, "", "EmbeddingsAllToOne"], [6, 2, 1, "", "EmbeddingsAllToOneReduce"], [6, 2, 1, "", "JaggedTensorAllToAll"], [6, 2, 1, "", "KJTAllToAll"], [6, 2, 1, "", "KJTAllToAllSplitsAwaitable"], [6, 2, 1, "", "KJTAllToAllTensorsAwaitable"], [6, 2, 1, "", "KJTOneToAll"], [6, 2, 1, "", "MergePooledEmbeddingsModule"], [6, 2, 1, "", "PooledEmbeddingsAllGather"], [6, 2, 1, "", "PooledEmbeddingsAllToAll"], [6, 2, 1, "", "PooledEmbeddingsAwaitable"], [6, 2, 1, "", "PooledEmbeddingsReduceScatter"], [6, 2, 1, "", "SeqEmbeddingsAllToOne"], [6, 2, 1, "", "SequenceEmbeddingsAllToAll"], [6, 2, 1, "", "SequenceEmbeddingsAwaitable"], [6, 2, 1, "", "SplitsAllToAllAwaitable"], [6, 2, 1, "", "TensorAllToAll"], [6, 2, 1, "", "TensorAllToAllSplitsAwaitable"], [6, 2, 1, "", "TensorAllToAllValuesAwaitable"], [6, 2, 1, "", "TensorValuesAllToAll"], [6, 2, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [6, 2, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[6, 4, 1, "", "forward"], [6, 4, 1, "", "set_device"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[6, 4, 1, "", "forward"], [6, 4, 1, "", "set_device"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.MergePooledEmbeddingsModule": [[6, 4, 1, "", "forward"], [6, 4, 1, "", "set_device"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[6, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[6, 4, 1, "", "forward"], [6, 4, 1, "", "set_device"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.TensorAllToAll": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.TensorValuesAllToAll": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.embedding": [[4, 2, 1, "", "EmbeddingCollectionAwaitable"], [4, 2, 1, "", "EmbeddingCollectionContext"], [4, 2, 1, "", "EmbeddingCollectionSharder"], [4, 2, 1, "", "ShardedEmbeddingCollection"], [4, 1, 1, "", "create_embedding_sharding"], [4, 1, 1, "", "create_sharding_infos_by_sharding"], [4, 1, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 1, 1, "", "get_device_from_parameter_sharding"], [4, 1, 1, "", "get_ec_index_dedup"], [4, 1, 1, "", "pad_vbe_kjt_lengths"], [4, 1, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "shardable_parameters"], [4, 4, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "compute_and_output_dist"], [4, 4, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "output_dist"], [4, 4, 1, "", "reset_parameters"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[4, 2, 1, "", "CommOpGradientScaling"], [4, 2, 1, "", "GroupedEmbeddingsLookup"], [4, 2, 1, "", "GroupedPooledEmbeddingsLookup"], [4, 2, 1, "", "InferCPUGroupedEmbeddingsLookup"], [4, 2, 1, "", "InferGroupedEmbeddingsLookup"], [4, 2, 1, "", "InferGroupedLookupMixin"], [4, 2, 1, "", "InferGroupedPooledEmbeddingsLookup"], [4, 2, 1, "", "MetaInferGroupedEmbeddingsLookup"], [4, 2, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "dummy_tensor"], [4, 1, 1, "", "embeddings_cat_empty_rank_handle"], [4, 1, 1, "", "embeddings_cat_empty_rank_handle_inference"], [4, 1, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[4, 4, 1, "", "flush"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "named_parameters_by_table"], [4, 4, 1, "", "prefetch"], [4, 4, 1, "", "purge"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[4, 4, 1, "", "flush"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "named_parameters_by_table"], [4, 4, 1, "", "prefetch"], [4, 4, 1, "", "purge"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[4, 4, 1, "", "get_tbes_to_register"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[4, 4, 1, "", "get_tbes_to_register"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[4, 4, 1, "", "forward"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[4, 4, 1, "", "get_tbes_to_register"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[4, 4, 1, "", "flush"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "purge"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[4, 4, 1, "", "flush"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "purge"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[4, 2, 1, "", "BaseEmbeddingDist"], [4, 2, 1, "", "BaseSparseFeaturesDist"], [4, 2, 1, "", "EmbeddingSharding"], [4, 2, 1, "", "EmbeddingShardingContext"], [4, 2, 1, "", "EmbeddingShardingInfo"], [4, 2, 1, "", "FusedKJTListSplitsAwaitable"], [4, 2, 1, "", "KJTListAwaitable"], [4, 2, 1, "", "KJTListSplitsAwaitable"], [4, 2, 1, "", "KJTSplitsAllToAllMeta"], [4, 2, 1, "", "ListOfKJTListAwaitable"], [4, 2, 1, "", "ListOfKJTListSplitsAwaitable"], [4, 1, 1, "", "bucketize_kjt_before_all2all"], [4, 1, 1, "", "bucketize_kjt_inference"], [4, 1, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[4, 4, 1, "", "forward"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[4, 4, 1, "", "forward"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[4, 4, 1, "", "create_input_dist"], [4, 4, 1, "", "create_lookup"], [4, 4, 1, "", "create_output_dist"], [4, 4, 1, "", "embedding_dims"], [4, 4, 1, "", "embedding_names"], [4, 4, 1, "", "embedding_names_per_rank"], [4, 4, 1, "", "embedding_shard_metadata"], [4, 4, 1, "", "embedding_tables"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 4, 1, "", "uncombined_embedding_dims"], [4, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[4, 3, 1, "", "embedding_config"], [4, 3, 1, "", "fused_params"], [4, 3, 1, "", "param"], [4, 3, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[4, 3, 1, "", "device"], [4, 3, 1, "", "input_splits"], [4, 3, 1, "", "input_tensors"], [4, 3, 1, "", "keys"], [4, 3, 1, "", "labels"], [4, 3, 1, "", "pg"], [4, 3, 1, "", "splits"], [4, 3, 1, "", "splits_tensors"], [4, 3, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[4, 2, 1, "", "BaseEmbeddingLookup"], [4, 2, 1, "", "BaseEmbeddingSharder"], [4, 2, 1, "", "BaseGroupedFeatureProcessor"], [4, 2, 1, "", "BaseQuantEmbeddingSharder"], [4, 2, 1, "", "DTensorMetadata"], [4, 2, 1, "", "EmbeddingAttributes"], [4, 2, 1, "", "EmbeddingComputeKernel"], [4, 2, 1, "", "FeatureShardingMixIn"], [4, 2, 1, "", "GroupedEmbeddingConfig"], [4, 2, 1, "", "InputDistOutputs"], [4, 2, 1, "", "KJTList"], [4, 2, 1, "", "ListOfKJTList"], [4, 2, 1, "", "ModuleShardingMixIn"], [4, 2, 1, "", "OptimType"], [4, 2, 1, "", "ShardedConfig"], [4, 2, 1, "", "ShardedEmbeddingModule"], [4, 2, 1, "", "ShardedEmbeddingTable"], [4, 2, 1, "", "ShardedMetaConfig"], [4, 1, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[4, 4, 1, "", "forward"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[4, 4, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 4, 1, "", "sharding_types"], [4, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[4, 4, 1, "", "forward"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[4, 4, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 4, 1, "", "shardable_parameters"], [4, 4, 1, "", "sharding_types"], [4, 4, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.DTensorMetadata": [[4, 3, 1, "", "mesh"], [4, 3, 1, "", "placements"], [4, 3, 1, "", "size"], [4, 3, 1, "", "stride"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[4, 3, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[4, 3, 1, "", "DENSE"], [4, 3, 1, "", "FUSED"], [4, 3, 1, "", "FUSED_UVM"], [4, 3, 1, "", "FUSED_UVM_CACHING"], [4, 3, 1, "", "KEY_VALUE"], [4, 3, 1, "", "QUANT"], [4, 3, 1, "", "QUANT_UVM"], [4, 3, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[4, 4, 1, "", "feature_names"], [4, 4, 1, "", "feature_names_per_rank"], [4, 4, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[4, 3, 1, "", "compute_kernel"], [4, 3, 1, "", "data_type"], [4, 4, 1, "", "dim_sum"], [4, 4, 1, "", "embedding_dims"], [4, 4, 1, "", "embedding_names"], [4, 4, 1, "", "embedding_shard_metadata"], [4, 3, 1, "", "embedding_tables"], [4, 4, 1, "", "feature_hash_sizes"], [4, 4, 1, "", "feature_names"], [4, 3, 1, "", "fused_params"], [4, 3, 1, "", "has_feature_processor"], [4, 3, 1, "", "is_weighted"], [4, 4, 1, "", "num_features"], [4, 3, 1, "", "pooling"], [4, 4, 1, "", "table_names"]], "torchrec.distributed.embedding_types.InputDistOutputs": [[4, 3, 1, "", "bucket_mapping_tensor"], [4, 3, 1, "", "bucketized_length"], [4, 3, 1, "", "features"], [4, 4, 1, "", "record_stream"], [4, 3, 1, "", "unbucketize_permute_tensor"]], "torchrec.distributed.embedding_types.KJTList": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[4, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[4, 3, 1, "", "ADAGRAD"], [4, 3, 1, "", "ADAM"], [4, 3, 1, "", "ADAMW"], [4, 3, 1, "", "LAMB"], [4, 3, 1, "", "LARS_SGD"], [4, 3, 1, "", "LION"], [4, 3, 1, "", "PARTIAL_ROWWISE_ADAM"], [4, 3, 1, "", "PARTIAL_ROWWISE_LAMB"], [4, 3, 1, "", "ROWWISE_ADAGRAD"], [4, 3, 1, "", "SGD"], [4, 3, 1, "", "SHAMPOO"], [4, 3, 1, "", "SHAMPOO_V2"]], "torchrec.distributed.embedding_types.ShardedConfig": [[4, 3, 1, "", "local_cols"], [4, 3, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[4, 4, 1, "", "extra_repr"], [4, 4, 1, "", "prefetch"], [4, 3, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[4, 3, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[4, 3, 1, "", "dtensor_metadata"], [4, 3, 1, "", "global_metadata"], [4, 3, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[4, 2, 1, "", "EmbeddingAwaitable"], [4, 2, 1, "", "EmbeddingBagCollectionAwaitable"], [4, 2, 1, "", "EmbeddingBagCollectionContext"], [4, 2, 1, "", "EmbeddingBagCollectionSharder"], [4, 2, 1, "", "EmbeddingBagSharder"], [4, 2, 1, "", "ShardedEmbeddingBag"], [4, 2, 1, "", "ShardedEmbeddingBagCollection"], [4, 2, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [4, 1, 1, "", "construct_output_kt"], [4, 1, 1, "", "create_embedding_bag_sharding"], [4, 1, 1, "", "create_sharding_infos_by_sharding"], [4, 1, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 1, 1, "", "get_device_from_parameter_sharding"], [4, 1, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[4, 3, 1, "", "divisor"], [4, 3, 1, "", "inverse_indices"], [4, 4, 1, "", "record_stream"], [4, 3, 1, "", "sharding_contexts"], [4, 3, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "load_state_dict"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_modules"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "output_dist"], [4, 4, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "compute_and_output_dist"], [4, 4, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "output_dist"], [4, 4, 1, "", "reset_parameters"], [4, 3, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[4, 2, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[4, 4, 1, "", "forward"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[4, 2, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [4, 2, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [4, 2, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[4, 4, 1, "", "create_context"], [4, 3, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[4, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [4, 2, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [4, 2, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[4, 3, 1, "", "evictions_per_table"], [4, 4, 1, "", "record_stream"], [4, 3, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[4, 4, 1, "", "create_context"], [4, 3, 1, "", "training"]], "torchrec.distributed.mc_modules": [[4, 2, 1, "", "ManagedCollisionCollectionAwaitable"], [4, 2, 1, "", "ManagedCollisionCollectionContext"], [4, 2, 1, "", "ManagedCollisionCollectionSharder"], [4, 2, 1, "", "ShardedManagedCollisionCollection"], [4, 1, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "shardable_parameters"], [4, 4, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "create_context"], [4, 4, 1, "", "evict"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "open_slots"], [4, 4, 1, "", "output_dist"], [4, 4, 1, "", "sharded_parameter_names"], [4, 3, 1, "", "training"]], "torchrec.distributed.model_parallel": [[4, 2, 1, "", "DataParallelWrapper"], [4, 2, 1, "", "DefaultDataParallelWrapper"], [4, 2, 1, "", "DistributedModelParallel"], [4, 1, 1, "", "get_module"], [4, 1, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[4, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[4, 4, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[4, 4, 1, "", "bare_named_parameters"], [4, 4, 1, "", "copy"], [4, 4, 1, "", "forward"], [4, 5, 1, "", "fused_optimizer"], [4, 4, 1, "", "init_data_parallel"], [4, 4, 1, "", "load_state_dict"], [4, 5, 1, "", "module"], [4, 4, 1, "", "named_buffers"], [4, 4, 1, "", "named_parameters"], [4, 5, 1, "", "plan"], [4, 4, 1, "", "sparse_grad_parameter_names"], [4, 4, 1, "", "state_dict"], [4, 3, 1, "", "training"]], "torchrec.distributed.planner": [[5, 0, 0, "-", "constants"], [5, 0, 0, "-", "enumerators"], [5, 0, 0, "-", "partitioners"], [5, 0, 0, "-", "perf_models"], [5, 0, 0, "-", "planners"], [5, 0, 0, "-", "proposers"], [5, 0, 0, "-", "shard_estimators"], [5, 0, 0, "-", "stats"], [5, 0, 0, "-", "storage_reservations"], [5, 0, 0, "-", "types"], [5, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[5, 1, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[5, 2, 1, "", "EmbeddingEnumerator"], [5, 1, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[5, 4, 1, "", "enumerate"], [5, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[5, 2, 1, "", "GreedyPerfPartitioner"], [5, 2, 1, "", "MemoryBalancedPartitioner"], [5, 2, 1, "", "OrderedDeviceHardware"], [5, 2, 1, "", "ShardingOptionGroup"], [5, 2, 1, "", "SortBy"], [5, 1, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[5, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[5, 4, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[5, 3, 1, "", "device"], [5, 3, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[5, 3, 1, "", "param_count"], [5, 3, 1, "", "perf_sum"], [5, 3, 1, "", "sharding_options"], [5, 3, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[5, 3, 1, "", "PERF"], [5, 3, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[5, 2, 1, "", "NoopPerfModel"], [5, 2, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[5, 4, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[5, 4, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[5, 2, 1, "", "EmbeddingShardingPlanner"], [5, 2, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[5, 4, 1, "", "collective_plan"], [5, 4, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[5, 4, 1, "", "collective_plan"], [5, 4, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[5, 2, 1, "", "EmbeddingOffloadScaleupProposer"], [5, 2, 1, "", "GreedyProposer"], [5, 2, 1, "", "GridSearchProposer"], [5, 2, 1, "", "UniformProposer"], [5, 1, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[5, 4, 1, "", "allocate_budget"], [5, 4, 1, "", "build_affine_storage_model"], [5, 4, 1, "", "clf_to_bytes"], [5, 4, 1, "", "feedback"], [5, 4, 1, "", "get_budget"], [5, 4, 1, "", "get_cacheability"], [5, 4, 1, "", "get_expected_lookups"], [5, 4, 1, "", "load"], [5, 4, 1, "", "next_plan"], [5, 4, 1, "", "promote_high_prefetch_overheaad_table_to_hbm"], [5, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[5, 4, 1, "", "feedback"], [5, 4, 1, "", "load"], [5, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[5, 4, 1, "", "feedback"], [5, 4, 1, "", "load"], [5, 4, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[5, 4, 1, "", "feedback"], [5, 4, 1, "", "load"], [5, 4, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[5, 2, 1, "", "EmbeddingOffloadStats"], [5, 2, 1, "", "EmbeddingPerfEstimator"], [5, 2, 1, "", "EmbeddingStorageEstimator"], [5, 1, 1, "", "calculate_pipeline_io_cost"], [5, 1, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[5, 5, 1, "", "cacheability"], [5, 4, 1, "", "estimate_cache_miss_rate"], [5, 5, 1, "", "expected_lookups"], [5, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[5, 4, 1, "", "estimate"], [5, 4, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[5, 4, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[5, 2, 1, "", "EmbeddingStats"], [5, 2, 1, "", "NoopEmbeddingStats"], [5, 1, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[5, 4, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[5, 4, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[5, 2, 1, "", "FixedPercentageStorageReservation"], [5, 2, 1, "", "HeuristicalStorageReservation"], [5, 2, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[5, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[5, 4, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[5, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[5, 2, 1, "", "CustomTopologyData"], [5, 2, 1, "", "DeviceHardware"], [5, 2, 1, "", "Enumerator"], [5, 2, 1, "", "ParameterConstraints"], [5, 2, 1, "", "PartitionByType"], [5, 2, 1, "", "Partitioner"], [5, 2, 1, "", "Perf"], [5, 2, 1, "", "PerfModel"], [5, 6, 1, "", "PlannerError"], [5, 2, 1, "", "PlannerErrorType"], [5, 2, 1, "", "Proposer"], [5, 2, 1, "", "Shard"], [5, 2, 1, "", "ShardEstimator"], [5, 2, 1, "", "ShardingOption"], [5, 2, 1, "", "Stats"], [5, 2, 1, "", "Storage"], [5, 2, 1, "", "StorageReservation"], [5, 2, 1, "", "Topology"]], "torchrec.distributed.planner.types.CustomTopologyData": [[5, 4, 1, "", "get_data"], [5, 4, 1, "", "has_data"], [5, 3, 1, "", "supported_fields"]], "torchrec.distributed.planner.types.DeviceHardware": [[5, 3, 1, "", "perf"], [5, 3, 1, "", "rank"], [5, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[5, 4, 1, "", "enumerate"], [5, 4, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[5, 3, 1, "id0", "batch_sizes"], [5, 3, 1, "id1", "bounds_check_mode"], [5, 3, 1, "id2", "cache_params"], [5, 3, 1, "id3", "compute_kernels"], [5, 3, 1, "id4", "device_group"], [5, 3, 1, "id5", "enforce_hbm"], [5, 3, 1, "id6", "feature_names"], [5, 3, 1, "id7", "is_weighted"], [5, 3, 1, "id8", "key_value_params"], [5, 3, 1, "id9", "min_partition"], [5, 3, 1, "id10", "num_poolings"], [5, 3, 1, "id11", "output_dtype"], [5, 3, 1, "id12", "pooling_factors"], [5, 3, 1, "id13", "sharding_types"], [5, 3, 1, "id14", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[5, 3, 1, "", "DEVICE"], [5, 3, 1, "", "HOST"], [5, 3, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[5, 4, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[5, 3, 1, "", "bwd_comms"], [5, 3, 1, "", "bwd_compute"], [5, 3, 1, "", "fwd_comms"], [5, 3, 1, "", "fwd_compute"], [5, 3, 1, "", "prefetch_compute"], [5, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[5, 4, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[5, 3, 1, "", "INSUFFICIENT_STORAGE"], [5, 3, 1, "", "OTHER"], [5, 3, 1, "", "PARTITION"], [5, 3, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[5, 4, 1, "", "feedback"], [5, 4, 1, "", "load"], [5, 4, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[5, 3, 1, "", "offset"], [5, 3, 1, "", "perf"], [5, 3, 1, "", "rank"], [5, 3, 1, "", "size"], [5, 3, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[5, 4, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[5, 3, 1, "", "batch_size"], [5, 3, 1, "", "bounds_check_mode"], [5, 5, 1, "", "cache_load_factor"], [5, 3, 1, "", "cache_params"], [5, 3, 1, "", "compute_kernel"], [5, 3, 1, "", "dependency"], [5, 3, 1, "", "enforce_hbm"], [5, 3, 1, "", "feature_names"], [5, 5, 1, "", "fqn"], [5, 3, 1, "", "input_lengths"], [5, 5, 1, "id15", "is_pooled"], [5, 3, 1, "", "key_value_params"], [5, 5, 1, "id16", "module"], [5, 4, 1, "", "module_pooled"], [5, 3, 1, "", "name"], [5, 5, 1, "", "num_inputs"], [5, 5, 1, "", "num_shards"], [5, 3, 1, "", "output_dtype"], [5, 5, 1, "", "path"], [5, 3, 1, "", "sharding_type"], [5, 3, 1, "", "shards"], [5, 3, 1, "", "stochastic_rounding"], [5, 5, 1, "id17", "tensor"], [5, 5, 1, "", "total_perf"], [5, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[5, 4, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[5, 3, 1, "", "ddr"], [5, 4, 1, "", "fits_in"], [5, 3, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[5, 4, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[5, 5, 1, "", "bwd_compute_multiplier"], [5, 5, 1, "", "compute_device"], [5, 5, 1, "", "ddr_mem_bw"], [5, 5, 1, "", "devices"], [5, 5, 1, "", "hbm_mem_bw"], [5, 5, 1, "", "inter_host_bw"], [5, 5, 1, "", "intra_host_bw"], [5, 5, 1, "", "local_world_size"], [5, 5, 1, "", "uneven_sharding_perf_multiplier"], [5, 5, 1, "", "weighted_feature_bwd_compute_multiplier"], [5, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[5, 2, 1, "", "BinarySearchPredicate"], [5, 2, 1, "", "LuusJaakolaSearch"], [5, 1, 1, "", "bytes_to_gb"], [5, 1, 1, "", "bytes_to_mb"], [5, 1, 1, "", "gb_to_bytes"], [5, 1, 1, "", "placement"], [5, 1, 1, "", "prod"], [5, 1, 1, "", "reset_shard_rank"], [5, 1, 1, "", "sharder_name"], [5, 1, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[5, 4, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[5, 4, 1, "", "best"], [5, 4, 1, "", "clamp"], [5, 4, 1, "", "next"], [5, 4, 1, "", "shrink_right"], [5, 4, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[4, 2, 1, "", "QuantEmbeddingBagCollectionSharder"], [4, 2, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [4, 2, 1, "", "ShardedQuantEbcInputDist"], [4, 2, 1, "", "ShardedQuantEmbeddingBagCollection"], [4, 2, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [4, 1, 1, "", "create_infer_embedding_bag_sharding"], [4, 1, 1, "", "flatten_feature_lengths"], [4, 1, 1, "", "get_device_from_parameter_sharding"], [4, 1, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[4, 4, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[4, 4, 1, "", "forward"], [4, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "compute_and_output_dist"], [4, 4, 1, "", "copy"], [4, 4, 1, "", "create_context"], [4, 4, 1, "", "embedding_bag_configs"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "output_dist"], [4, 4, 1, "", "sharding_type_device_group_to_sharding_infos"], [4, 5, 1, "", "shardings"], [4, 4, 1, "", "tbes_configs"], [4, 3, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[4, 4, 1, "", "apply_feature_processor"], [4, 4, 1, "", "compute"], [4, 3, 1, "", "embedding_bags"], [4, 3, 1, "", "tbes"], [4, 3, 1, "", "training"]], "torchrec.distributed.sharding": [[6, 0, 0, "-", "cw_sharding"], [6, 0, 0, "-", "dp_sharding"], [6, 0, 0, "-", "rw_sharding"], [6, 0, 0, "-", "tw_sharding"], [6, 0, 0, "-", "twcw_sharding"], [6, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[6, 2, 1, "", "BaseCwEmbeddingSharding"], [6, 2, 1, "", "CwPooledEmbeddingSharding"], [6, 2, 1, "", "InferCwPooledEmbeddingDist"], [6, 2, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [6, 2, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[6, 4, 1, "", "embedding_dims"], [6, 4, 1, "", "embedding_names"], [6, 4, 1, "", "uncombined_embedding_dims"], [6, 4, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[6, 2, 1, "", "BaseDpEmbeddingSharding"], [6, 2, 1, "", "DpPooledEmbeddingDist"], [6, 2, 1, "", "DpPooledEmbeddingSharding"], [6, 2, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[6, 4, 1, "", "embedding_dims"], [6, 4, 1, "", "embedding_names"], [6, 4, 1, "", "embedding_names_per_rank"], [6, 4, 1, "", "embedding_shard_metadata"], [6, 4, 1, "", "embedding_tables"], [6, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[6, 2, 1, "", "BaseRwEmbeddingSharding"], [6, 2, 1, "", "InferRwPooledEmbeddingDist"], [6, 2, 1, "", "InferRwPooledEmbeddingSharding"], [6, 2, 1, "", "InferRwSparseFeaturesDist"], [6, 2, 1, "", "RwPooledEmbeddingDist"], [6, 2, 1, "", "RwPooledEmbeddingSharding"], [6, 2, 1, "", "RwSparseFeaturesDist"], [6, 1, 1, "", "get_block_sizes_runtime_device"], [6, 1, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[6, 4, 1, "", "embedding_dims"], [6, 4, 1, "", "embedding_names"], [6, 4, 1, "", "embedding_names_per_rank"], [6, 4, 1, "", "embedding_shard_metadata"], [6, 4, 1, "", "embedding_tables"], [6, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[6, 2, 1, "", "BaseTwEmbeddingSharding"], [6, 2, 1, "", "InferTwEmbeddingSharding"], [6, 2, 1, "", "InferTwPooledEmbeddingDist"], [6, 2, 1, "", "InferTwSparseFeaturesDist"], [6, 2, 1, "", "TwPooledEmbeddingDist"], [6, 2, 1, "", "TwPooledEmbeddingSharding"], [6, 2, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[6, 4, 1, "", "embedding_dims"], [6, 4, 1, "", "embedding_names"], [6, 4, 1, "", "embedding_names_per_rank"], [6, 4, 1, "", "embedding_shard_metadata"], [6, 4, 1, "", "embedding_tables"], [6, 4, 1, "", "feature_names"], [6, 4, 1, "", "feature_names_per_rank"], [6, 4, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[6, 2, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[6, 2, 1, "", "BaseTwRwEmbeddingSharding"], [6, 2, 1, "", "TwRwPooledEmbeddingDist"], [6, 2, 1, "", "TwRwPooledEmbeddingSharding"], [6, 2, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[6, 4, 1, "", "embedding_dims"], [6, 4, 1, "", "embedding_names"], [6, 4, 1, "", "embedding_names_per_rank"], [6, 4, 1, "", "embedding_shard_metadata"], [6, 4, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[6, 4, 1, "", "create_input_dist"], [6, 4, 1, "", "create_lookup"], [6, 4, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[6, 4, 1, "", "forward"], [6, 3, 1, "", "training"]], "torchrec.distributed.types": [[4, 2, 1, "", "Awaitable"], [4, 2, 1, "", "CacheParams"], [4, 2, 1, "", "CacheStatistics"], [4, 2, 1, "", "CommOp"], [4, 2, 1, "", "ComputeKernel"], [4, 2, 1, "", "EmbeddingModuleShardingPlan"], [4, 2, 1, "", "GenericMeta"], [4, 2, 1, "", "GetItemLazyAwaitable"], [4, 2, 1, "", "KeyValueParams"], [4, 2, 1, "", "LazyAwaitable"], [4, 2, 1, "", "LazyGetItemMixin"], [4, 2, 1, "", "LazyNoWait"], [4, 2, 1, "", "ModuleSharder"], [4, 2, 1, "", "ModuleShardingPlan"], [4, 2, 1, "", "NoOpQuantizedCommCodec"], [4, 2, 1, "", "NoWait"], [4, 2, 1, "", "NullShardedModuleContext"], [4, 2, 1, "", "NullShardingContext"], [4, 2, 1, "", "ObjectPoolShardingPlan"], [4, 2, 1, "", "ObjectPoolShardingType"], [4, 2, 1, "", "ParameterSharding"], [4, 2, 1, "", "ParameterStorage"], [4, 2, 1, "", "PipelineType"], [4, 2, 1, "", "QuantizedCommCodec"], [4, 2, 1, "", "QuantizedCommCodecs"], [4, 2, 1, "", "ShardedModule"], [4, 2, 1, "", "ShardingEnv"], [4, 2, 1, "", "ShardingPlan"], [4, 2, 1, "", "ShardingPlanner"], [4, 2, 1, "", "ShardingType"], [4, 1, 1, "", "get_tensor_size_bytes"], [4, 1, 1, "", "rank_device"], [4, 1, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[4, 5, 1, "", "callbacks"], [4, 4, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[4, 3, 1, "id34", "algorithm"], [4, 3, 1, "id35", "load_factor"], [4, 3, 1, "", "multipass_prefetch_config"], [4, 3, 1, "id36", "precision"], [4, 3, 1, "id37", "prefetch_pipeline"], [4, 3, 1, "id38", "reserved_memory"], [4, 3, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[4, 5, 1, "", "cacheability"], [4, 5, 1, "", "expected_lookups"], [4, 4, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[4, 3, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [4, 3, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [4, 3, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[4, 3, 1, "", "DEFAULT"]], "torchrec.distributed.types.KeyValueParams": [[4, 3, 1, "id40", "gather_ssd_cache_stats"], [4, 3, 1, "", "ods_prefix"], [4, 3, 1, "id41", "ps_hosts"], [4, 3, 1, "", "report_interval"], [4, 3, 1, "id42", "ssd_rocksdb_shards"], [4, 3, 1, "id43", "ssd_rocksdb_write_buffer_size"], [4, 3, 1, "id44", "ssd_storage_directory"], [4, 3, 1, "", "stats_reporter_config"], [4, 3, 1, "", "use_passed_in_path"]], "torchrec.distributed.types.ModuleSharder": [[4, 4, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 4, 1, "", "shard"], [4, 4, 1, "", "shardable_parameters"], [4, 4, 1, "", "sharding_types"], [4, 4, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[4, 4, 1, "", "calc_quantized_size"], [4, 4, 1, "", "create_context"], [4, 4, 1, "", "decode"], [4, 4, 1, "", "encode"], [4, 4, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[4, 4, 1, "", "record_stream"]], "torchrec.distributed.types.ObjectPoolShardingPlan": [[4, 3, 1, "", "inference"], [4, 3, 1, "", "sharding_type"]], "torchrec.distributed.types.ObjectPoolShardingType": [[4, 3, 1, "", "REPLICATED_ROW_WISE"], [4, 3, 1, "", "ROW_WISE"]], "torchrec.distributed.types.ParameterSharding": [[4, 3, 1, "", "bounds_check_mode"], [4, 3, 1, "", "cache_params"], [4, 3, 1, "", "compute_kernel"], [4, 3, 1, "", "enforce_hbm"], [4, 3, 1, "", "key_value_params"], [4, 3, 1, "", "output_dtype"], [4, 3, 1, "", "ranks"], [4, 3, 1, "", "sharding_spec"], [4, 3, 1, "", "sharding_type"], [4, 3, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[4, 3, 1, "", "DDR"], [4, 3, 1, "", "HBM"]], "torchrec.distributed.types.PipelineType": [[4, 3, 1, "", "NONE"], [4, 3, 1, "", "TRAIN_BASE"], [4, 3, 1, "", "TRAIN_PREFETCH_SPARSE_DIST"], [4, 3, 1, "", "TRAIN_SPARSE_DIST"]], "torchrec.distributed.types.QuantizedCommCodec": [[4, 4, 1, "", "calc_quantized_size"], [4, 4, 1, "", "create_context"], [4, 4, 1, "", "decode"], [4, 4, 1, "", "encode"], [4, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[4, 3, 1, "", "backward"], [4, 3, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[4, 4, 1, "", "compute"], [4, 4, 1, "", "compute_and_output_dist"], [4, 4, 1, "", "create_context"], [4, 4, 1, "", "forward"], [4, 4, 1, "", "input_dist"], [4, 4, 1, "", "output_dist"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 4, 1, "", "sharded_parameter_names"], [4, 3, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[4, 4, 1, "", "from_local"], [4, 4, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[4, 4, 1, "", "get_plan_for_module"], [4, 3, 1, "id45", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[4, 4, 1, "", "collective_plan"], [4, 4, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[4, 3, 1, "", "COLUMN_WISE"], [4, 3, 1, "", "DATA_PARALLEL"], [4, 3, 1, "", "ROW_WISE"], [4, 3, 1, "", "TABLE_COLUMN_WISE"], [4, 3, 1, "", "TABLE_ROW_WISE"], [4, 3, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[4, 2, 1, "", "CopyableMixin"], [4, 2, 1, "", "ForkedPdb"], [4, 1, 1, "", "add_params_from_parameter_sharding"], [4, 1, 1, "", "add_prefix_to_state_dict"], [4, 1, 1, "", "append_prefix"], [4, 1, 1, "", "convert_to_fbgemm_types"], [4, 1, 1, "", "copy_to_device"], [4, 1, 1, "", "filter_state_dict"], [4, 1, 1, "", "get_unsharded_module_names"], [4, 1, 1, "", "init_parameters"], [4, 1, 1, "", "merge_fused_params"], [4, 1, 1, "", "none_throws"], [4, 1, 1, "", "optimizer_type_to_emb_opt_type"], [4, 2, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[4, 4, 1, "", "copy"], [4, 3, 1, "", "training"]], "torchrec.distributed.utils.ForkedPdb": [[4, 4, 1, "", "interaction"]], "torchrec.fx": [[7, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[7, 2, 1, "", "Tracer"], [7, 1, 1, "", "is_fx_tracing"], [7, 1, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[7, 4, 1, "", "create_arg"], [7, 4, 1, "", "is_leaf_module"], [7, 4, 1, "", "path_of_module"], [7, 4, 1, "", "trace"]], "torchrec.inference": [[8, 0, 0, "-", "model_packager"], [8, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[8, 2, 1, "", "PredictFactoryPackager"], [8, 1, 1, "", "load_config_text"], [8, 1, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[8, 4, 1, "", "save_predict_factory"], [8, 4, 1, "", "set_extern_modules"], [8, 4, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[8, 2, 1, "", "BatchingMetadata"], [8, 2, 1, "", "PredictFactory"], [8, 2, 1, "", "PredictModule"], [8, 2, 1, "", "QualNameMetadata"], [8, 1, 1, "", "quantize_dense"], [8, 1, 1, "", "quantize_embeddings"], [8, 1, 1, "", "quantize_feature"], [8, 1, 1, "", "quantize_inference_model"], [8, 1, 1, "", "shard_quant_model"], [8, 1, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[8, 3, 1, "", "device"], [8, 3, 1, "", "pinned"], [8, 3, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[8, 4, 1, "", "batching_metadata"], [8, 4, 1, "", "batching_metadata_json"], [8, 4, 1, "", "create_predict_module"], [8, 4, 1, "", "model_inputs_data"], [8, 4, 1, "", "qualname_metadata"], [8, 4, 1, "", "qualname_metadata_json"], [8, 4, 1, "", "result_metadata"], [8, 4, 1, "", "run_weights_dependent_transformations"], [8, 4, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[8, 4, 1, "", "forward"], [8, 4, 1, "", "predict_forward"], [8, 5, 1, "", "predict_module"], [8, 4, 1, "", "state_dict"], [8, 3, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[8, 3, 1, "", "need_preproc"]], "torchrec.metrics": [[9, 0, 0, "-", "accuracy"], [9, 0, 0, "-", "auc"], [9, 0, 0, "-", "auprc"], [9, 0, 0, "-", "calibration"], [9, 0, 0, "-", "ctr"], [9, 0, 0, "-", "mae"], [9, 0, 0, "-", "metric_module"], [9, 0, 0, "-", "mse"], [9, 0, 0, "-", "multiclass_recall"], [9, 0, 0, "-", "ndcg"], [9, 0, 0, "-", "ne"], [9, 0, 0, "-", "precision"], [9, 0, 0, "-", "rauc"], [9, 0, 0, "-", "rec_metric"], [9, 0, 0, "-", "recall"], [9, 0, 0, "-", "throughput"], [9, 0, 0, "-", "weighted_avg"], [9, 0, 0, "-", "xauc"]], "torchrec.metrics.accuracy": [[9, 2, 1, "", "AccuracyMetric"], [9, 2, 1, "", "AccuracyMetricComputation"], [9, 1, 1, "", "compute_accuracy"], [9, 1, 1, "", "compute_accuracy_sum"], [9, 1, 1, "", "get_accuracy_states"]], "torchrec.metrics.accuracy.AccuracyMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.auc": [[9, 2, 1, "", "AUCMetric"], [9, 2, 1, "", "AUCMetricComputation"], [9, 1, 1, "", "compute_auc"], [9, 1, 1, "", "compute_auc_per_group"]], "torchrec.metrics.auc.AUCMetricComputation": [[9, 4, 1, "", "reset"], [9, 4, 1, "", "update"]], "torchrec.metrics.auprc": [[9, 2, 1, "", "AUPRCMetric"], [9, 2, 1, "", "AUPRCMetricComputation"], [9, 1, 1, "", "compute_auprc"], [9, 1, 1, "", "compute_auprc_per_group"]], "torchrec.metrics.auprc.AUPRCMetricComputation": [[9, 4, 1, "", "reset"], [9, 4, 1, "", "update"]], "torchrec.metrics.calibration": [[9, 2, 1, "", "CalibrationMetric"], [9, 2, 1, "", "CalibrationMetricComputation"], [9, 1, 1, "", "compute_calibration"], [9, 1, 1, "", "get_calibration_states"]], "torchrec.metrics.calibration.CalibrationMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.ctr": [[9, 2, 1, "", "CTRMetric"], [9, 2, 1, "", "CTRMetricComputation"], [9, 1, 1, "", "compute_ctr"], [9, 1, 1, "", "get_ctr_states"]], "torchrec.metrics.ctr.CTRMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.mae": [[9, 2, 1, "", "MAEMetric"], [9, 2, 1, "", "MAEMetricComputation"], [9, 1, 1, "", "compute_error_sum"], [9, 1, 1, "", "compute_mae"], [9, 1, 1, "", "get_mae_states"]], "torchrec.metrics.mae.MAEMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.metric_module": [[9, 2, 1, "", "RecMetricModule"], [9, 2, 1, "", "StateMetric"], [9, 1, 1, "", "generate_metric_module"]], "torchrec.metrics.metric_module.RecMetricModule": [[9, 3, 1, "", "batch_size"], [9, 4, 1, "", "check_memory_usage"], [9, 4, 1, "", "compute"], [9, 3, 1, "", "compute_count"], [9, 4, 1, "", "get_memory_usage"], [9, 4, 1, "", "get_required_inputs"], [9, 3, 1, "", "last_compute_time"], [9, 4, 1, "", "local_compute"], [9, 3, 1, "", "memory_usage_limit_mb"], [9, 3, 1, "", "memory_usage_mb_avg"], [9, 3, 1, "", "oom_count"], [9, 3, 1, "", "rec_metrics"], [9, 3, 1, "", "rec_tasks"], [9, 4, 1, "", "reset"], [9, 4, 1, "", "should_compute"], [9, 3, 1, "", "state_metrics"], [9, 4, 1, "", "sync"], [9, 3, 1, "", "throughput_metric"], [9, 4, 1, "", "unsync"], [9, 4, 1, "", "update"], [9, 3, 1, "", "world_size"]], "torchrec.metrics.metric_module.StateMetric": [[9, 4, 1, "", "get_metrics"]], "torchrec.metrics.mse": [[9, 2, 1, "", "MSEMetric"], [9, 2, 1, "", "MSEMetricComputation"], [9, 1, 1, "", "compute_error_sum"], [9, 1, 1, "", "compute_mse"], [9, 1, 1, "", "compute_rmse"], [9, 1, 1, "", "get_mse_states"]], "torchrec.metrics.mse.MSEMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.multiclass_recall": [[9, 2, 1, "", "MulticlassRecallMetric"], [9, 2, 1, "", "MulticlassRecallMetricComputation"], [9, 1, 1, "", "compute_multiclass_recall_at_k"], [9, 1, 1, "", "compute_true_positives_at_k"], [9, 1, 1, "", "get_multiclass_recall_states"]], "torchrec.metrics.multiclass_recall.MulticlassRecallMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.ndcg": [[9, 2, 1, "", "NDCGComputation"], [9, 2, 1, "", "NDCGMetric"]], "torchrec.metrics.ndcg.NDCGComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.ne": [[9, 2, 1, "", "NEMetric"], [9, 2, 1, "", "NEMetricComputation"], [9, 1, 1, "", "compute_cross_entropy"], [9, 1, 1, "", "compute_logloss"], [9, 1, 1, "", "compute_ne"], [9, 1, 1, "", "get_ne_states"]], "torchrec.metrics.ne.NEMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.precision": [[9, 2, 1, "", "PrecisionMetric"], [9, 2, 1, "", "PrecisionMetricComputation"], [9, 1, 1, "", "compute_false_pos_sum"], [9, 1, 1, "", "compute_precision"], [9, 1, 1, "", "compute_true_pos_sum"], [9, 1, 1, "", "get_precision_states"]], "torchrec.metrics.precision.PrecisionMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.rauc": [[9, 2, 1, "", "RAUCMetric"], [9, 2, 1, "", "RAUCMetricComputation"], [9, 1, 1, "", "compute_rauc"], [9, 1, 1, "", "compute_rauc_per_group"], [9, 1, 1, "", "conquer_and_count"], [9, 1, 1, "", "count_reverse_pairs_divide_and_conquer"], [9, 1, 1, "", "divide"]], "torchrec.metrics.rauc.RAUCMetricComputation": [[9, 4, 1, "", "reset"], [9, 4, 1, "", "update"]], "torchrec.metrics.rec_metric": [[9, 2, 1, "", "MetricComputationReport"], [9, 2, 1, "", "RecMetric"], [9, 2, 1, "", "RecMetricComputation"], [9, 6, 1, "", "RecMetricException"], [9, 2, 1, "", "RecMetricList"], [9, 2, 1, "", "WindowBuffer"]], "torchrec.metrics.rec_metric.MetricComputationReport": [[9, 3, 1, "", "description"], [9, 3, 1, "", "metric_prefix"], [9, 3, 1, "", "name"], [9, 3, 1, "", "value"]], "torchrec.metrics.rec_metric.RecMetric": [[9, 3, 1, "", "LABELS"], [9, 3, 1, "", "PREDICTIONS"], [9, 3, 1, "", "WEIGHTS"], [9, 4, 1, "", "compute"], [9, 4, 1, "", "get_memory_usage"], [9, 4, 1, "", "get_required_inputs"], [9, 4, 1, "", "local_compute"], [9, 4, 1, "", "reset"], [9, 4, 1, "", "state_dict"], [9, 4, 1, "", "sync"], [9, 4, 1, "", "unsync"], [9, 4, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricComputation": [[9, 4, 1, "", "compute"], [9, 4, 1, "", "get_window_state"], [9, 4, 1, "", "get_window_state_name"], [9, 4, 1, "", "local_compute"], [9, 4, 1, "", "pre_compute"], [9, 4, 1, "", "reset"], [9, 4, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricList": [[9, 4, 1, "", "compute"], [9, 4, 1, "", "get_required_inputs"], [9, 4, 1, "", "local_compute"], [9, 3, 1, "", "rec_metrics"], [9, 3, 1, "", "required_inputs"], [9, 4, 1, "", "reset"], [9, 4, 1, "", "sync"], [9, 4, 1, "", "unsync"], [9, 4, 1, "", "update"]], "torchrec.metrics.rec_metric.WindowBuffer": [[9, 4, 1, "", "aggregate_state"], [9, 5, 1, "", "buffers"]], "torchrec.metrics.recall": [[9, 2, 1, "", "RecallMetric"], [9, 2, 1, "", "RecallMetricComputation"], [9, 1, 1, "", "compute_false_neg_sum"], [9, 1, 1, "", "compute_recall"], [9, 1, 1, "", "compute_true_pos_sum"], [9, 1, 1, "", "get_recall_states"]], "torchrec.metrics.recall.RecallMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.throughput": [[9, 2, 1, "", "ThroughputMetric"]], "torchrec.metrics.throughput.ThroughputMetric": [[9, 4, 1, "", "compute"], [9, 4, 1, "", "update"]], "torchrec.metrics.weighted_avg": [[9, 2, 1, "", "WeightedAvgMetric"], [9, 2, 1, "", "WeightedAvgMetricComputation"], [9, 1, 1, "", "get_mean"]], "torchrec.metrics.weighted_avg.WeightedAvgMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.metrics.xauc": [[9, 2, 1, "", "XAUCMetric"], [9, 2, 1, "", "XAUCMetricComputation"], [9, 1, 1, "", "compute_error_sum"], [9, 1, 1, "", "compute_weighted_num_pairs"], [9, 1, 1, "", "compute_xauc"], [9, 1, 1, "", "get_xauc_states"]], "torchrec.metrics.xauc.XAUCMetricComputation": [[9, 4, 1, "", "update"]], "torchrec.models": [[10, 0, 0, "-", "deepfm"]], "torchrec.models.deepfm": [[10, 2, 1, "", "DenseArch"], [10, 2, 1, "", "FMInteractionArch"], [10, 2, 1, "", "OverArch"], [10, 2, 1, "", "SimpleDeepFMNN"], [10, 2, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[10, 4, 1, "", "forward"], [10, 3, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[10, 4, 1, "", "forward"], [10, 3, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[10, 4, 1, "", "forward"], [10, 3, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[10, 4, 1, "", "forward"], [10, 3, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[10, 4, 1, "", "forward"], [10, 3, 1, "", "training"]], "torchrec.modules": [[11, 0, 0, "-", "activation"], [11, 0, 0, "-", "crossnet"], [11, 0, 0, "-", "deepfm"], [11, 0, 0, "-", "embedding_configs"], [11, 0, 0, "-", "embedding_modules"], [11, 0, 0, "-", "feature_processor"], [11, 0, 0, "-", "lazy_extension"], [11, 0, 0, "-", "mc_embedding_modules"], [11, 0, 0, "-", "mc_modules"], [11, 0, 0, "-", "mlp"], [11, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[11, 2, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.crossnet": [[11, 2, 1, "", "CrossNet"], [11, 2, 1, "", "LowRankCrossNet"], [11, 2, 1, "", "LowRankMixtureCrossNet"], [11, 2, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.deepfm": [[11, 2, 1, "", "DeepFM"], [11, 2, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.embedding_configs": [[11, 2, 1, "", "BaseEmbeddingConfig"], [11, 2, 1, "", "EmbeddingBagConfig"], [11, 2, 1, "", "EmbeddingConfig"], [11, 2, 1, "", "EmbeddingTableConfig"], [11, 2, 1, "", "PoolingType"], [11, 2, 1, "", "QuantConfig"], [11, 2, 1, "", "ShardingType"], [11, 1, 1, "", "data_type_to_dtype"], [11, 1, 1, "", "data_type_to_sparse_type"], [11, 1, 1, "", "dtype_to_data_type"], [11, 1, 1, "", "pooling_type_to_pooling_mode"], [11, 1, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[11, 3, 1, "", "data_type"], [11, 3, 1, "", "embedding_dim"], [11, 3, 1, "", "feature_names"], [11, 4, 1, "", "get_weight_init_max"], [11, 4, 1, "", "get_weight_init_min"], [11, 3, 1, "", "init_fn"], [11, 3, 1, "", "name"], [11, 3, 1, "", "need_pos"], [11, 3, 1, "", "num_embeddings"], [11, 4, 1, "", "num_features"], [11, 3, 1, "", "pruning_indices_remapping"], [11, 3, 1, "", "weight_init_max"], [11, 3, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[11, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[11, 3, 1, "", "embedding_dim"], [11, 3, 1, "", "feature_names"], [11, 3, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[11, 3, 1, "", "embedding_names"], [11, 3, 1, "", "has_feature_processor"], [11, 3, 1, "", "is_weighted"], [11, 3, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[11, 3, 1, "", "MEAN"], [11, 3, 1, "", "NONE"], [11, 3, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[11, 3, 1, "", "activation"], [11, 3, 1, "", "per_table_weight_dtype"], [11, 3, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[11, 3, 1, "", "COLUMN_WISE"], [11, 3, 1, "", "DATA_PARALLEL"], [11, 3, 1, "", "ROW_WISE"], [11, 3, 1, "", "TABLE_COLUMN_WISE"], [11, 3, 1, "", "TABLE_ROW_WISE"], [11, 3, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[11, 2, 1, "", "EmbeddingBagCollection"], [11, 2, 1, "", "EmbeddingBagCollectionInterface"], [11, 2, 1, "", "EmbeddingCollection"], [11, 2, 1, "", "EmbeddingCollectionInterface"], [11, 1, 1, "", "get_embedding_names_by_table"], [11, 1, 1, "", "process_pooled_embeddings"], [11, 1, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_bag_configs"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "reset_parameters"], [11, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[11, 4, 1, "", "embedding_bag_configs"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "is_weighted"], [11, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 4, 1, "", "embedding_configs"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "embedding_names_by_table"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "need_indices"], [11, 4, 1, "", "reset_parameters"], [11, 3, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[11, 4, 1, "", "embedding_configs"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "embedding_names_by_table"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "need_indices"], [11, 3, 1, "", "training"]], "torchrec.modules.feature_processor": [[11, 2, 1, "", "BaseFeatureProcessor"], [11, 2, 1, "", "BaseGroupedFeatureProcessor"], [11, 2, 1, "", "PositionWeightedModule"], [11, 2, 1, "", "PositionWeightedProcessor"], [11, 1, 1, "", "offsets_to_range_traceble"], [11, 1, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[11, 4, 1, "", "forward"], [11, 4, 1, "", "reset_parameters"], [11, 3, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[11, 4, 1, "", "forward"], [11, 4, 1, "", "named_buffers"], [11, 4, 1, "", "state_dict"], [11, 3, 1, "", "training"]], "torchrec.modules.lazy_extension": [[11, 2, 1, "", "LazyModuleExtensionMixin"], [11, 1, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[11, 4, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[11, 2, 1, "", "BaseManagedCollisionEmbeddingCollection"], [11, 2, 1, "", "ManagedCollisionEmbeddingBagCollection"], [11, 2, 1, "", "ManagedCollisionEmbeddingCollection"], [11, 1, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[11, 3, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[11, 3, 1, "", "training"]], "torchrec.modules.mc_modules": [[11, 2, 1, "", "DistanceLFU_EvictionPolicy"], [11, 2, 1, "", "LFU_EvictionPolicy"], [11, 2, 1, "", "LRU_EvictionPolicy"], [11, 2, 1, "", "MCHEvictionPolicy"], [11, 2, 1, "", "MCHEvictionPolicyMetadataInfo"], [11, 2, 1, "", "MCHManagedCollisionModule"], [11, 2, 1, "", "ManagedCollisionCollection"], [11, 2, 1, "", "ManagedCollisionModule"], [11, 1, 1, "", "apply_mc_method_to_jt_dict"], [11, 1, 1, "", "average_threshold_filter"], [11, 1, 1, "", "dynamic_threshold_filter"], [11, 1, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[11, 4, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 4, 1, "", "record_history_metadata"], [11, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[11, 4, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 4, 1, "", "record_history_metadata"], [11, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[11, 4, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 4, 1, "", "record_history_metadata"], [11, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[11, 4, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 4, 1, "", "record_history_metadata"], [11, 4, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[11, 3, 1, "", "is_history_metadata"], [11, 3, 1, "", "is_mch_metadata"], [11, 3, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[11, 4, 1, "", "evict"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "input_size"], [11, 4, 1, "", "open_slots"], [11, 4, 1, "", "output_size"], [11, 4, 1, "", "preprocess"], [11, 4, 1, "", "profile"], [11, 4, 1, "", "rebuild_with_output_id_range"], [11, 4, 1, "", "remap"], [11, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[11, 4, 1, "", "embedding_configs"], [11, 4, 1, "", "evict"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "open_slots"], [11, 3, 1, "", "training"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[11, 5, 1, "", "device"], [11, 4, 1, "", "evict"], [11, 4, 1, "", "forward"], [11, 4, 1, "", "input_size"], [11, 4, 1, "", "open_slots"], [11, 4, 1, "", "output_size"], [11, 4, 1, "", "preprocess"], [11, 4, 1, "", "rebuild_with_output_id_range"], [11, 3, 1, "", "training"]], "torchrec.modules.mlp": [[11, 2, 1, "", "MLP"], [11, 2, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[11, 4, 1, "", "forward"], [11, 3, 1, "", "training"]], "torchrec.modules.utils": [[11, 2, 1, "", "SequenceVBEContext"], [11, 1, 1, "", "check_module_output_dimension"], [11, 1, 1, "", "construct_jagged_tensors"], [11, 1, 1, "", "construct_jagged_tensors_inference"], [11, 1, 1, "", "construct_modulelist_from_single_module"], [11, 1, 1, "", "convert_list_of_modules_to_modulelist"], [11, 1, 1, "", "deterministic_dedup"], [11, 1, 1, "", "extract_module_or_tensor_callable"], [11, 1, 1, "", "get_module_output_dimension"], [11, 1, 1, "", "init_mlp_weights_xavier_uniform"], [11, 1, 1, "", "jagged_index_select_with_empty"]], "torchrec.modules.utils.SequenceVBEContext": [[11, 3, 1, "", "recat"], [11, 4, 1, "", "record_stream"], [11, 3, 1, "", "reindexed_length_per_key"], [11, 3, 1, "", "reindexed_lengths"], [11, 3, 1, "", "reindexed_values"], [11, 3, 1, "", "unpadded_lengths"]], "torchrec.optim": [[12, 0, 0, "-", "clipping"], [12, 0, 0, "-", "fused"], [12, 0, 0, "-", "keyed"], [12, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[12, 2, 1, "", "GradientClipping"], [12, 2, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[12, 3, 1, "", "NONE"], [12, 3, 1, "", "NORM"], [12, 3, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[12, 4, 1, "", "step"]], "torchrec.optim.fused": [[12, 2, 1, "", "EmptyFusedOptimizer"], [12, 2, 1, "", "FusedOptimizer"], [12, 2, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[12, 4, 1, "", "step"], [12, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[12, 4, 1, "", "step"], [12, 4, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[12, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[12, 2, 1, "", "CombinedOptimizer"], [12, 2, 1, "", "KeyedOptimizer"], [12, 2, 1, "", "KeyedOptimizerWrapper"], [12, 2, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[12, 5, 1, "", "optimizers"], [12, 5, 1, "", "param_groups"], [12, 5, 1, "", "params"], [12, 4, 1, "", "post_load_state_dict"], [12, 4, 1, "", "prepend_opt_key"], [12, 4, 1, "", "save_param_groups"], [12, 4, 1, "", "set_optimizer_step"], [12, 5, 1, "", "state"], [12, 4, 1, "", "step"], [12, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[12, 4, 1, "", "add_param_group"], [12, 4, 1, "", "init_state"], [12, 4, 1, "", "load_state_dict"], [12, 4, 1, "", "post_load_state_dict"], [12, 4, 1, "", "save_param_groups"], [12, 4, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[12, 4, 1, "", "step"], [12, 4, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[12, 4, 1, "", "add_param_group"], [12, 4, 1, "", "load_state_dict"], [12, 4, 1, "", "post_load_state_dict"], [12, 4, 1, "", "save_param_groups"], [12, 4, 1, "", "state_dict"], [12, 4, 1, "", "step"], [12, 4, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[12, 2, 1, "", "WarmupOptimizer"], [12, 2, 1, "", "WarmupPolicy"], [12, 2, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[12, 4, 1, "", "post_load_state_dict"], [12, 4, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[12, 3, 1, "", "CONSTANT"], [12, 3, 1, "", "COSINE_ANNEALING_WARM_RESTARTS"], [12, 3, 1, "", "INVSQRT"], [12, 3, 1, "", "LINEAR"], [12, 3, 1, "", "NONE"], [12, 3, 1, "", "POLY"], [12, 3, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[12, 3, 1, "", "decay_iters"], [12, 3, 1, "", "lr_scale"], [12, 3, 1, "", "max_iters"], [12, 3, 1, "", "policy"], [12, 3, 1, "", "sgdr_period"], [12, 3, 1, "", "value"]], "torchrec.quant": [[13, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[13, 2, 1, "", "EmbeddingBagCollection"], [13, 2, 1, "", "EmbeddingCollection"], [13, 2, 1, "", "FeatureProcessedEmbeddingBagCollection"], [13, 1, 1, "", "for_each_module_of_type_do"], [13, 1, 1, "", "pruned_num_embeddings"], [13, 1, 1, "", "quant_prep_customize_row_alignment"], [13, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [13, 1, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [13, 1, 1, "", "quant_prep_enable_register_tbes"], [13, 1, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[13, 5, 1, "", "device"], [13, 4, 1, "", "embedding_bag_configs"], [13, 4, 1, "", "forward"], [13, 4, 1, "", "from_float"], [13, 4, 1, "", "is_weighted"], [13, 4, 1, "", "output_dtype"], [13, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[13, 5, 1, "", "device"], [13, 4, 1, "", "embedding_configs"], [13, 4, 1, "", "embedding_dim"], [13, 4, 1, "", "embedding_names_by_table"], [13, 4, 1, "", "forward"], [13, 4, 1, "", "from_float"], [13, 4, 1, "", "need_indices"], [13, 4, 1, "", "output_dtype"], [13, 3, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[13, 3, 1, "", "embedding_bags"], [13, 4, 1, "", "forward"], [13, 4, 1, "", "from_float"], [13, 3, 1, "", "tbes"], [13, 3, 1, "", "training"]], "torchrec.sparse": [[14, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[14, 2, 1, "", "ComputeJTDictToKJT"], [14, 2, 1, "", "ComputeKJTToJTDict"], [14, 2, 1, "", "JaggedTensor"], [14, 2, 1, "", "JaggedTensorMeta"], [14, 2, 1, "", "KeyedJaggedTensor"], [14, 2, 1, "", "KeyedTensor"], [14, 1, 1, "", "flatten_kjt_list"], [14, 1, 1, "", "jt_is_equal"], [14, 1, 1, "", "kjt_is_equal"], [14, 1, 1, "", "permute_multi_embedding"], [14, 1, 1, "", "regroup_kts"], [14, 1, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[14, 4, 1, "", "forward"], [14, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[14, 4, 1, "", "forward"], [14, 3, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[14, 4, 1, "", "device"], [14, 4, 1, "", "empty"], [14, 4, 1, "", "from_dense"], [14, 4, 1, "", "from_dense_lengths"], [14, 4, 1, "", "lengths"], [14, 4, 1, "", "lengths_or_none"], [14, 4, 1, "", "offsets"], [14, 4, 1, "", "offsets_or_none"], [14, 4, 1, "", "record_stream"], [14, 4, 1, "", "to"], [14, 4, 1, "", "to_dense"], [14, 4, 1, "", "to_dense_weights"], [14, 4, 1, "", "to_padded_dense"], [14, 4, 1, "", "to_padded_dense_weights"], [14, 4, 1, "", "values"], [14, 4, 1, "", "weights"], [14, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[14, 4, 1, "", "concat"], [14, 4, 1, "", "device"], [14, 4, 1, "", "dist_init"], [14, 4, 1, "", "dist_labels"], [14, 4, 1, "", "dist_splits"], [14, 4, 1, "", "dist_tensors"], [14, 4, 1, "", "empty"], [14, 4, 1, "", "empty_like"], [14, 4, 1, "", "flatten_lengths"], [14, 4, 1, "", "from_jt_dict"], [14, 4, 1, "", "from_lengths_sync"], [14, 4, 1, "", "from_offsets_sync"], [14, 4, 1, "", "index_per_key"], [14, 4, 1, "", "inverse_indices"], [14, 4, 1, "", "inverse_indices_or_none"], [14, 4, 1, "", "keys"], [14, 4, 1, "", "length_per_key"], [14, 4, 1, "", "length_per_key_or_none"], [14, 4, 1, "", "lengths"], [14, 4, 1, "", "lengths_offset_per_key"], [14, 4, 1, "", "lengths_or_none"], [14, 4, 1, "", "offset_per_key"], [14, 4, 1, "", "offset_per_key_or_none"], [14, 4, 1, "", "offsets"], [14, 4, 1, "", "offsets_or_none"], [14, 4, 1, "", "permute"], [14, 4, 1, "", "pin_memory"], [14, 4, 1, "", "record_stream"], [14, 4, 1, "", "split"], [14, 4, 1, "", "stride"], [14, 4, 1, "", "stride_per_key"], [14, 4, 1, "", "stride_per_key_per_rank"], [14, 4, 1, "", "sync"], [14, 4, 1, "", "to"], [14, 4, 1, "", "to_dict"], [14, 4, 1, "", "unsync"], [14, 4, 1, "", "values"], [14, 4, 1, "", "variable_stride_per_key"], [14, 4, 1, "", "weights"], [14, 4, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[14, 4, 1, "", "device"], [14, 4, 1, "", "from_tensor_list"], [14, 4, 1, "", "key_dim"], [14, 4, 1, "", "keys"], [14, 4, 1, "", "length_per_key"], [14, 4, 1, "", "offset_per_key"], [14, 4, 1, "", "record_stream"], [14, 4, 1, "", "regroup"], [14, 4, 1, "", "regroup_as_dict"], [14, 4, 1, "", "to"], [14, 4, 1, "", "to_dict"], [14, 4, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:function", "2": "py:class", "3": "py:attribute", "4": "py:method", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "function", "Python function"], "2": ["py", "class", "Python class"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "method", "Python method"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "document": 0, "tutori": 0, "api": 0, "indic": 0, "tabl": 0, "overview": 1, "why": 1, "dataset": [2, 3], "criteo": 2, "movielen": 2, "random": 2, "util": [2, 4, 5, 11], "script": 3, "contiguous_preproc_criteo": 3, "npy_preproc_criteo": 3, "distribut": [4, 5, 6], "collective_util": 4, "comm": 4, "comm_op": 4, "dist_data": [4, 6], "embed": 4, "embedding_lookup": 4, "embedding_shard": 4, "embedding_typ": 4, "embeddingbag": 4, "grouped_position_weight": 4, "model_parallel": 4, "quant_embeddingbag": 4, "train_pipelin": 4, "type": [4, 5], "mc_modul": [4, 11], "mc_embeddingbag": 4, "mc_embed": 4, "planner": 5, "constant": 5, "enumer": 5, "partition": 5, "perf_model": 5, "propos": 5, "shard_estim": 5, "stat": 5, "storage_reserv": 5, "shard": 6, "cw_shard": 6, "dp_shard": 6, "rw_shard": 6, "tw_shard": 6, "twcw_shard": 6, "twrw_shard": 6, "fx": 7, "tracer": 7, "modul": [7, 8, 10, 11, 12, 13, 14], "content": [7, 8, 10, 12, 13, 14], "infer": 8, "model_packag": 8, "metric": 9, "accuraci": 9, "auc": 9, "auprc": 9, "calibr": 9, "ctr": 9, "mae": 9, "mse": 9, "multiclass_recal": 9, "ndcg": 9, "ne": 9, "recal": 9, "precis": 9, "rauc": 9, "throughput": 9, "weighted_avg": 9, "xauc": 9, "metric_modul": 9, "rec_metr": 9, "model": 10, "deepfm": [10, 11], "dlrm": 10, "activ": 11, "crossnet": 11, "embedding_config": 11, "embedding_modul": [11, 13], "feature_processor": 11, "lazy_extens": 11, "mlp": 11, "mc_embedding_modul": 11, "optim": 12, "clip": 12, "fuse": 12, "kei": 12, "warmup": 12, "quant": 13, "spars": 14, "jagged_tensor": 14}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 56}})
\ No newline at end of file
+Search.setIndex({"docnames": ["index", "overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "overview.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.metrics.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "TorchRec Overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"recommendation system": 0, "shard": [0, 1, 4, 5, 8, 11, 12, 13], "distributed train": 0, "special": [0, 1, 7, 9, 11, 12], "librari": [0, 1], "within": [0, 4, 5, 6, 8, 10, 11, 14], "pytorch": [0, 1, 2, 4, 11, 12, 14], "ecosystem": [0, 1], "tailor": 0, "build": [0, 1, 5], "scale": [0, 1], "deploi": [0, 1, 8], "larg": [0, 1, 2, 5], "recommend": [0, 1, 2, 9, 10], "system": [0, 1, 2, 4, 5, 10], "nich": 0, "directli": [0, 4, 12], "address": [0, 1, 4], "standard": 0, "offer": 0, "advanc": [0, 1, 12], "featur": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 14], "complex": [0, 7], "techniqu": [0, 1], "massiv": [0, 1], "embed": [0, 1, 5, 6, 7, 10, 11, 13, 14], "tabl": [0, 1, 4, 5, 6, 7, 8, 10, 11, 13], "enhanc": 0, "distribut": [0, 1, 2, 8, 9, 11, 12, 14], "train": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "capabl": [0, 1], "topic": 0, "thi": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "section": 0, "help": [0, 4], "you": [0, 4, 6, 7, 14], "overview": 0, "A": [0, 2, 4, 5, 6, 7, 8, 9, 12, 14], "short": 0, "intro": 0, "why": 0, "need": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "set": [0, 2, 4, 5, 6, 8, 9, 11, 12], "up": [0, 4, 5, 13], "learn": [0, 8, 10, 11, 12], "instal": 0, "us": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "your": [0, 7, 9], "environ": [0, 1, 4, 8], "tutori": 0, "follow": [0, 1, 4, 5, 6, 9, 10, 11, 12, 14], "our": 0, "interact": [0, 4, 10, 11], "step": [0, 4, 5, 12], "real": [0, 5], "life": 0, "exampl": [0, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "we": [0, 2, 4, 5, 6, 7, 9, 11, 12, 13, 14], "feedback": [0, 5], "from": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "commun": [0, 1, 4, 5, 6, 9], "If": [0, 2, 4, 5, 8, 9, 11, 12, 14], "ar": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "interest": 0, "improv": [0, 1, 12], "project": [0, 10], "here": [0, 5, 10], "can": [0, 1, 2, 4, 5, 9, 10, 11, 12, 14], "visit": 0, "github": [0, 11], "repositori": 0, "There": [0, 4], "yoou": 0, "find": [0, 5], "sourc": [0, 2, 10, 11], "code": [0, 1, 4, 11], "issu": [0, 4, 6, 11], "ongo": 0, "submit": 0, "encount": 0, "ani": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 14], "bug": 0, "have": [0, 2, 4, 5, 6, 9, 10, 11, 12, 14], "suggest": 0, "pleas": [0, 2, 4, 5, 8, 9, 11, 14], "an": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "through": [0, 4, 7, 9, 12], "tracker": 0, "propos": [0, 10], "chang": [0, 4, 11, 12], "fork": [0, 4], "pull": 0, "request": [0, 4, 8, 12], "whether": [0, 2, 4, 5, 7, 9, 11, 13], "s": [0, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "fix": [0, 14], "ad": [0, 2, 4, 8, 9, 11, 12], "new": [0, 4, 5, 9], "alwai": [0, 11, 14], "make": [0, 4, 11, 12], "sure": [0, 12], "review": 0, "md": 0, "go": [0, 12], "repo": 0, "design": [1, 2, 4, 8, 9, 11], "provid": [1, 4, 5, 6, 8, 9, 10, 11, 13], "common": [1, 2, 11, 14], "primit": [1, 4, 6], "creat": [1, 4, 7, 8, 9, 11, 12, 14], "state": [1, 4, 5, 8, 9, 11, 12], "art": 1, "person": [1, 10], "model": [1, 4, 5, 6, 7, 8, 9, 11, 12, 13], "path": [1, 2, 4, 5, 8], "product": [1, 10], "wide": 1, "adopt": 1, "mani": [1, 4, 6], "meta": [1, 4, 5, 8], "infer": [1, 4, 5, 6, 13, 14], "workflow": 1, "uniqu": [1, 2, 5, 9, 11], "challeng": [1, 2], "which": [1, 2, 4, 5, 6, 8, 9, 11, 12, 14], "focu": [1, 9], "regular": 1, "more": [1, 4, 5, 6, 9, 11], "specif": [1, 4, 5, 8, 12], "gener": [1, 2, 4, 5, 7, 8, 10, 11, 12, 14], "compon": [1, 9, 11], "simplist": 1, "modul": [1, 4, 5, 6, 9], "author": [1, 4], "flexibl": [1, 11], "customiz": [1, 5], "method": [1, 4, 7, 8, 9, 11], "row": [1, 2, 4, 5, 6], "wise": [1, 4, 5, 6, 11], "column": [1, 2, 5, 6], "so": [1, 2, 4, 5, 9, 10, 12, 14], "automat": [1, 4, 5, 9, 14], "determin": [1, 2, 4, 5, 6], "best": [1, 5], "plan": [1, 4, 5, 8, 11], "devic": [1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "topolog": [1, 4, 5, 6], "effici": [1, 5, 11], "memori": [1, 2, 4, 5, 8, 9, 12], "balanc": [1, 5], "while": [1, 2, 4, 5, 6, 7, 8, 11], "support": [1, 4, 5, 6, 7, 9, 11, 12], "basic": [1, 10, 14], "extend": [1, 4], "sophist": 1, "parallel": [1, 2, 4, 6], "incred": 1, "optim": [1, 4, 5, 8, 9, 11, 13], "top": [1, 4, 9, 11], "fbgemm": [1, 4, 5, 6, 13], "after": [1, 4, 5, 6, 9, 11], "all": [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "power": 1, "some": [1, 4, 9, 10, 14], "largest": [1, 5], "frictionless": 1, "deploy": 1, "simpl": [1, 10], "api": [1, 4, 6, 7, 9, 11], "transform": [1, 2, 4, 8, 11], "load": [1, 2, 4, 5, 6, 11, 12], "c": [1, 2, 4, 6, 8, 14], "most": [1, 4, 8, 12], "integr": 1, "built": [1, 11], "mean": [1, 4, 5, 9, 11], "seamlessli": 1, "exist": [1, 4, 6, 11, 14], "tool": 1, "allow": [1, 2, 4, 5, 7, 9, 11, 12], "develop": 1, "leverag": [1, 11], "knowledg": [1, 4, 5, 9], "codebas": 1, "util": [1, 6], "By": 1, "being": [1, 4, 5, 8, 9, 11], "part": [1, 2, 4, 5, 6, 11, 12], "benefit": 1, "robust": 1, "continu": [1, 2], "updat": [1, 4, 5, 6, 8, 9, 11, 12], "come": [1, 11], "contain": [2, 4, 5, 6, 8, 9, 11, 12, 13], "two": [2, 4, 5, 9, 10, 11, 14], "popular": [2, 10], "reci": 2, "kaggl": 2, "displai": 2, "advertis": 2, "20m": 2, "addition": 2, "randomdataset": 2, "data": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "same": [2, 4, 5, 6, 8, 9, 10, 11, 14], "format": [2, 3, 5, 8, 14], "abov": [2, 11, 14], "lastli": 2, "script": [2, 14], "pre": [2, 9, 11, 12], "process": [2, 3, 4, 5, 6, 9, 10, 11, 13], "etc": [2, 4, 8, 12, 14], "import": [2, 4, 5, 8, 11, 13], "criteo_kaggl": 2, "datapip": 2, "criteo_terabyt": 2, "home": 2, "day_0": 2, "tsv": [2, 3], "day_1": 2, "dp": [2, 5], "iter": [2, 4, 5, 11, 12], "batcher": 2, "100": [2, 4, 5, 9, 10, 11], "collat": 2, "batch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "next": [2, 5], "class": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "binarycriteoutil": 2, "base": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "object": [2, 4, 5, 8, 9, 11, 12], "function": [2, 3, 4, 5, 6, 7, 8, 11, 12, 14], "preprocess": [2, 3, 11], "save": [2, 3, 4, 5, 11, 12], "partit": [2, 5, 6], "binari": [2, 3, 5, 9], "numpi": 2, "static": [2, 4, 5, 9, 12, 14], "get_file_row_ranges_and_remaind": 2, "length": [2, 4, 5, 6, 10, 11, 13, 14], "list": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "int": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "rank": [2, 4, 5, 6, 9, 10, 11, 12, 14], "world_siz": [2, 4, 5, 6, 8, 9], "start_row": 2, "0": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "last_row": 2, "option": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "none": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "tupl": [2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14], "dict": [2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "given": [2, 4, 5, 6, 7, 11], "number": [2, 4, 5, 6, 8, 9, 10, 11, 14], "file": [2, 3, 4], "return": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "portion": 2, "those": [2, 10, 11], "repres": [2, 4, 5, 8, 10, 11, 13, 14], "rang": [2, 5, 7, 11], "indic": [2, 4, 5, 6, 8, 11, 12, 13, 14], "inclus": 2, "should": [2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "handl": [2, 4, 5, 6, 7, 11, 12], "each": [2, 4, 5, 6, 9, 10, 11, 13, 14], "assign": [2, 4, 14], "The": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "wai": [2, 4, 5], "deal": 2, "enabl": [2, 4, 5, 9, 12], "reduc": [2, 4, 6, 11, 13], "amount": [2, 5], "read": 2, "avoid": [2, 4, 8, 9, 11, 12], "seek": 2, "paramet": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "count": [2, 5, 9, 11], "world": [2, 4, 6], "size": [2, 4, 5, 6, 9, 10, 11, 13, 14], "first": [2, 4, 5, 6, 10, 11, 12, 14], "item": [2, 4, 6], "map": [2, 4, 8, 9, 11, 12, 13], "kei": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "second": [2, 4, 5, 6, 9, 10, 11, 14], "remaind": 2, "type": [2, 6, 7, 8, 9, 10, 11, 12, 13, 14], "output": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "get_shape_from_npi": 2, "str": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "path_manager_kei": 2, "shape": [2, 4, 6, 9, 10, 11, 14], "npy": [2, 3], "onli": [2, 4, 5, 6, 9, 11, 14], "its": [2, 4, 5, 6, 8, 9, 11, 12, 14], "header": 2, "input": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "manag": [2, 11], "differ": [2, 4, 5, 6, 11, 12, 14], "filesystem": 2, "load_npy_rang": 2, "fname": 2, "num_row": 2, "mmap_mod": 2, "bool": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "fals": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "ndarrai": 2, "note": [2, 4, 5, 6, 11, 14], "assum": [2, 4, 5, 6, 8, 9, 10, 12], "arrai": 2, "ndim": 2, "2": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "string": [2, 5, 8, 11], "start": [2, 4, 11, 14], "get": [2, 4, 5, 6], "desir": [2, 4, 8], "suppli": 2, "np": 2, "shuffl": [2, 6], "input_dir_labels_and_dens": 2, "input_dir_spars": 2, "output_dir_shuffl": 2, "rows_per_dai": 2, "output_dir_full_set": 2, "dai": 2, "24": [2, 4, 6], "int_column": 2, "13": 2, "sparse_column": 2, "26": 2, "random_se": 2, "expect": [2, 3, 4, 5, 10, 11], "split": [2, 4, 5, 6, 8, 14], "dens": [2, 4, 5, 10, 11, 14], "spars": [2, 3, 4, 6, 10, 11, 13], "label": [2, 4, 6, 9, 10], "must": [2, 4, 5, 6, 8, 9, 10, 11], "day_x_dens": 2, "day_x_spars": 2, "day_x_label": 2, "reconstruct": 2, "back": 2, "separ": [2, 3, 4], "1": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "final": [2, 5, 9, 10, 11, 13, 14], "remain": 2, "untouch": 2, "valid": [2, 5, 11, 14], "directori": [2, 4], "full": [2, 11, 12, 14], "total": [2, 4, 5, 6, 9], "categor": 2, "seed": [2, 5], "oper": [2, 4, 5, 6, 7, 11, 14], "sparse_to_contigu": 2, "in_fil": 2, "output_dir": 2, "frequency_threshold": 2, "3": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "output_file_suffix": 2, "_contig_freq": 2, "convert": [2, 4, 7, 8, 14], "contigu": [2, 3], "integ": 2, "store": [2, 4, 5, 6, 14], "togeth": [2, 4, 11], "becaus": [2, 5, 11, 12], "match": [2, 4, 5, 8, 9, 10, 11], "id": [2, 4, 5, 6, 11], "between": [2, 4, 5, 9, 10, 11, 14], "henc": 2, "thei": [2, 4, 5, 14], "also": [2, 4, 5, 8, 9, 10, 11, 12], "appear": [2, 11], "less": 2, "than": [2, 5, 11, 12], "time": [2, 4, 5, 6, 8, 9, 11], "remap": [2, 11], "valu": [2, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14], "day_0_spars": 2, "col_0": 2, "col_1": 2, "abc": [2, 4, 5, 8, 9, 11, 12], "xyz": 2, "iop": 2, "day_1_spars": 2, "tuv": 2, "lkj": 2, "day_0_sparse_contig": 2, "day_1_sparse_contig": 2, "occur": 2, "frequenc": [2, 4], "tsv_to_npi": 2, "out_dense_fil": 2, "out_sparse_fil": 2, "out_labels_fil": 2, "dataset_nam": 2, "criteo_1tb": 2, "one": [2, 4, 5, 6, 8, 9, 10, 11, 12], "three": [2, 9], "float32": [2, 4, 8, 11, 13], "int32": [2, 6, 14], "1tb": 2, "click": [2, 9], "log": [2, 5, 9], "For": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "test": [2, 11], "filler": 2, "includ": [2, 4, 5, 7, 8, 9, 11, 14], "name": [2, 4, 5, 8, 9, 10, 11, 12, 13, 14], "criteoiterdatapip": 2, "row_mapp": 2, "callabl": [2, 4, 6, 7, 11, 12, 13], "_default_row_mapp": 2, "open_kw": 2, "iterdatapip": 2, "stream": [2, 4, 11, 14], "either": [2, 4, 5, 9, 11], "http": [2, 4, 5, 10, 11, 14], "ailab": 2, "com": [2, 11], "download": 2, "www": 2, "local": [2, 4, 5, 6, 9, 11], "constitut": 2, "appli": [2, 4, 6, 10, 11], "line": [2, 3], "pass": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "underli": [2, 9], "invoc": [2, 5], "iopath": 2, "file_io": 2, "pathmanag": 2, "open": 2, "inmemorybinarycriteoiterdatapip": [2, 3], "stage": [2, 12], "dense_path": 2, "sparse_path": 2, "labels_path": 2, "batch_siz": [2, 4, 5, 6, 9, 11, 14], "drop_last": 2, "shuffle_batch": 2, "shuffle_training_set": 2, "shuffle_training_set_random_se": 2, "hash": [2, 6, 11], "iterabledataset": 2, "over": [2, 4, 6, 10, 11, 12], "version": [2, 4, 11, 13], "entir": [2, 5, 6], "prevent": 2, "disk": 2, "speed": [2, 13], "affect": [2, 5], "throughout": [2, 10], "respons": [2, 4], "npy_preproc_criteo": 2, "py": [2, 4], "val": [2, 4], "max": [2, 5, 11, 12], "cat_feature_count": 2, "templat": [2, 9], "1tb_binari": 2, "day_": 2, "_": [2, 8], "1024": 2, "torch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "get_rank": [2, 4], "get_world_s": [2, 8], "train_datapip": 2, "txt": 2, "test_datapip": 2, "movielens_20m": 2, "root": [2, 7], "include_movies_data": 2, "param": [2, 4, 5, 9, 12], "true": [2, 4, 5, 8, 9, 11, 12, 14], "add": [2, 4, 7, 11, 12], "movi": 2, "ml": 2, "20": [2, 10, 11], "movielens_25m": 2, "25m": 2, "25": [2, 4, 9], "randomrecdataset": 2, "hash_siz": 2, "ids_per_featur": 2, "num_dens": 2, "50": 2, "manual_se": 2, "num_batch": 2, "num_generated_batch": 2, "10": [2, 5, 10, 11, 13, 14], "min_ids_per_featur": 2, "recsi": [2, 5, 10, 12], "current": [2, 4, 5, 6, 8, 9, 11], "produc": [2, 4, 5], "unweight": 2, "todo": 2, "weight": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "taken": [2, 5], "modulo": 2, "per": [2, 4, 5, 6, 9, 11, 14], "correspond": [2, 4, 5, 6, 8, 9, 11, 14], "argument": [2, 4, 7, 8, 9, 11], "ignor": [2, 4, 5, 6, 8, 11], "sampl": [2, 5, 9, 11], "determinist": [2, 5], "behavior": [2, 4, 7, 12], "num": 2, "befor": [2, 4, 5, 6, 9, 11, 12], "rais": [2, 4], "stopiter": 2, "cach": [2, 4, 5], "num_gener": 2, "cycl": 2, "neg": 2, "fly": 2, "minimum": [2, 5], "feat1": 2, "feat2": 2, "16": [2, 4, 6, 11, 13], "100_000": 2, "dense_featur": [2, 10, 11], "tensor": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "sparse_featur": [2, 4, 6, 10], "jagged_tensor": [2, 4], "keyedjaggedtensor": [2, 4, 6, 10, 11, 13, 14], "pipelin": [2, 4, 5, 11, 14], "pin_memori": [2, 14], "record_stream": [2, 4, 11, 14], "see": [2, 4, 5, 6, 7, 9, 11, 14], "org": [2, 4, 5, 10, 11, 14], "doc": [2, 4, 11, 14], "stabl": [2, 4, 11, 14], "html": [2, 4, 11, 14], "non_block": [2, 14], "awar": [2, 4, 14], "accord": [2, 4, 5, 6, 8, 10, 12, 14], "might": [2, 5, 14], "self": [2, 4, 5, 6, 8, 11, 14], "copi": [2, 4, 6, 8, 9, 11, 12, 14], "rememb": [2, 4, 14], "new_devic": [2, 14], "limit": [2, 8, 9, 11], "loadfil": 2, "mode": [2, 4, 5, 9], "b": [2, 4, 5, 6, 10, 11, 13, 14], "iobas": 2, "adapt": [2, 11], "loadfilesfromdisk": 2, "merg": [2, 4, 6], "replac": [2, 9, 12], "someth": 2, "core": 2, "lib": 2, "parallelreadconcat": 2, "dp_selector": 2, "sequenc": [2, 4, 5, 6], "_default_dp_selector": 2, "concaten": [2, 4, 6, 10, 11, 14], "multipl": [2, 4, 5, 9, 10, 11, 12], "when": [2, 4, 5, 7, 9, 11, 12], "dataload": [2, 4], "subset": [2, 5], "worker": [2, 4], "instanc": [2, 4, 6, 7, 8, 9, 11], "would": [2, 4, 6, 14], "f": [2, 4, 5, 6, 10, 11, 13], "shard_": 2, "idx": [2, 4], "4": [2, 4, 5, 6, 9, 10, 11, 13, 14], "num_work": [2, 14], "readlinesfromcsv": 2, "skip_first_lin": 2, "kw": 2, "idx_split_train_v": 2, "train_perc": 2, "float": [2, 4, 5, 7, 9, 11, 12, 14], "decimal_plac": 2, "key_fn": 2, "_default_key_fn": 2, "rand_split_train_v": 2, "via": [2, 4, 6], "uniform": [2, 5, 11], "disjoint": 2, "specifi": [2, 4, 5, 6, 7, 9, 10, 11, 12], "target": [2, 4, 10], "proport": 2, "actual": [2, 4, 5, 6, 8, 9, 11], "guarante": [2, 5, 7, 12], "exactli": [2, 4], "membership": 2, "across": [2, 4, 5, 6, 9], "call": [2, 4, 5, 6, 8, 9, 11, 12, 13], "consist": [2, 4], "val_datapip": 2, "75": 2, "train_batch": 2, "val_batch": 2, "safe_cast": 2, "t": [2, 4, 5, 6, 7, 8, 11, 12, 14], "dest_typ": 2, "default": [2, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "train_filt": 2, "val_filt": 2, "main": [3, 9], "argv": 3, "result": [3, 4, 5, 6, 8, 9, 11, 13], "command": 3, "arg": [3, 4, 5, 8, 9, 11, 13, 14], "parse_arg": 3, "namespac": [3, 14], "raw": [3, 11], "criteo": [3, 10], "necessari": [4, 5, 6, 8, 9], "These": [4, 5, 9, 11], "distributedmodelparallel": 4, "collect": [4, 6, 10, 11, 12, 13], "scatter": [4, 6], "wrapper": [4, 12], "kjt": [4, 5, 6, 10, 11, 13, 14], "variou": [4, 8, 11], "implement": [4, 5, 6, 8, 9, 10, 11, 12, 14], "shardedembeddingbag": 4, "nn": [4, 5, 7, 8, 10, 11, 13], "shardedembeddingbagcollect": [4, 11, 13], "embeddingbagcollect": [4, 8, 10, 11, 13], "sharder": [4, 5, 8], "defin": [4, 5, 6, 8, 9, 10, 11], "comput": [4, 5, 6, 8, 9, 10, 11, 13], "kernel": [4, 5, 11], "cpu": [4, 5, 9], "gpu": [4, 5], "mai": [4, 14], "fusion": 4, "trainpipelinesparsedist": 4, "overlap": 4, "transfer": 4, "inter": [4, 11], "input_dist": [4, 11], "forward": [4, 5, 6, 8, 9, 10, 11, 13, 14], "backward": [4, 5, 7, 12], "increas": [4, 9], "perform": [4, 5, 6, 8, 9, 11, 12, 13], "quantiz": [4, 6, 7, 8, 13], "precis": [4, 5, 11, 13], "construct": [4, 7, 11, 14], "control": [4, 7], "flow": [4, 7], "invoke_on_rank_and_broadcast_result": 4, "pg": [4, 5, 6, 9], "processgroup": [4, 5, 6, 9], "func": 4, "kwarg": [4, 9, 11, 14], "invok": [4, 5], "broadcast": [4, 5], "member": [4, 11], "group": [4, 5, 6, 9, 11, 12, 14], "allocate_id": 4, "is_lead": 4, "leader_rank": 4, "check": [4, 5, 9, 11, 12, 14], "processs": 4, "leader": [4, 9], "dist": [4, 6, 9], "impli": 4, "e": [4, 5, 6, 7, 8, 9, 10, 11, 12], "g": [4, 5, 8, 9, 10, 11, 12], "singl": [4, 5, 6, 11, 12], "program": [4, 5], "definit": [4, 7, 8], "caller": 4, "overrid": [4, 5, 7, 8, 9], "context": [4, 6, 14], "run_on_lead": 4, "get_group_rank": 4, "avail": [4, 5, 6], "group_rank": 4, "varibl": 4, "get_num_group": 4, "elast": 4, "run": [4, 5, 6, 8, 9, 11, 12], "get_local_rank": 4, "usual": [4, 5, 6, 9, 11], "node": [4, 7], "get_local_s": 4, "equival": 4, "max_nnod": 4, "intra_and_cross_node_pg": 4, "backend": [4, 6, 8, 9], "sub": 4, "intra": 4, "cross": [4, 10, 11], "all2alldenseinfo": 4, "output_split": [4, 6], "input_shap": 4, "input_split": [4, 6], "attribut": [4, 5, 9, 12], "alltoall_dens": 4, "all2allpooledinfo": 4, "batch_size_per_rank": [4, 6], "dim_sum_per_rank": [4, 6], "dim_sum_per_rank_tensor": 4, "cumsum_dim_sum_per_rank_tensor": 4, "codec": [4, 6], "quantizedcommcodec": [4, 6], "alltoall_pool": [4, 6], "sum": [4, 5, 6, 11], "dimens": [4, 5, 6, 10, 11, 13, 14], "fast": 4, "_recat_pooled_embedding_grad_out": 4, "cumul": [4, 9, 14], "all2allsequenceinfo": 4, "embedding_dim": [4, 5, 6, 10, 11, 13], "lengths_after_sparse_data_all2al": 4, "forward_recat_tensor": 4, "backward_recat_tensor": 4, "variable_batch_s": 4, "permuted_lengths_after_sparse_data_all2al": 4, "alltoall_sequ": 4, "alltoal": [4, 6], "recat": [4, 6, 11, 14], "variabl": [4, 6, 9, 11, 13, 14], "all2allvinfo": 4, "dims_sum_per_rank": 4, "b_global": 4, "b_local": 4, "b_local_list": 4, "d_local_list": 4, "input_split_s": 4, "factori": [4, 5, 11], "output_split_s": 4, "alltoallv": 4, "global": [4, 5, 6, 9], "my": 4, "how": [4, 5, 6, 12], "do": [4, 5, 9, 11, 12, 14], "all_to_all_singl": 4, "fill": 4, "all2all_pooled_req": 4, "ctx": 4, "unus": [4, 11], "formula": [4, 5], "differenti": 4, "overridden": [4, 6, 8, 9, 11], "subclass": [4, 6, 8, 11, 12], "vjp": 4, "It": [4, 5, 6, 8, 9, 11, 12, 13, 14], "accept": [4, 5, 8, 9, 11], "non": [4, 5, 6, 7, 9, 11, 13], "were": 4, "gradient": [4, 5, 12], "w": [4, 6, 11, 14], "r": [4, 11], "requir": [4, 5, 9, 11, 12], "grad": [4, 12], "just": [4, 5, 10, 11, 14], "retriev": 4, "dure": [4, 5, 9, 12], "ha": [4, 5, 9, 10, 11, 14], "needs_input_grad": 4, "boolean": 4, "myreq": 4, "a2ai": 4, "input_embed": [4, 11], "custom": [4, 5, 7, 11], "autograd": [4, 8, 9, 11], "usag": [4, 5, 9], "combin": [4, 11, 12], "staticmethod": 4, "def": [4, 11], "other": [4, 5, 6, 9, 12], "detail": [4, 5, 6, 9, 11], "setup_context": 4, "longer": [4, 5], "instead": [4, 5, 6, 8, 11, 12], "arbitrari": 4, "though": 4, "enforc": [4, 8, 9, 11], "compat": [4, 5, 7, 10, 12], "save_for_backward": 4, "intend": 4, "save_for_forward": 4, "jvp": 4, "all2all_pooled_wait": 4, "grad_output": 4, "dummy_tensor": 4, "all2all_seq_req": 4, "sharded_input_embed": 4, "all2all_seq_req_wait": 4, "sharded_grad_output": 4, "all2allv_req": 4, "all2allv_wait": 4, "allgatherbaseinfo": 4, "input_s": [4, 5, 11], "all_gatther_base_pool": 4, "allgatherbase_req": 4, "agi": 4, "allgatherbase_wait": 4, "reducescatterbaseinfo": 4, "reduce_scatter_base_pool": 4, "flatten": [4, 6, 11], "reducescatterbase_req": 4, "rsi": 4, "reducescatterbase_wait": 4, "reducescatterinfo": 4, "reduce_scatter_pool": 4, "reducescattervinfo": 4, "equal_split": 4, "total_input_s": 4, "reduce_scatter_v_pool": 4, "along": [4, 6, 9, 10, 12, 14], "dim": [4, 6], "reducescatterv_req": 4, "reducescatterv_wait": 4, "reducescatter_req": 4, "reducescatter_wait": 4, "await": [4, 6, 7], "variablebatchall2allpooledinfo": 4, "batch_size_per_rank_per_featur": [4, 6], "batch_size_per_feature_pre_a2a": [4, 6], "emb_dim_per_rank_per_featur": [4, 6], "variable_batch_alltoall_pool": [4, 6], "variable_batch_all2all_pooled_req": 4, "variable_batch_all2all_pooled_wait": 4, "all2all_pooled_sync": 4, "all2all_sequence_sync": 4, "all2allv_sync": 4, "all_gather_base_pool": 4, "gather": [4, 6], "form": [4, 11, 13], "pool": [4, 5, 6, 10, 11, 13, 14], "output_tensor_s": 4, "work": [4, 5, 8, 9, 11, 14], "async": [4, 6], "wait": [4, 6], "later": [4, 11], "experiment": [4, 11], "subject": 4, "all_gather_base_sync": 4, "all_gather_into_tensor_backward": 4, "all_gather_into_tensor_fak": 4, "gather_dim": 4, "group_siz": 4, "group_nam": 4, "gradient_divis": 4, "all_gather_into_tensor_setup_context": 4, "all_to_all_single_backward": 4, "all_to_all_single_fak": 4, "all_to_all_single_setup_context": 4, "a2a_pooled_embs_tensor": 4, "Then": 4, "receiv": [4, 6, 12], "Its": 4, "x": [4, 5, 6, 10, 11, 13, 14], "d_local_sum": 4, "where": [4, 5, 6, 9, 10, 11, 13], "a2a_sequence_embs_tensor": 4, "doe": [4, 5, 10, 11, 12, 14], "mix": 4, "out_split": 4, "per_rank_split_length": 4, "assumpt": [4, 14], "emb": 4, "get_gradient_divis": 4, "get_use_sync_collect": 4, "pg_name": 4, "reduce_scatter_base_sync": 4, "chunk": [4, 6], "reduce_scatter_sync": 4, "reduce_scatter_tensor_backward": 4, "reduce_scatter_tensor_fak": 4, "reduceop": 4, "reduce_scatter_tensor_setup_context": 4, "reduce_scatter_v_per_feature_pool": 4, "v": [4, 6, 11, 14], "d": [4, 5, 10, 11, 13, 14], "unevenli": 4, "reduce_scatter_v_sync": 4, "set_gradient_divis": 4, "set_use_sync_collect": 4, "torchrec_use_sync_collect": 4, "variable_batch_all2all_pooled_sync": 4, "embeddingsalltoon": [4, 6], "cat_dim": [4, 6, 14], "buffer": [4, 6, 8, 9, 11], "alloc": [4, 5, 6, 8], "like": [4, 5, 6, 7, 11, 12, 14], "alltoon": [4, 6], "set_devic": [4, 6], "device_str": [4, 6], "embeddingsalltoonereduc": [4, 6], "jaggedtensoralltoal": [4, 6], "jt": [4, 6, 11, 14], "jaggedtensor": [4, 6, 11, 13, 14], "num_items_to_send": [4, 6], "num_items_to_rec": [4, 6], "redistribut": [4, 6], "send": [4, 6], "known": [4, 5, 6, 11], "ahead": [4, 6], "keyedjaggedtensorpool": [4, 6], "lookup": [4, 5, 6, 10, 11, 13], "anoth": [4, 5, 6], "kjtalltoal": [4, 6], "stagger": [4, 6, 14], "kjtalltoallsplitsawait": [4, 6], "transmit": [4, 6], "correct": [4, 6, 14], "space": [4, 5, 6, 10], "kjtalltoalltensorsawait": [4, 6], "asynchron": [4, 6], "len": [4, 6, 10], "order": [4, 5, 6, 8, 9, 11, 14], "destin": [4, 6, 8, 9, 11], "_get_recat": [4, 6], "kjta2a": [4, 6], "rank0_input": [4, 6], "hold": [4, 5, 6, 12, 14], "v0": [4, 6, 14], "v1": [4, 6, 11, 14], "v2": [4, 6, 10, 11, 14], "rank1_input": [4, 6], "v3": [4, 6, 14], "v4": [4, 6, 14], "rank0_output": [4, 6], "5": [4, 6, 9, 10, 11, 13, 14], "rank1_output": [4, 6], "relev": [4, 5, 6], "tensor_split": [4, 6], "input_tensor": [4, 6], "ie": [4, 5, 6, 11, 14], "stride_per_rank": [4, 6, 14], "stride": [4, 6, 14], "case": [4, 5, 6, 9, 11, 12, 14], "kjtonetoal": [4, 6], "onetoal": [4, 6], "essenti": [4, 5, 6, 14], "p2p": [4, 6], "keyjaggedtensor": [4, 6], "them": [4, 6, 8, 10, 11, 12], "kjtlist": [4, 6], "slice": [4, 6, 7, 14], "mergepooledembeddingsmodul": [4, 6], "merge_pooled_embedding_optim": [4, 6], "_mergepooledembeddingsmoduleimpl": [4, 6], "merge_pooled_embed": [4, 6], "pooledembeddingsallgath": [4, 6], "wrap": [4, 6, 9, 10, 12], "layout": [4, 6, 7], "want": [4, 6, 9], "nccl": [4, 6], "happen": [4, 5, 6], "init_distribut": [4, 6], "new_group": [4, 6, 9], "randn": [4, 6, 10, 11], "m": [4, 5, 6, 7, 11], "local_emb": [4, 6], "pooledembeddingsawait": [4, 6], "num_bucket": [4, 6], "pooledembeddingsalltoal": [4, 6], "callback": [4, 6], "a2a": [4, 6], "t0": [4, 6], "rand": [4, 6, 10], "6": [4, 5, 6, 10, 11, 13, 14], "t1": [4, 6, 10, 11, 13], "print": [4, 6, 11, 13], "properti": [4, 5, 6, 8, 9, 10, 11, 12, 13], "tensor_await": [4, 6], "pooledembeddingsreducescatt": [4, 6], "twrw": [4, 5, 6], "unequ": [4, 6], "bucket": [4, 6], "seqembeddingsalltoon": [4, 6], "concat": [4, 6, 11, 14], "sequenceembeddingsalltoal": [4, 6], "features_per_rank": [4, 6], "sharding_ctx": [4, 6], "sequenceshardingcontext": [4, 6], "lengths_after_input_dist": [4, 6], "unbucketize_permute_tensor": [4, 6], "sparse_features_recat": [4, 6], "sequenceembeddingsawait": [4, 6], "permut": [4, 6, 14], "splitsalltoallawait": [4, 6], "tensoralltoal": [4, 6], "1d": [4, 5, 6], "tensoralltoallsplitsawait": [4, 6], "tensoralltoallvaluesawait": [4, 6], "tensor_a2a": [4, 6], "rank0": [4, 6], "rank1": [4, 6], "v5": [4, 6, 14], "v6": [4, 6, 14], "v7": [4, 6, 14], "v8": [4, 6], "v9": [4, 6], "v10": [4, 6], "v11": [4, 6], "v12": [4, 6], "tensorvaluesalltoal": [4, 6], "tensor_vals_a2a": [4, 6], "v13": [4, 6], "v14": [4, 6], "v15": [4, 6], "sent": [4, 6], "equal": [4, 5, 6, 11, 14], "_pg": [4, 6], "variablebatchpooledembeddingsalltoal": [4, 6], "kjt_split": [4, 6], "r0_batch_siz": [4, 6], "r1_batch_siz": [4, 6], "f_0": [4, 6], "f_1": [4, 6], "f_2": [4, 6], "r0_batch_size_per_rank_per_featur": [4, 6], "r1_batch_size_per_rank_per_featur": [4, 6], "r0_batch_size_per_feature_pre_a2a": [4, 6], "r1_batch_size_per_feature_pre_a2a": [4, 6], "r0": [4, 6], "r1": [4, 6], "14": [4, 6], "post": [4, 6], "rank_0": [4, 6], "rank_1": [4, 6], "variablebatchpooledembeddingsreducescatt": [4, 6], "rw": [4, 5, 6, 11], "multipli": [4, 5, 6], "batch_size_r0_f0": [4, 6], "emb_dim_f0": [4, 6], "embeddingcollectionawait": 4, "lazyawait": 4, "embeddingcollectioncontext": 4, "sharding_context": 4, "input_featur": 4, "reverse_indic": [4, 11], "seq_vbe_ctx": [4, 11], "sequencevbecontext": [4, 11], "multistream": [4, 11], "embeddingcollectionshard": 4, "fused_param": [4, 6], "qcomm_codecs_registri": [4, 6], "use_index_dedup": 4, "baseembeddingshard": 4, "embeddingcollect": [4, 8, 11, 13], "module_typ": [4, 8, 13], "parametershard": 4, "env": [4, 6], "shardingenv": [4, 6], "shardedembeddingcollect": [4, 11, 13], "locat": 4, "replic": [4, 5, 6], "embeddingmoduleshardingplan": 4, "fulli": [4, 5, 12], "qualifi": 4, "spec": 4, "shardedmodul": 4, "shardable_paramet": 4, "sharding_typ": [4, 5, 11], "compute_device_typ": 4, "shardingtyp": [4, 5, 11], "well": [4, 5, 11], "table_name_to_parameter_shard": 4, "shardedembeddingmodul": 4, "fusedoptimizermodul": [4, 12], "public": [4, 11], "manual": [4, 10, 12], "dist_input": 4, "compute_and_output_dist": 4, "In": [4, 5, 11, 12, 14], "sens": [4, 12], "initi": [4, 11, 12], "distibut": 4, "soon": 4, "complet": [4, 5], "create_context": 4, "fused_optim": [4, 12], "keyedoptim": [4, 12], "output_dist": [4, 5], "reset_paramet": [4, 11], "create_embedding_shard": 4, "sharding_info": [4, 6], "embeddingshardinginfo": [4, 6], "embeddingshard": [4, 6], "create_sharding_infos_by_shard": 4, "embeddingcollectioninterfac": [4, 11, 13], "create_sharding_infos_by_sharding_device_group": 4, "get_device_from_parameter_shard": 4, "ps": [4, 5], "get_ec_index_dedup": 4, "pad_vbe_kjt_length": 4, "set_ec_index_dedup": 4, "commopgradientsc": 4, "functionctx": 4, "scale_gradient_factor": 4, "groupedembeddingslookup": 4, "grouped_config": 4, "groupedembeddingconfig": [4, 6], "baseembeddinglookup": [4, 6], "i": [4, 5, 6, 7, 9, 10, 11], "flush": 4, "everi": [4, 5, 6, 8, 11], "although": [4, 6, 8, 11], "recip": [4, 6, 8, 11], "afterward": [4, 6, 8, 11], "sinc": [4, 5, 6, 8, 11], "former": [4, 6, 8, 11], "take": [4, 5, 6, 8, 11, 12], "care": [4, 6, 8, 11], "regist": [4, 6, 7, 8, 11], "hook": [4, 6, 8, 11], "latter": [4, 6, 8, 11], "silent": [4, 6, 8, 11], "load_state_dict": [4, 12], "state_dict": [4, 8, 9, 11, 12], "ordereddict": [4, 8, 9, 11], "union": [4, 5, 7, 8, 9, 11, 12], "shardedtensor": [4, 12], "strict": [4, 12], "_incompatiblekei": 4, "descend": [4, 5], "unless": [4, 12], "get_swap_module_params_on_convers": 4, "persist": [4, 8, 9, 11], "strictli": [4, 11], "preserv": [4, 11], "except": [4, 5, 9, 11], "requires_grad": 4, "field": [4, 11, 12, 14], "missing_kei": 4, "miss": [4, 5], "unexpected_kei": 4, "present": [4, 12], "namedtupl": 4, "runtimeerror": 4, "named_buff": [4, 11], "prefix": [4, 8, 9, 11], "recurs": [4, 11], "remove_dupl": [4, 11], "yield": [4, 11], "both": [4, 8, 9, 10, 11, 12, 14], "itself": [4, 10, 11], "prepend": [4, 11], "submodul": [4, 11, 12], "otherwis": [4, 5, 8, 9, 11, 12, 14], "direct": [4, 11], "remov": [4, 7, 11], "duplic": [4, 11, 12], "xdoctest": [4, 8, 9, 11], "skip": [4, 8, 9, 11, 12], "undefin": [4, 8, 9, 11], "var": [4, 8, 9, 11], "buf": [4, 11], "running_var": [4, 11], "named_paramet": 4, "bia": [4, 8, 9, 11], "named_parameters_by_t": 4, "tablebatchedembeddingslic": 4, "table_nam": 4, "embedding_weight": 4, "cw": [4, 5], "compos": [4, 8, 9, 11], "prefetch": [4, 5], "forward_stream": 4, "purg": 4, "keep_var": [4, 8, 9, 11], "dictionari": [4, 8, 9, 11], "refer": [4, 8, 9, 11, 14], "whole": [4, 8, 9, 11], "averag": [4, 5, 8, 9, 11], "shallow": [4, 8, 9, 11], "posit": [4, 5, 6, 8, 9, 11], "howev": [4, 8, 9, 11, 12], "deprec": [4, 8, 9, 11], "keyword": [4, 8, 9, 11], "futur": [4, 8, 9, 11], "releas": [4, 8, 9, 11], "end": [4, 5, 8, 9, 11], "user": [4, 5, 8, 9, 11, 12], "detach": [4, 8, 9, 11], "groupedpooledembeddingslookup": 4, "feature_processor": [4, 6, 13], "basegroupedfeatureprocessor": [4, 6, 11], "scale_weight_gradi": 4, "infercpugroupedembeddingslookup": 4, "grouped_configs_per_rank": 4, "infergroupedlookupmixin": 4, "inputdistoutput": [4, 6], "tbetoregistermixin": 4, "get_tbes_to_regist": 4, "intnbittablebatchedembeddingbagscodegen": 4, "infergroupedembeddingslookup": 4, "input_dist_output": 4, "infergroupedpooledembeddingslookup": 4, "metainfergroupedembeddingslookup": 4, "tbe": [4, 5, 13], "op": [4, 5, 6, 12, 13], "metainfergroupedpooledembeddingslookup": 4, "bag": [4, 6, 7, 10, 11], "dtype": [4, 5, 6, 7, 8, 11, 13, 14], "embeddings_cat_empty_rank_handl": 4, "dummy_embs_tensor": 4, "embeddings_cat_empty_rank_handle_infer": 4, "fx_wrap_tensor_view2d": 4, "dim0": 4, "dim1": 4, "baseembeddingdist": [4, 6], "embeddinglookup": 4, "abstract": [4, 5, 8, 9, 11, 12], "basesparsefeaturesdist": [4, 6], "featureshardingmixin": 4, "table_wis": [4, 11], "create_input_dist": [4, 6], "create_lookup": [4, 6], "create_output_dist": [4, 6], "embedding_nam": [4, 6, 11], "embedding_names_per_rank": [4, 6], "embedding_shard_metadata": [4, 6], "shardmetadata": [4, 6], "embedding_t": [4, 6], "shardedembeddingt": [4, 6], "uncombined_embedding_dim": [4, 6], "uncombined_embedding_nam": [4, 6], "embeddingshardingcontext": [4, 6], "variable_batch_per_featur": 4, "embedding_config": [4, 13], "embeddingtableconfig": [4, 11], "param_shard": 4, "nonetyp": [4, 9, 11], "fusedkjtlistsplitsawait": 4, "kjtlistsplitsawait": 4, "kjtlistawait": 4, "info": [4, 11], "metadata": [4, 8, 11], "kjtsplitsalltoallmeta": 4, "distributed_c10d": 4, "_input": 4, "splits_tensor": 4, "listofkjtlistawait": 4, "listofkjtlist": 4, "listofkjtlistsplitsawait": 4, "bucketize_kjt_before_all2al": 4, "block_siz": [4, 6], "output_permut": 4, "bucketize_po": 4, "block_bucketize_row_po": 4, "keep_original_indic": [4, 6], "readjust": 4, "unbucket": 4, "offset": [4, 5, 10, 11, 13, 14], "keep": [4, 5, 11], "origin": [4, 5, 8, 10], "bucketize_kjt_infer": 4, "is_sequ": [4, 6], "group_tabl": 4, "tables_per_rank": 4, "datatyp": [4, 5, 11, 13, 14], "poolingtyp": [4, 11], "embeddingcomputekernel": [4, 5], "weighted": 4, "interfac": [4, 8, 9, 11], "reli": [4, 8, 11, 13], "moduleshard": [4, 5, 8], "compute_kernel": [4, 5], "storage_usag": 4, "resourc": 4, "processor": [4, 6, 8, 11], "basequantembeddingshard": 4, "shardable_param": 4, "dtensormetadata": 4, "mesh": 4, "device_mesh": 4, "devicemesh": 4, "placement": [4, 5], "_tensor": 4, "placement_typ": 4, "embeddingattribut": 4, "enum": [4, 5, 11, 12], "enumer": [4, 11, 12], "fuse": [4, 6, 9], "fused_uvm": 4, "fused_uvm_cach": 4, "key_valu": 4, "quant": 4, "quant_uvm": 4, "quant_uvm_cach": 4, "feature_nam": [4, 5, 6, 10, 11, 13], "feature_names_per_rank": [4, 6], "data_typ": [4, 11], "is_weight": [4, 5, 11, 13, 14], "has_feature_processor": [4, 6, 11], "dim_sum": 4, "feature_hash_s": [4, 6], "num_featur": [4, 6, 10, 11], "bucket_mapping_tensor": 4, "bucketized_length": 4, "moduleshardingmixin": 4, "access": [4, 5, 12, 14], "scheme": 4, "optimtyp": 4, "adagrad": [4, 12], "adam": [4, 12], "adamw": 4, "lamb": 4, "lars_sgd": 4, "lion": 4, "partial_rowwise_adam": 4, "partial_rowwise_lamb": 4, "rowwise_adagrad": 4, "sgd": 4, "shampoo": 4, "shampoo_v2": 4, "shampoo_v2_mr": 4, "shardedconfig": 4, "local_row": [4, 5], "local_col": [4, 5], "compin": 4, "distout": 4, "out": [4, 11, 14], "shrdctx": 4, "commop": 4, "extra_repr": 4, "pretti": 4, "represent": [4, 5, 7, 11, 14], "num_embed": [4, 5, 10, 11, 13], "fp32": [4, 5, 11], "weight_init_max": [4, 11], "weight_init_min": [4, 11], "num_embeddings_post_prun": [4, 11], "init_fn": [4, 11], "need_po": [4, 6, 11], "local_metadata": 4, "_shard": 4, "global_metadata": 4, "sharded_tensor": 4, "shardedtensormetadata": 4, "dtensor_metadata": 4, "shardedmetaconfig": 4, "compute_kernel_to_embedding_loc": 4, "embeddingloc": 4, "embeddingawait": 4, "embeddingbagcollectionawait": 4, "lazygetitemmixin": 4, "keyedtensor": [4, 10, 11, 13, 14], "embeddingbagcollectioncontext": 4, "inverse_indic": [4, 11, 14], "divisor": 4, "embeddingbagcollectionshard": 4, "embeddingbagshard": 4, "nullshardedmodulecontext": 4, "per_sample_weight": 4, "named_modul": 4, "memo": 4, "network": [4, 5, 11, 12], "alreadi": [4, 6, 8, 12], "onc": [4, 11], "l": [4, 11, 13], "linear": [4, 5, 11, 12], "net": [4, 10, 11], "sequenti": [4, 5, 11], "in_featur": [4, 10, 11], "out_featur": [4, 11], "sharded_parameter_nam": 4, "embeddingbagcollectioninterfac": [4, 11, 13], "variablebatchembeddingbagcollectionawait": 4, "construct_output_kt": 4, "create_embedding_bag_shard": 4, "permute_embed": [4, 6], "suffix": 4, "replace_placement_with_meta_devic": 4, "could": [4, 5, 14], "unmatch": 4, "scenario": [4, 11, 13], "dmp": 4, "cuda": [4, 5, 8], "embeddingshardingplann": [4, 5], "planner": 4, "groupedpositionweightedmodul": 4, "max_feature_length": [4, 11], "dataparallelwrapp": 4, "defaultdataparallelwrapp": 4, "bucket_cap_mb": 4, "static_graph": 4, "find_unused_paramet": 4, "allreduce_comm_precis": 4, "params_to_ignor": 4, "ddp_kwarg": 4, "unshard": [4, 5, 11, 13], "shardingplan": [4, 5, 8], "init_data_parallel": 4, "init_paramet": 4, "data_parallel_wrapp": 4, "entri": 4, "point": [4, 5], "collective_plan": [4, 5], "lazi": [4, 11, 12], "delai": 4, "until": 4, "still": [4, 14], "no_grad": [4, 11], "init_weight": [4, 11], "isinst": 4, "fill_": [4, 11], "elif": 4, "init": 4, "kaiming_normal_": 4, "mymodel": 4, "bare_named_paramet": 4, "tor": 4, "safe": 4, "ddp": 4, "fsdp": 4, "sparse_grad_parameter_nam": [4, 12], "get_modul": 4, "unwrap": 4, "get_unwrapped_modul": 4, "quantembeddingbagcollectionshard": [4, 8], "shardedquantembeddingbagcollect": 4, "quantfeatureprocessedembeddingbagcollectionshard": [4, 8], "featureprocessedembeddingbagcollect": [4, 8, 13], "shardedquantebcinputdist": 4, "sharding_type_device_group_to_shard": 4, "nullshardingcontext": [4, 6], "sharding_type_to_shard": 4, "sqebc_input_dist": 4, "infertwsequenceembeddingshard": 4, "f1": [4, 10, 11, 13], "f2": [4, 10, 11, 13], "7": [4, 9, 10, 11, 13, 14], "8": [4, 5, 10, 11, 13, 14], "shardedquantembeddingmodulest": 4, "embedding_bag_config": [4, 11, 13], "embeddingbagconfig": [4, 10, 11, 13], "execut": [4, 5, 8, 11, 13], "sharding_type_device_group_to_sharding_info": 4, "tbes_config": 4, "shardedquantfeatureprocessedembeddingbagcollect": 4, "featureprocessorscollect": [4, 13], "apply_feature_processor": 4, "kjt_list": [4, 14], "embedding_bag": [4, 13], "moduledict": [4, 13], "modulelist": [4, 9, 11, 13], "create_infer_embedding_bag_shard": 4, "flatten_feature_length": 4, "get_device_from_sharding_info": 4, "emb_shard_info": 4, "cacheparam": [4, 5], "algorithm": 4, "cachealgorithm": 4, "load_factor": [4, 5], "reserved_memori": 4, "prefetch_pipelin": [4, 5], "stat": 4, "cachestatist": [4, 5], "multipass_prefetch_config": 4, "multipassprefetchconfig": 4, "relat": [4, 5, 9], "uvm": [4, 5], "lru": [4, 5], "lfu": 4, "factor": [4, 5, 11], "decid": 4, "crucial": 4, "reserv": [4, 5], "ideal": 4, "aka": 4, "statist": [4, 5], "better": [4, 5], "tune": [4, 12], "cacheabl": [4, 5], "summar": [4, 5], "measur": [4, 5, 9], "difficulti": [4, 5], "dataset": [4, 5, 10], "independ": [4, 5], "score": [4, 5, 6, 11], "veri": [4, 5], "high": [4, 5, 9, 11], "difficult": [4, 5], "expected_lookup": [4, 5], "distinct": [4, 5], "expected_miss_r": [4, 5], "clf": [4, 5], "rate": [4, 5, 9, 12], "hit": [4, 5], "extrem": [4, 5], "estim": [4, 5, 9], "pooled_embeddings_all_to_al": 4, "pooled_embeddings_reduce_scatt": 4, "sequence_embeddings_all_to_al": 4, "computekernel": 4, "moduleshardingplan": 4, "describ": 4, "genericmeta": 4, "getitemlazyawait": 4, "parentw": 4, "kt": [4, 14], "__getitem__": 4, "parent": 4, "keyvalueparam": [4, 5], "ssd_storage_directori": 4, "ssd_rocksdb_write_buffer_s": 4, "ssd_rocksdb_shard": 4, "gather_ssd_cache_stat": 4, "stats_reporter_config": 4, "tbestatsreporterconfig": 4, "use_passed_in_path": 4, "l2_cache_s": 4, "ps_host": 4, "ps_client_thread_num": 4, "ps_max_key_per_request": 4, "ps_max_local_index_length": 4, "ssd": [4, 5], "ssdtablebatchedembeddingbag": 4, "data00_nvidia": 4, "local_rank": 4, "rocksdb": 4, "write": 4, "relav": 4, "compact": 4, "std": 4, "report": [4, 9], "od": 4, "report_interv": 4, "interv": [4, 9, 11], "ods_prefix": 4, "server": 4, "host": [4, 5, 6], "ip": 4, "port": 4, "2000": 4, "2001": 4, "2002": 4, "reason": [4, 12], "hashabl": 4, "thread": [4, 5], "client": 4, "maximum": [4, 5, 9, 11], "index": [4, 11, 14], "expos": [4, 12], "concret": 4, "achiev": 4, "late": 4, "possibl": [4, 5, 9], "__torch_function__": 4, "below": 4, "doesn": [4, 11, 12], "python": [4, 7, 10], "magic": 4, "__getattr__": 4, "caveat": 4, "mechan": [4, 11], "ensur": [4, 11, 14], "perfect": 4, "quickli": 4, "long": [4, 5, 11], "kwd": 4, "vt_co": 4, "augment": 4, "trigger": [4, 11], "keyedlazyawait": 4, "defer": 4, "mixin": 4, "inherit": [4, 9, 11], "mro": 4, "properli": [4, 11], "select": [4, 5, 6, 14], "lazynowait": 4, "classmethod": [4, 5, 8, 13], "noopquantizedcommcodec": 4, "quantizationcontext": 4, "No": [4, 6, 9], "calc_quantized_s": 4, "input_len": 4, "decod": 4, "input_grad": 4, "encod": 4, "padded_s": 4, "dim_per_rank": 4, "my_rank": [4, 9], "qcomm_ctx": 4, "quantized_dtyp": 4, "nowait": [4, 7], "obj": 4, "objectpoolshardingplan": 4, "objectpoolshardingtyp": 4, "replicated_row_wis": 4, "row_wis": [4, 11], "sharding_spec": 4, "shardingspec": 4, "cache_param": [4, 5], "enforce_hbm": [4, 5], "stochastic_round": [4, 5], "bounds_check_mod": [4, 5], "boundscheckmod": [4, 5], "output_dtyp": [4, 5, 8, 13], "key_value_param": [4, 5], "hbm": [4, 5], "stochast": [4, 5], "round": [4, 5], "bound": [4, 5], "place": [4, 5, 6, 12, 14], "column_wis": [4, 11], "seen": [4, 7], "individu": [4, 5, 10], "table_row_wis": [4, 11], "data_parallel": [4, 5, 11], "parameterstorag": 4, "physic": 4, "constraint": [4, 5, 8], "shardingplann": [4, 5], "ddr": [4, 5], "pipelinetyp": [4, 5], "about": 4, "train_bas": 4, "train_prefetch_sparse_dist": 4, "train_sparse_dist": 4, "pooled_all_to_al": 4, "reduce_scatt": 4, "quantized_tensor": 4, "quantized_comm_codec": 4, "collective_cal": 4, "output_tensor": 4, "assert_clos": 4, "int8": [4, 8], "addit": [4, 5, 7, 8, 10, 11, 12, 14], "carri": 4, "session": 4, "padded_dim_sum": 4, "padding_s": 4, "respect": [4, 10, 11], "sequence_all_to_al": 4, "modulenocopymixin": [4, 13], "vise": [4, 12], "versa": [4, 12], "practic": 4, "from_loc": 4, "typic": [4, 5, 7, 11, 12, 14], "from_process_group": 4, "fqn": [4, 5], "larger": [4, 5], "get_plan_for_modul": 4, "module_path": 4, "re": [4, 12], "stabil": 4, "table_column_wis": [4, 11], "get_tensor_size_byt": 4, "rank_devic": 4, "device_typ": 4, "scope": 4, "copyablemixin": 4, "mymodul": 4, "forkedpdb": 4, "completekei": 4, "tab": 4, "stdin": 4, "stdout": 4, "nosigint": 4, "readrc": 4, "pdb": 4, "multiprocess": 4, "child": 4, "debug": [4, 5, 9], "multiprocessing_util": 4, "set_trac": 4, "barrier": 4, "add_params_from_parameter_shard": 4, "parameter_shard": 4, "extract": 4, "ones": 4, "add_prefix_to_state_dict": 4, "filter": [4, 11], "append_prefix": 4, "append": 4, "convert_to_fbgemm_typ": 4, "copy_to_devic": 4, "current_devic": [4, 8], "to_devic": 4, "filter_state_dict": 4, "strip": 4, "begin": [4, 5, 12], "get_unsharded_module_nam": 4, "level": [4, 6], "don": [4, 8, 11], "merge_fused_param": 4, "param_fused_param": 4, "configur": 4, "cache_precis": 4, "preset": 4, "table_level_fused_param": 4, "precid": 4, "grouped_fused_param": 4, "null": 4, "none_throw": 4, "_t": 4, "messag": [4, 5], "unexpect": 4, "assertionerror": 4, "optimizer_type_to_emb_opt_typ": 4, "optimizer_class": 4, "emboptimtyp": 4, "sharded_model_copi": 4, "m_cpu": 4, "deepcopi": 4, "managedcollisioncollectionawait": 4, "managedcollisioncollectioncontext": 4, "managedcollisioncollectionshard": 4, "managedcollisioncollect": [4, 11], "shardedmanagedcollisioncollect": 4, "evict": [4, 11], "global_to_local_index": 4, "jt_dict": [4, 14], "open_slot": [4, 11], "create_mc_shard": 4, "managedcollisionembeddingbagcollectioncontext": 4, "evictions_per_t": 4, "remapped_kjt": 4, "managedcollisionembeddingbagcollectionshard": 4, "ebc_shard": 4, "mc_sharder": 4, "basemanagedcollisionembeddingcollectionshard": 4, "managedcollisionembeddingbagcollect": [4, 11], "shardedmanagedcollisionembeddingbagcollect": 4, "baseshardedmanagedcollisionembeddingcollect": 4, "managedcollisionembeddingcollectioncontext": 4, "managedcollisionembeddingcollectionshard": 4, "ec_shard": 4, "managedcollisionembeddingcollect": [4, 11], "shardedmanagedcollisionembeddingcollect": 4, "consid": [5, 11, 13, 14], "perf": 5, "storag": [5, 14], "peak": 5, "elimin": 5, "oom": [5, 9], "kernel_bw_lookup": 5, "compute_devic": [5, 8], "hbm_mem_bw": 5, "ddr_mem_bw": 5, "caching_ratio": 5, "calcul": [5, 9], "bandwidth": 5, "ratio": [5, 9], "embeddingenumer": 5, "parameterconstraint": [5, 8], "shardestim": 5, "use_exact_enumerate_ord": 5, "shardabl": 5, "exact": 5, "name_children": 5, "shardingopt": 5, "popul": [5, 11], "populate_estim": 5, "sharding_opt": 5, "descript": [5, 9], "get_partition_by_typ": 5, "partitionbytyp": 5, "greedyperfpartition": 5, "sort_bi": 5, "sortbi": 5, "balance_modul": 5, "greedi": 5, "sort": [5, 11], "smaller": 5, "effect": [5, 11], "storage_constraint": 5, "partition_bi": 5, "strategi": 5, "docstr": [5, 9, 14], "partition_by_devic": 5, "done": [5, 11, 12, 14], "clariti": 5, "memorybalancedpartition": 5, "max_search_count": 5, "toler": 5, "02": 5, "greedypartition": 5, "reject": 5, "200": 5, "wors": 5, "repeatedli": 5, "least": 5, "ordereddevicehardwar": 5, "devicehardwar": 5, "local_world_s": 5, "shardingoptiongroup": 5, "storage_sum": 5, "perf_sum": 5, "param_count": 5, "set_hbm_per_devic": 5, "hbm_per_devic": 5, "noopperfmodel": 5, "perfmodel": 5, "among": [5, 10], "without": [5, 9, 14], "noopstoragemodel": 5, "storagereserv": 5, "performance_model": 5, "heteroembeddingshardingplann": 5, "topology_group": 5, "dynamicprogrammingpropos": 5, "hbm_bins_per_devic": 5, "dynam": 5, "fashion": [5, 6], "problem": 5, "frame": 5, "n": [5, 8, 10, 11, 14], "minim": 5, "overal": [5, 10], "k": [5, 9, 10, 11], "mathemat": [5, 9], "formul": 5, "matrix": [5, 10, 11], "let": 5, "element": [5, 11], "denot": 5, "a_": 5, "j": 5, "b_": 5, "aim": 5, "j_0": 5, "j_1": 5, "ldot": 5, "j_": 5, "condit": [5, 11], "satisfi": 5, "sum_": 5, "j_i": 5, "leq": 5, "tackl": 5, "discret": 5, "k_i": 5, "transit": 5, "min_": 5, "left": [5, 14], "right": [5, 9, 11], "simpli": 5, "fit": 5, "card": 5, "therefor": 5, "maintain": 5, "last": [5, 10, 11, 14], "layer": [5, 10, 11, 12], "under": [5, 9], "vari": 5, "hdm": 5, "bin": 5, "perf_rat": 5, "search_spac": 5, "search": 5, "embeddingoffloadscaleuppropos": 5, "use_depth": 5, "allocate_budget": 5, "budget": 5, "allocation_prior": 5, "build_affine_storage_model": 5, "uvm_caching_sharding_opt": 5, "clf_to_byt": 5, "get_budget": 5, "get_cach": 5, "get_expected_lookup": 5, "next_plan": 5, "starting_propos": 5, "promote_high_prefetch_overheaad_table_to_hbm": 5, "overhead": 5, "io": 5, "offload": 5, "undo": 5, "promot": 5, "greedypropos": 5, "threshold": [5, 9, 11], "On": [5, 11], "tri": [5, 12], "earli": 5, "stop": 5, "consecut": 5, "best_perf_r": 5, "gridsearchpropos": 5, "max_propos": 5, "10000": 5, "uniformpropos": 5, "proposers_to_proposals_list": 5, "proposers_list": 5, "static_feedback": 5, "embeddingoffloadstat": 5, "mrc_hist_count": 5, "height": 5, "uvm_fused_cach": 5, "cachebl": 5, "area": [5, 9], "curv": [5, 9], "histogram": 5, "nth": 5, "wa": [5, 8], "estimate_cache_miss_r": 5, "cache_s": 5, "hist": 5, "mrc": 5, "embeddingperfestim": 5, "is_infer": 5, "wall": 5, "sharder_map": 5, "perf_func_emb_wall_tim": 5, "shard_siz": 5, "input_length": 5, "input_data_type_s": 5, "table_data_type_s": 5, "output_data_type_s": 5, "fwd_a2a_comm_data_type_s": 5, "bwd_a2a_comm_data_type_s": 5, "fwd_sr_comm_data_type_s": 5, "bwd_sr_comm_data_type_s": 5, "num_pool": 5, "intra_host_bw": 5, "inter_host_bw": 5, "bwd_compute_multipli": 5, "weighted_feature_bwd_compute_multipli": 5, "is_pool": 5, "expected_cache_fetch": 5, "uneven_sharding_perf_multipli": 5, "attempt": 5, "rel": [5, 11], "tw": 5, "queri": 5, "fwd_comm_data_type_s": 5, "bwd_comm_data_type_s": 5, "machin": [5, 11], "embeddingbag": [5, 7, 10, 11, 13], "unpool": 5, "ebc": [5, 8, 10, 11, 13], "signifi": 5, "fetch": 5, "embeddingstorageestim": 5, "pipeline_typ": 5, "run_embedding_at_peak_memori": 5, "Will": [5, 9], "fwd": 5, "bwd": 5, "temporari": [5, 11], "toward": 5, "cost": [5, 11], "won": [5, 11], "ll": 5, "hidden": [5, 10, 11], "old": [5, 12], "agnost": 5, "forwrad": 5, "calculate_pipeline_io_cost": 5, "output_s": [5, 11], "prefetch_s": 5, "multipass_prefetch_max_pass": 5, "count_ephemeral_storage_cost": 5, "calculate_shard_storag": 5, "compris": 5, "synonym": 5, "byte": [5, 8, 9], "embeddingstat": 5, "sharding_plan": 5, "num_propos": 5, "num_plan": 5, "run_tim": 5, "best_plan": 5, "tabular": 5, "view": 5, "chosen": [5, 11], "evalu": [5, 11], "successfulli": 5, "noopembeddingstat": 5, "noop": 5, "round_to_one_sigfig": 5, "fixedpercentagestoragereserv": 5, "percentag": 5, "heuristicalstoragereserv": 5, "parameter_multipli": 5, "dense_tensor_estim": 5, "heurist": 5, "extra": 5, "percent": 5, "act": 5, "margin": 5, "error": [5, 9, 11, 14], "beyond": 5, "inferencestoragereserv": 5, "customtopologydata": 5, "get_data": 5, "has_data": 5, "supported_field": 5, "ddr_cap": 5, "hbm_cap": 5, "512": [5, 9], "min_partit": 5, "pooling_factor": 5, "fbgemm_gpu": 5, "split_table_batched_embeddings_ops_common": 5, "device_group": 5, "around": 5, "lower": [5, 7, 8, 12, 13], "divid": [5, 9], "divis": 5, "optionallist": 5, "momentum": 5, "accuraci": [5, 11], "term": [5, 11], "fp16": 5, "exce": 5, "todai": 5, "bldm": 5, "fwd_comput": 5, "fwd_comm": 5, "bwd_comput": 5, "bwd_comm": 5, "prefetch_comput": 5, "breakdown": 5, "plannererror": 5, "error_typ": 5, "plannererrortyp": 5, "classifi": 5, "insufficient_storag": 5, "strict_constraint": 5, "prospos": 5, "paritit": 5, "much": [5, 12], "depend": [5, 8, 11], "One": [5, 9, 11], "eval": 5, "job": 5, "tower": [5, 11], "cache_load_factor": 5, "module_pool": 5, "sharding_option_nam": 5, "num_input": 5, "num_shard": 5, "total_perf": 5, "total_storag": 5, "capac": 5, "hardwar": 5, "fits_in": 5, "963146416": 5, "128": [5, 9], "54760833": 5, "024": 5, "644245094": 5, "13421772": 5, "custom_topology_data": 5, "binarysearchpred": 5, "extern": [5, 10], "predic": 5, "discov": 5, "try": 5, "prior_result": 5, "probe": 5, "prior": 5, "explor": 5, "reach": [5, 9], "luusjaakolasearch": 5, "max_iter": 5, "42": 5, "left_cost": 5, "clamp": 5, "variant": 5, "luu": 5, "jaakola": 5, "en": 5, "wikipedia": 5, "wiki": 5, "far": 5, "associ": 5, "fy": 5, "y": [5, 10, 11], "previou": 5, "subsequ": 5, "been": [5, 11], "shrink_right": 5, "shrink": 5, "boundari": 5, "infin": [5, 12], "random": [5, 10], "bytes_to_gb": 5, "num_byt": 5, "bytes_to_mb": 5, "gb_to_byt": 5, "gb": 5, "local_s": [5, 6], "prod": 5, "reset_shard_rank": 5, "sharder_nam": 5, "storage_repr_in_gb": 5, "basecwembeddingshard": 6, "basetwembeddingshard": 6, "cwpooledembeddingshard": 6, "infercwpooledembeddingdist": 6, "infercwpooledembeddingdistwithpermut": 6, "infercwpooledembeddingshard": 6, "basedpembeddingshard": 6, "dppooledembeddingdist": 6, "dppooledembeddingshard": 6, "dpsparsefeaturesdist": 6, "sparsefeatur": 6, "baserwembeddingshard": 6, "inferrwpooledembeddingdist": 6, "inferrwpooledembeddingshard": 6, "inferrwsparsefeaturesdist": 6, "rwpooledembeddingdist": 6, "share": [6, 11], "rwpooledembeddingshard": 6, "evenli": 6, "rwsparsefeaturesdist": 6, "intra_pg": 6, "get_block_sizes_runtime_devic": 6, "runtime_devic": 6, "tensor_cach": 6, "get_embedding_shard_metadata": 6, "grouped_embedding_configs_per_rank": 6, "infertwembeddingshard": 6, "infertwpooledembeddingdist": 6, "infertwsparsefeaturesdist": 6, "twpooledembeddingdist": 6, "twpooledembeddingshard": 6, "twsparsefeaturesdist": 6, "twcwpooledembeddingshard": 6, "basetwrwembeddingshard": 6, "twrwpooledembeddingdist": 6, "cross_pg": 6, "dim_sum_per_nod": 6, "emb_dim_per_node_per_featur": 6, "twrwpooledembeddingshard": 6, "twrwsparsefeaturesdist": 6, "id_list_features_per_rank": 6, "id_score_list_features_per_rank": 6, "id_list_feature_hash_s": 6, "id_score_list_feature_hash_s": 6, "look": [6, 7, 14], "reorder": 6, "document": [7, 10], "leaf_modul": 7, "trace": [7, 8], "torchscript": 7, "create_arg": 7, "memory_format": 7, "opoverload": 7, "symint": 7, "symbool": 7, "symfloat": 7, "prepar": [7, 11], "graph": 7, "emit": 7, "appropri": 7, "is_leaf_modul": 7, "module_qualified_nam": 7, "path_of_modul": 7, "mod": 7, "abil": 7, "made": [7, 12], "concrete_arg": 7, "is_fx_trac": 7, "symbolic_trac": 7, "graphmodul": 7, "symbol": 7, "record": [7, 11], "partial": 7, "structur": [7, 12], "predictfactorypackag": 8, "save_predict_factori": 8, "predict_factori": 8, "predictfactori": 8, "config": [8, 9, 10, 11], "pathlib": 8, "binaryio": 8, "extra_fil": 8, "loader_cod": 8, "nimport": 8, "packag": 8, "nmodule_factori": 8, "package_import": 8, "_sysimport": 8, "set_extern_modul": 8, "decor": 8, "abstractmethod": 8, "set_mocked_modul": 8, "load_config_text": 8, "load_pickle_config": 8, "clazz": 8, "batchingmetadata": 8, "pin": 8, "kept": [8, 11], "sync": [8, 9, 14], "batching_metadata": 8, "infom": 8, "batching_metadata_json": 8, "serial": 8, "json": 8, "eas": [8, 11], "pars": 8, "create_predict_modul": 8, "transformmodul": 8, "transform_state_dict": 8, "init_process_group": 8, "model_inputs_data": 8, "benchmark": 8, "qualname_metadata": 8, "qualnamemetadata": 8, "qualnam": 8, "inform": [8, 9, 14], "qualname_metadata_json": 8, "result_metadata": 8, "run_weights_dependent_transform": 8, "predict_modul": 8, "predict": [8, 9], "run_weights_independent_tranform": 8, "fx": 8, "predictmodul": 8, "predict_forward": 8, "primari": 8, "need_preproc": 8, "assign_weights_to_tb": 8, "table_to_weight": 8, "get_table_to_weights_from_tb": 8, "quantize_dens": 8, "additional_embedding_module_typ": 8, "quantize_embed": 8, "inplac": [8, 13], "additional_qconfig_spec_kei": 8, "additional_map": 8, "per_table_weight_dtyp": [8, 11], "quantize_featur": 8, "quantize_inference_model": 8, "quantization_map": 8, "fp_weight_dtyp": 8, "quantization_dtyp": 8, "swap": 8, "counterpart": 8, "quantembeddingbagcollect": [8, 13], "quantembeddingcollect": 8, "eb_config": 8, "dlrmpredictmodul": 8, "embedding_bag_collect": [8, 10, 11], "dense_in_featur": [8, 10], "model_config": 8, "dense_arch_layer_s": [8, 10], "over_arch_layer_s": [8, 10], "id_list_features_kei": 8, "dense_devic": [8, 10], "quant_model": 8, "set_pruning_data": 8, "tables_to_rows_post_prun": 8, "shard_quant_model": 8, "sharding_devic": 8, "device_memory_s": 8, "quantembeddingcollectionshard": 8, "tablewis": 8, "sharded_model": 8, "trim_torch_package_prefix_from_typenam": 8, "typenam": 8, "accuracymetr": 9, "task": 9, "rectaskinfo": 9, "compute_mod": 9, "reccomputemod": 9, "unfused_tasks_comput": 9, "window_s": 9, "fused_update_limit": 9, "compute_on_all_rank": 9, "should_validate_upd": 9, "process_group": 9, "recmetr": 9, "accuracymetriccomput": 9, "recmetriccomput": 9, "constructor": [9, 11], "cut": [9, 11], "off": [9, 11], "compute_accuraci": 9, "accuracy_sum": 9, "weighted_num_sampl": 9, "compute_accuracy_sum": 9, "get_accuracy_st": 9, "aucmetr": 9, "aucmetriccomput": 9, "grouped_auc": 9, "apply_bin": 9, "grouping_kei": 9, "reset": [9, 11, 12], "n_task": 9, "n_exampl": 9, "compute_auc": 9, "classif": 9, "compute_auc_per_group": 9, "auprcmetr": 9, "auprcmetriccomput": 9, "grouped_auprc": 9, "pr": 9, "compute_auprc": 9, "compute_auprc_per_group": 9, "calibrationmetr": 9, "calibrationmetriccomput": 9, "convers": 9, "compute_calibr": 9, "calibration_num": 9, "calibration_denom": 9, "get_calibration_st": 9, "ctrmetric": 9, "ctrmetriccomput": 9, "compute_ctr": 9, "ctr_num": 9, "ctr_denom": 9, "get_ctr_stat": 9, "maemetr": 9, "maemetriccomput": 9, "absolut": 9, "compute_error_sum": 9, "compute_ma": 9, "error_sum": 9, "get_mae_st": 9, "msemetr": 9, "msemetriccomput": 9, "squar": [9, 11], "compute_ms": 9, "compute_rms": 9, "get_mse_st": 9, "multiclassrecallmetr": 9, "multiclassrecallmetriccomput": 9, "compute_multiclass_recall_at_k": 9, "tp_at_k": 9, "total_weight": 9, "compute_true_positives_at_k": 9, "n_class": 9, "tp": 9, "1st": 9, "2nd": [9, 11], "n_sampl": 9, "ground": 9, "truth": 9, "true_positives_list": 9, "9": [9, 10], "15": [9, 10], "compute_multiclass_k_sum": 9, "5000": 9, "7500": 9, "0000": [9, 11], "get_multiclass_recall_st": 9, "ndcgcomput": 9, "exponential_gain": 9, "session_kei": 9, "session_id": 9, "report_ndcg_as_decreasing_curv": 9, "remove_single_length_sess": 9, "scale_by_weights_tensor": 9, "is_negative_task_mask": 9, "normal": [9, 11], "discount": 9, "gain": 9, "tensorboard": 9, "captur": 9, "decreas": 9, "loss": [9, 10, 12], "oppos": 9, "visual": [9, 14], "similarli": 9, "entropi": 9, "pointwis": 9, "noth": 9, "ndcgmetric": 9, "nemetr": 9, "nemetriccomput": 9, "include_logloss": 9, "allow_missing_label_with_zero_weight": 9, "vanilla": 9, "logloss": 9, "compute_cross_entropi": 9, "eta": 9, "compute_logloss": 9, "ce_sum": 9, "pos_label": 9, "neg_label": 9, "compute_n": 9, "get_ne_st": 9, "recallmetr": 9, "recallmetriccomput": 9, "compute_false_neg_sum": 9, "compute_recal": 9, "num_true_posit": 9, "num_false_negit": 9, "compute_true_pos_sum": 9, "get_recall_st": 9, "precisionmetr": 9, "precisionmetriccomput": 9, "compute_false_pos_sum": 9, "compute_precis": 9, "num_false_posit": 9, "get_precision_st": 9, "raucmetr": 9, "raucmetriccomput": 9, "grouped_rauc": 9, "regress": 9, "compute_rauc": 9, "compute_rauc_per_group": 9, "conquer_and_count": 9, "left_index": 9, "mid_index": 9, "right_index": 9, "count_reverse_pairs_divide_and_conqu": 9, "low": [9, 10, 11], "throughputmetr": 9, "window_second": 9, "warmup_step": 9, "batch_size_stag": 9, "batchsizestag": 9, "32": [9, 11], "time_to_train_one_step": 9, "trainer": 9, "window": 9, "window_throughput": 9, "warmup": 9, "Not": 9, "weightedavgmetr": 9, "weightedavgmetriccomput": 9, "get_mean": 9, "value_sum": 9, "num_sampl": 9, "xaucmetr": 9, "xaucmetriccomput": 9, "compute_weighted_num_pair": 9, "compute_xauc": 9, "weighted_num_pair": 9, "get_xauc_st": 9, "recmetricmodul": 9, "rec_task": 9, "recmetriclist": 9, "throughput_metr": 9, "state_metr": 9, "statemetr": 9, "compute_interval_step": 9, "min_compute_interv": 9, "max_compute_interv": 9, "inf": [9, 12], "memory_usage_limit_mb": 9, "standalon": 9, "characterist": 9, "componenet": 9, "intern": [9, 11, 14], "logic": [9, 11], "unit": [9, 11], "dataclass": 9, "defaultmetricsconfig": 9, "statemetricenum": 9, "metricmodul": 9, "generate_metric_modul": 9, "metric_class": 9, "metrics_config": 9, "64": [9, 11], "state_metrics_map": 9, "mock_optim": 9, "check_memory_usag": 9, "compute_count": 9, "sink": 9, "get_memory_usag": 9, "get_required_input": 9, "last_compute_tim": 9, "local_comput": 9, "memory_usage_mb_avg": 9, "oom_count": 9, "should_comput": 9, "unsync": [9, 14], "model_out": 9, "model_output": 9, "due": 9, "slide": 9, "qat": 9, "get_metr": 9, "metricsconfig": 9, "metriccomputationreport": 9, "metrics_namespac": 9, "metricnamebas": 9, "metric_prefix": 9, "metricprefix": 9, "signal": 9, "own": 9, "__init__": 9, "_namespac": 9, "_metrics_comput": 9, "consum": 9, "invalid": 9, "defaulttaskinfo": 9, "rec": 9, "overwrit": 9, "synchron": 9, "get_window_st": 9, "state_nam": 9, "get_window_state_nam": 9, "pre_comput": 9, "torchmetr": 9, "aggreg": 9, "recmetricexcept": 9, "encapul": 9, "required_input": 9, "windowbuff": 9, "max_siz": 9, "max_buffer_count": 9, "aggregate_st": 9, "window_st": 9, "curr_stat": 9, "dequ": 9, "architectur": [10, 11], "deep": [10, 11], "sparsearch": 10, "densearch": 10, "interactionarch": 10, "overarch": 10, "found": 10, "notat": 10, "embedding_dimens": 10, "hidden_layer_s": 10, "deepfmnn": 10, "dimension": 10, "dense_arch": 10, "dense_arch_input": 10, "dense_embed": 10, "fminteractionarch": 10, "fm_in_featur": 10, "sparse_feature_nam": 10, "deep_fm_dimens": 10, "paper": [10, 11], "arxiv": 10, "pdf": 10, "1703": 10, "04247": 10, "cat": [10, 11], "dense_modul": [10, 11], "di": 10, "arch": 10, "fm_inter_arch": 10, "length_per_kei": [10, 14], "cat_fm_output": 10, "mlp": 10, "over_arch": 10, "logit": 10, "simpledeepfmnn": 10, "num_dense_featur": 10, "relationship": 10, "deep_fm": 10, "eb1_config": [10, 13], "f3": 10, "eb2_config": [10, 13], "t2": [10, 11, 13], "sparse_nn": 10, "over_embedding_dim": 10, "from_offsets_sync": [10, 11, 13, 14], "sparse_arch": 10, "ab": 10, "1906": 10, "00091": 10, "pairwis": 10, "dlrmtrain": 10, "dlrm_modul": 10, "train_pipelin": 10, "dlrm_project": 10, "dlrm_dcn": 10, "ebc_config": 10, "dlrm_model": 10, "dcn_num_lay": 10, "dcn_low_rank_dim": 10, "dcn": [10, 11], "modifi": [10, 11, 12], "similar": 10, "deepcrossnet": 10, "2008": 10, "13535": 10, "approxim": 10, "interaction_branch1_layer_s": 10, "interaction_branch2_layer_s": 10, "branch": 10, "layer_s": [10, 11], "num_sparse_featur": 10, "dot": [10, 11], "pair": 10, "inter_arch": 10, "choos": 10, "concat_dens": 10, "interactiondcnarch": 10, "crossnet": 10, "lowrankcrossnet": [10, 11], "dnc_low_rank_dim": 10, "interactionprojectionarch": 10, "interaction_branch1": 10, "interaction_branch2": 10, "z": 10, "bx": 10, "f1xd": 10, "dxf2": 10, "i1": 10, "i2": 10, "sparse_embed": 10, "math": 10, "comb": 10, "extens": 11, "establish": 11, "pattern": 11, "swishlayernorm": 11, "positionweightedmodul": 11, "lazymoduleextensionmixin": 11, "embeddingtow": 11, "embeddingtowercollect": 11, "input_dim": 11, "swish": 11, "sigmoid": 11, "layernorm": 11, "d1": 11, "d2": 11, "d3": 11, "sln": 11, "num_lay": 11, "stack": 11, "learnabl": 11, "polynom": 11, "nxn": 11, "cover": 11, "bit": 11, "x_": 11, "x_0": 11, "w_l": 11, "cdot": 11, "x_l": 11, "b_l": 11, "low_rank": 11, "highli": 11, "matric": 11, "simplifi": 11, "v_l": 11, "vector": 11, "smartli": 11, "setup": 11, "lowrankmixturecrossnet": 11, "num_expert": 11, "relu": 11, "mixtur": 11, "expert": 11, "compar": [11, 14], "subspac": 11, "gate": 11, "moe": 11, "expert_i": 11, "k_": 11, "u_": 11, "li": 11, "c_": 11, "v_": 11, "vectorcrossnet": 11, "nx1": 11, "thu": [11, 12], "further": [11, 14], "implent": 11, "framework": 11, "factorizationmachin": 11, "fm": 11, "publish": 11, "learnt": 11, "To": 11, "90": 11, "30": 11, "40": 11, "fb": 11, "lazymlp": 11, "output_dim": 11, "192": 11, "deep_fm_output": 11, "common_spars": 11, "specialized_spars": 11, "embedding_featur": 11, "raw_embedding_featur": 11, "nativ": 11, "trained_embed": 11, "native_embed": 11, "ident": 11, "mention": 11, "baseembeddingconfig": 11, "get_weight_init_max": 11, "get_weight_init_min": 11, "embeddingconfig": [11, 13], "quantconfig": 11, "placeholderobserv": [11, 13], "alia": 11, "data_type_to_dtyp": 11, "data_type_to_sparse_typ": 11, "sparsetyp": 11, "dtype_to_data_typ": 11, "pooling_type_to_pooling_mod": 11, "pooling_typ": 11, "poolingmod": 11, "pooling_type_to_str": 11, "sensit": [11, 13], "jag": [11, 13, 14], "table_0": [11, 13], "table_1": [11, 13], "pooled_embed": 11, "8899": 11, "1342": 11, "9060": 11, "0905": 11, "2814": 11, "9369": 11, "7783": 11, "1598": 11, "0695": 11, "3265": 11, "1011": 11, "4256": 11, "1846": 11, "1648": 11, "0893": 11, "3590": 11, "9784": 11, "7681": 11, "grad_fn": [11, 13], "catbackward0": 11, "offset_per_kei": [11, 14], "need_indic": [11, 13], "e1_config": [11, 13], "e2_config": [11, 13], "ec": [11, 13], "feature_embed": [11, 13], "2050": [11, 13], "5478": [11, 13], "6054": [11, 13], "7352": [11, 13], "3210": [11, 13], "0399": [11, 13], "1279": [11, 13], "1756": [11, 13], "4130": [11, 13], "7519": [11, 13], "4341": [11, 13], "0499": [11, 13], "9329": [11, 13], "0697": [11, 13], "8095": [11, 13], "embeddingbackward": [11, 13], "embedding_names_by_t": [11, 13], "get_embedding_names_by_t": 11, "process_pooled_embed": 11, "reorder_inverse_indic": 11, "basefeatureprocessor": 11, "max_length": 11, "truncat": 11, "positionweightedprocessor": 11, "feature_length": 11, "feature0": [11, 14], "feature1": [11, 14], "feature2": 11, "from_lengths_sync": [11, 14], "pw": 11, "featureprocessorcollect": 11, "feature_processor_modul": 11, "positionweightedfeatureprocessor": 11, "fp_featur": 11, "non_fp_featur": 11, "non_fp": 11, "feature_process": 11, "And": 11, "offsets_to_range_tracebl": 11, "position_weighted_module_update_featur": 11, "weighted_featur": 11, "lazymodulemixin": 11, "upstream": 11, "59923": 11, "testlazymoduleextensionmixin": 11, "_infer_paramet": 11, "pariti": 11, "_call_impl": 11, "fn": 11, "children": 11, "uniniti": 11, "dummi": [11, 12], "lazylinear": 11, "fail": [11, 14], "hasn": 11, "yet": 11, "now": [11, 14], "lazy_appli": 11, "attach": 11, "numer": 11, "immedi": 11, "seq": 11, "in_siz": 11, "perceptron": 11, "multi": 11, "out_siz": 11, "swish_layernorm": 11, "mlp_modul": 11, "assert": 11, "o": 11, "channel": 11, "unpadded_length": 11, "reindexed_length": 11, "reindexed_length_per_kei": 11, "reindexed_valu": 11, "check_module_output_dimens": 11, "verifi": 11, "construct_jagged_tensor": 11, "features_to_permute_indic": 11, "original_featur": 11, "construct_jagged_tensors_infer": 11, "construct_modulelist_from_single_modul": 11, "nest": 11, "reiniti": 11, "convert_list_of_modules_to_modulelist": 11, "deterministic_dedup": 11, "race": 11, "conflict": 11, "extract_module_or_tensor_cal": 11, "module_or_cal": 11, "get_module_output_dimens": 11, "init_mlp_weights_xavier_uniform": 11, "jagged_index_select_with_empti": 11, "output_offset": 11, "distancelfu_evictionpolici": 11, "decay_expon": 11, "threshold_filtering_func": 11, "mchevictionpolici": 11, "coalesce_history_metadata": 11, "current_it": 11, "history_metadata": 11, "unique_ids_count": 11, "unique_inverse_map": 11, "additional_id": 11, "threshold_mask": 11, "histori": 11, "invers": [11, 14], "history_accumul": 11, "coalesc": 11, "metadata_info": 11, "mchevictionpolicymetadatainfo": 11, "record_history_metadata": 11, "incoming_id": 11, "incom": 11, "polici": [11, 12], "update_metadata_and_generate_eviction_scor": 11, "mch_size": 11, "coalesced_history_argsort_map": 11, "coalesced_history_sorted_unique_ids_count": 11, "coalesced_history_mch_matching_elements_mask": 11, "coalesced_history_mch_matching_indic": 11, "mch_metadata": 11, "coalesced_history_metadata": 11, "evicted_indic": 11, "selected_new_indic": 11, "mch": 11, "lfu_evictionpolici": 11, "lru_evictionpolici": 11, "metadata_nam": 11, "is_mch_metadata": 11, "is_history_metadata": 11, "mchmanagedcollisionmodul": 11, "zch_size": 11, "eviction_polici": 11, "eviction_interv": 11, "input_hash_s": 11, "9223372036854775807": 11, "input_hash_func": 11, "mch_hash_func": 11, "output_global_offset": 11, "output_seg": 11, "managedcollisionmodul": 11, "zch": 11, "collis": 11, "output_size_offset": 11, "drive": 11, "greater": 11, "depreci": 11, "residu": 11, "legaci": 11, "shift": 11, "zch_output_rang": 11, "down": 11, "applic": 11, "slot": 11, "assumptionn": 11, "downstream": 11, "rtype": 11, "vs": 11, "profil": 11, "rebuild_with_output_id_rang": 11, "output_id_rang": 11, "mc": 11, "validate_st": 11, "checkpoint": [11, 12], "managed_collision_modul": 11, "need_preprocess": 11, "mcc": 11, "embedding_confg": 11, "collsion": 11, "skip_state_valid": 11, "max_output_id": 11, "remapping_range_start_index": 11, "mcm": 11, "mcm_jt": 11, "fp": 11, "apply_mc_method_to_jt_dict": 11, "features_dict": 11, "average_threshold_filt": 11, "id_count": 11, "dynamic_threshold_filt": 11, "threshold_skew_multipli": 11, "total_count": 11, "num_id": 11, "probabilistic_threshold_filt": 11, "per_id_prob": 11, "01": 11, "probabl": 11, "60": 11, "randomli": 11, "chanc": 11, "basemanagedcollisionembeddingcollect": 11, "managed_collision_collect": 11, "return_remapped_featur": 11, "embedding_collect": 11, "meaning": 12, "prohibit": 12, "empti": [12, 14], "sever": 12, "combinedoptim": 12, "optimizerwrapp": 12, "rowwis": 12, "gradientclip": 12, "norm": 12, "gradientclippingoptim": 12, "max_gradi": 12, "norm_typ": 12, "p": 12, "closur": 12, "reevalu": 12, "emptyfusedoptim": 12, "fusedoptim": 12, "zero_grad": 12, "set_to_non": 12, "zero": [12, 14], "footprint": 12, "modestli": 12, "certain": 12, "0s": 12, "behav": 12, "did": 12, "altogeth": 12, "param_group": 12, "meant": 12, "post_load_state_dict": 12, "prepend_opt_kei": 12, "opt_kei": 12, "save_param_group": 12, "set_optimizer_step": 12, "stricter": 12, "switch": 12, "flag": 12, "identifi": 12, "littl": 12, "add_param_group": 12, "fine": 12, "frozen": 12, "trainabl": 12, "progress": 12, "what": 12, "init_st": 12, "introduc": 12, "usabl": 12, "sd": 12, "load_checkpoint": 12, "protocol": 12, "keyedoptimizerwrapp": 12, "optim_factori": 12, "conveni": 12, "warmupoptim": 12, "warmupstag": 12, "lr": 12, "lr_param": 12, "param_nam": 12, "__warmup": 12, "adjust": 12, "schedul": 12, "fake": 12, "warmuppolici": 12, "constant": 12, "cosine_annealing_warm_restart": 12, "invsqrt": 12, "inv_sqrt": 12, "poli": 12, "max_it": 12, "lr_scale": 12, "decay_it": 12, "sgdr_period": 12, "trec_quant": 13, "trec": 13, "qconfig": 13, "activ": 13, "with_arg": 13, "qint8": 13, "quantize_dynam": 13, "qconfig_spec": 13, "table_name_to_quantized_weight": 13, "register_tb": 13, "quant_state_dict_split_scale_bia": 13, "row_align": 13, "qebc": 13, "from_float": 13, "quantized_embed": 13, "use_precomputed_fake_qu": 13, "for_each_module_of_type_do": 13, "quant_prep_customize_row_align": 13, "quant_prep_enable_quant_state_dict_split_scale_bia": 13, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 13, "quant_prep_enable_register_tb": 13, "quantize_state_dict": 13, "table_name_to_data_typ": 13, "table_name_to_num_embeddings_post_prun": 13, "whose": 14, "dimes": 14, "computejtdicttokjt": 14, "dim_1": 14, "dim_0": 14, "computekjttojtdict": 14, "keyed_jagged_tensor": 14, "jit": 14, "abl": 14, "NOT": 14, "expens": 14, "values_dtyp": 14, "weights_dtyp": 14, "lengths_dtyp": 14, "from_dens": 14, "2d": 14, "11": 14, "12": 14, "j1": 14, "from_dense_length": 14, "lengths_or_non": 14, "offsets_or_non": 14, "to_dens": 14, "inttensor": 14, "values_list": 14, "to_dense_weight": 14, "weights_list": 14, "to_padded_dens": 14, "desired_length": 14, "padding_valu": 14, "longest": 14, "pad": 14, "dt": 14, "to_padded_dense_weight": 14, "d_wt": 14, "weights_or_non": 14, "jaggedtensormeta": 14, "abcmeta": 14, "proxyableclassmeta": 14, "stride_per_key_per_rank": 14, "outer": 14, "inner": 14, "index_per_kei": 14, "expand": 14, "dedupl": 14, "dim_2": 14, "w0": 14, "w1": 14, "w2": 14, "w3": 14, "w4": 14, "w5": 14, "w6": 14, "w7": 14, "dist_init": 14, "variable_stride_per_kei": 14, "dist_label": 14, "dist_split": 14, "key_split": 14, "dist_tensor": 14, "empty_lik": 14, "flatten_length": 14, "from_jt_dict": 14, "implicit": 14, "variable_feature_dim": 14, "But": 14, "That": 14, "didn": 14, "notic": 14, "correctli": 14, "technic": 14, "know": 14, "violat": 14, "precondit": 14, "inverse_indices_or_non": 14, "length_per_key_or_non": 14, "lengths_offset_per_kei": 14, "offset_per_key_or_non": 14, "indices_tensor": 14, "segment": 14, "stride_per_kei": 14, "to_dict": 14, "key_dim": 14, "tensor_list": 14, "from_tensor_list": 14, "regroup": 14, "keyed_tensor": 14, "regroup_as_dict": 14, "flatten_kjt_list": 14, "kjt_arr": 14, "jt_is_equ": 14, "jt_1": 14, "jt_2": 14, "comparison": 14, "themselv": 14, "treat": 14, "kjt_is_equ": 14, "kjt_1": 14, "kjt_2": 14, "permute_multi_embed": 14, "regroup_kt": 14, "unflatten_kjt_list": 14}, "objects": {"torchrec": [[2, 0, 0, "-", "datasets"], [4, 0, 0, "-", "distributed"], [7, 0, 0, "module-0", "fx"], [8, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "metrics"], [10, 0, 0, "module-0", "models"], [11, 0, 0, "-", "modules"], [12, 0, 0, "module-0", "optim"], [13, 0, 0, "module-0", "quant"], [14, 0, 0, "module-0", "sparse"]], "torchrec.datasets": [[2, 0, 0, "-", "criteo"], [2, 0, 0, "-", "movielens"], [2, 0, 0, "-", "random"], [3, 0, 0, "-", "scripts"], [2, 0, 0, "-", "utils"]], "torchrec.datasets.criteo": [[2, 1, 1, "", "BinaryCriteoUtils"], [2, 1, 1, "", "CriteoIterDataPipe"], [2, 1, 1, "", "InMemoryBinaryCriteoIterDataPipe"], [2, 3, 1, "", "criteo_kaggle"], [2, 3, 1, "", "criteo_terabyte"]], "torchrec.datasets.criteo.BinaryCriteoUtils": [[2, 2, 1, "", "get_file_row_ranges_and_remainder"], [2, 2, 1, "", "get_shape_from_npy"], [2, 2, 1, "", "load_npy_range"], [2, 2, 1, "", "shuffle"], [2, 2, 1, "", "sparse_to_contiguous"], [2, 2, 1, "", "tsv_to_npys"]], "torchrec.datasets.movielens": [[2, 3, 1, "", "movielens_20m"], [2, 3, 1, "", "movielens_25m"]], "torchrec.datasets.random": [[2, 1, 1, "", "RandomRecDataset"]], "torchrec.datasets.scripts": [[3, 0, 0, "-", "contiguous_preproc_criteo"], [3, 0, 0, "-", "npy_preproc_criteo"]], "torchrec.datasets.scripts.contiguous_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.scripts.npy_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.utils": [[2, 1, 1, "", "Batch"], [2, 1, 1, "", "Limit"], [2, 1, 1, "", "LoadFiles"], [2, 1, 1, "", "ParallelReadConcat"], [2, 1, 1, "", "ReadLinesFromCSV"], [2, 3, 1, "", "idx_split_train_val"], [2, 3, 1, "", "rand_split_train_val"], [2, 3, 1, "", "safe_cast"], [2, 3, 1, "", "train_filter"], [2, 3, 1, "", "val_filter"]], "torchrec.datasets.utils.Batch": [[2, 4, 1, "", "dense_features"], [2, 4, 1, "", "labels"], [2, 2, 1, "", "pin_memory"], [2, 2, 1, "", "record_stream"], [2, 4, 1, "", "sparse_features"], [2, 2, 1, "", "to"]], "torchrec.distributed": [[4, 0, 0, "-", "collective_utils"], [4, 0, 0, "-", "comm"], [4, 0, 0, "-", "comm_ops"], [6, 0, 0, "-", "dist_data"], [4, 0, 0, "-", "embedding"], [4, 0, 0, "-", "embedding_lookup"], [4, 0, 0, "-", "embedding_sharding"], [4, 0, 0, "-", "embedding_types"], [4, 0, 0, "-", "embeddingbag"], [4, 0, 0, "-", "grouped_position_weighted"], [4, 0, 0, "-", "mc_embedding"], [4, 0, 0, "-", "mc_embeddingbag"], [4, 0, 0, "-", "mc_modules"], [4, 0, 0, "-", "model_parallel"], [5, 0, 0, "-", "planner"], [4, 0, 0, "-", "quant_embeddingbag"], [6, 0, 0, "-", "sharding"], [4, 0, 0, "-", "train_pipeline"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[4, 3, 1, "", "invoke_on_rank_and_broadcast_result"], [4, 3, 1, "", "is_leader"], [4, 3, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[4, 3, 1, "", "get_group_rank"], [4, 3, 1, "", "get_local_rank"], [4, 3, 1, "", "get_local_size"], [4, 3, 1, "", "get_num_groups"], [4, 3, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[4, 1, 1, "", "All2AllDenseInfo"], [4, 1, 1, "", "All2AllPooledInfo"], [4, 1, 1, "", "All2AllSequenceInfo"], [4, 1, 1, "", "All2AllVInfo"], [4, 1, 1, "", "All2All_Pooled_Req"], [4, 1, 1, "", "All2All_Pooled_Wait"], [4, 1, 1, "", "All2All_Seq_Req"], [4, 1, 1, "", "All2All_Seq_Req_Wait"], [4, 1, 1, "", "All2Allv_Req"], [4, 1, 1, "", "All2Allv_Wait"], [4, 1, 1, "", "AllGatherBaseInfo"], [4, 1, 1, "", "AllGatherBase_Req"], [4, 1, 1, "", "AllGatherBase_Wait"], [4, 1, 1, "", "ReduceScatterBaseInfo"], [4, 1, 1, "", "ReduceScatterBase_Req"], [4, 1, 1, "", "ReduceScatterBase_Wait"], [4, 1, 1, "", "ReduceScatterInfo"], [4, 1, 1, "", "ReduceScatterVInfo"], [4, 1, 1, "", "ReduceScatterV_Req"], [4, 1, 1, "", "ReduceScatterV_Wait"], [4, 1, 1, "", "ReduceScatter_Req"], [4, 1, 1, "", "ReduceScatter_Wait"], [4, 1, 1, "", "Request"], [4, 1, 1, "", "VariableBatchAll2AllPooledInfo"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Req"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [4, 3, 1, "", "all2all_pooled_sync"], [4, 3, 1, "", "all2all_sequence_sync"], [4, 3, 1, "", "all2allv_sync"], [4, 3, 1, "", "all_gather_base_pooled"], [4, 3, 1, "", "all_gather_base_sync"], [4, 3, 1, "", "all_gather_into_tensor_backward"], [4, 3, 1, "", "all_gather_into_tensor_fake"], [4, 3, 1, "", "all_gather_into_tensor_setup_context"], [4, 3, 1, "", "all_to_all_single_backward"], [4, 3, 1, "", "all_to_all_single_fake"], [4, 3, 1, "", "all_to_all_single_setup_context"], [4, 3, 1, "", "alltoall_pooled"], [4, 3, 1, "", "alltoall_sequence"], [4, 3, 1, "", "alltoallv"], [4, 3, 1, "", "get_gradient_division"], [4, 3, 1, "", "get_use_sync_collectives"], [4, 3, 1, "", "pg_name"], [4, 3, 1, "", "reduce_scatter_base_pooled"], [4, 3, 1, "", "reduce_scatter_base_sync"], [4, 3, 1, "", "reduce_scatter_pooled"], [4, 3, 1, "", "reduce_scatter_sync"], [4, 3, 1, "", "reduce_scatter_tensor_backward"], [4, 3, 1, "", "reduce_scatter_tensor_fake"], [4, 3, 1, "", "reduce_scatter_tensor_setup_context"], [4, 3, 1, "", "reduce_scatter_v_per_feature_pooled"], [4, 3, 1, "", "reduce_scatter_v_pooled"], [4, 3, 1, "", "reduce_scatter_v_sync"], [4, 3, 1, "", "set_gradient_division"], [4, 3, 1, "", "set_use_sync_collectives"], [4, 3, 1, "", "torchrec_use_sync_collectives"], [4, 3, 1, "", "variable_batch_all2all_pooled_sync"], [4, 3, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[4, 4, 1, "", "batch_size"], [4, 4, 1, "", "input_shape"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[4, 4, 1, "id0", "batch_size_per_rank"], [4, 4, 1, "id1", "codecs"], [4, 4, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [4, 4, 1, "id3", "dim_sum_per_rank"], [4, 4, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[4, 4, 1, "id5", "backward_recat_tensor"], [4, 4, 1, "id6", "codecs"], [4, 4, 1, "id7", "embedding_dim"], [4, 4, 1, "id8", "forward_recat_tensor"], [4, 4, 1, "id9", "input_splits"], [4, 4, 1, "id10", "lengths_after_sparse_data_all2all"], [4, 4, 1, "id11", "output_splits"], [4, 4, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [4, 4, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[4, 4, 1, "id14", "B_global"], [4, 4, 1, "id15", "B_local"], [4, 4, 1, "id16", "B_local_list"], [4, 4, 1, "id17", "D_local_list"], [4, 4, 1, "", "codecs"], [4, 4, 1, "", "dim_sum_per_rank"], [4, 4, 1, "", "dims_sum_per_rank"], [4, 4, 1, "id18", "input_split_sizes"], [4, 4, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[4, 4, 1, "id23", "codecs"], [4, 4, 1, "id24", "equal_splits"], [4, 4, 1, "id25", "input_sizes"], [4, 4, 1, "id26", "input_splits"], [4, 4, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[4, 4, 1, "id28", "batch_size_per_feature_pre_a2a"], [4, 4, 1, "id29", "batch_size_per_rank_per_feature"], [4, 4, 1, "id30", "codecs"], [4, 4, 1, "id31", "emb_dim_per_rank_per_feature"], [4, 4, 1, "id32", "input_splits"], [4, 4, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.dist_data": [[6, 1, 1, "", "EmbeddingsAllToOne"], [6, 1, 1, "", "EmbeddingsAllToOneReduce"], [6, 1, 1, "", "JaggedTensorAllToAll"], [6, 1, 1, "", "KJTAllToAll"], [6, 1, 1, "", "KJTAllToAllSplitsAwaitable"], [6, 1, 1, "", "KJTAllToAllTensorsAwaitable"], [6, 1, 1, "", "KJTOneToAll"], [6, 1, 1, "", "MergePooledEmbeddingsModule"], [6, 1, 1, "", "PooledEmbeddingsAllGather"], [6, 1, 1, "", "PooledEmbeddingsAllToAll"], [6, 1, 1, "", "PooledEmbeddingsAwaitable"], [6, 1, 1, "", "PooledEmbeddingsReduceScatter"], [6, 1, 1, "", "SeqEmbeddingsAllToOne"], [6, 1, 1, "", "SequenceEmbeddingsAllToAll"], [6, 1, 1, "", "SequenceEmbeddingsAwaitable"], [6, 1, 1, "", "SplitsAllToAllAwaitable"], [6, 1, 1, "", "TensorAllToAll"], [6, 1, 1, "", "TensorAllToAllSplitsAwaitable"], [6, 1, 1, "", "TensorAllToAllValuesAwaitable"], [6, 1, 1, "", "TensorValuesAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.MergePooledEmbeddingsModule": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[6, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorValuesAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.embedding": [[4, 1, 1, "", "EmbeddingCollectionAwaitable"], [4, 1, 1, "", "EmbeddingCollectionContext"], [4, 1, 1, "", "EmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedEmbeddingCollection"], [4, 3, 1, "", "create_embedding_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_ec_index_dedup"], [4, 3, 1, "", "pad_vbe_kjt_lengths"], [4, 3, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[4, 1, 1, "", "CommOpGradientScaling"], [4, 1, 1, "", "GroupedEmbeddingsLookup"], [4, 1, 1, "", "GroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "InferCPUGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedLookupMixin"], [4, 1, 1, "", "InferGroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [4, 3, 1, "", "dummy_tensor"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle_inference"], [4, 3, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[4, 1, 1, "", "BaseEmbeddingDist"], [4, 1, 1, "", "BaseSparseFeaturesDist"], [4, 1, 1, "", "EmbeddingSharding"], [4, 1, 1, "", "EmbeddingShardingContext"], [4, 1, 1, "", "EmbeddingShardingInfo"], [4, 1, 1, "", "FusedKJTListSplitsAwaitable"], [4, 1, 1, "", "KJTListAwaitable"], [4, 1, 1, "", "KJTListSplitsAwaitable"], [4, 1, 1, "", "KJTSplitsAllToAllMeta"], [4, 1, 1, "", "ListOfKJTListAwaitable"], [4, 1, 1, "", "ListOfKJTListSplitsAwaitable"], [4, 3, 1, "", "bucketize_kjt_before_all2all"], [4, 3, 1, "", "bucketize_kjt_inference"], [4, 3, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[4, 2, 1, "", "create_input_dist"], [4, 2, 1, "", "create_lookup"], [4, 2, 1, "", "create_output_dist"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_names_per_rank"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 2, 1, "", "embedding_tables"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "uncombined_embedding_dims"], [4, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[4, 4, 1, "", "embedding_config"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "param"], [4, 4, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[4, 4, 1, "", "device"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "input_tensors"], [4, 4, 1, "", "keys"], [4, 4, 1, "", "labels"], [4, 4, 1, "", "pg"], [4, 4, 1, "", "splits"], [4, 4, 1, "", "splits_tensors"], [4, 4, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[4, 1, 1, "", "BaseEmbeddingLookup"], [4, 1, 1, "", "BaseEmbeddingSharder"], [4, 1, 1, "", "BaseGroupedFeatureProcessor"], [4, 1, 1, "", "BaseQuantEmbeddingSharder"], [4, 1, 1, "", "DTensorMetadata"], [4, 1, 1, "", "EmbeddingAttributes"], [4, 1, 1, "", "EmbeddingComputeKernel"], [4, 1, 1, "", "FeatureShardingMixIn"], [4, 1, 1, "", "GroupedEmbeddingConfig"], [4, 1, 1, "", "InputDistOutputs"], [4, 1, 1, "", "KJTList"], [4, 1, 1, "", "ListOfKJTList"], [4, 1, 1, "", "ModuleShardingMixIn"], [4, 1, 1, "", "OptimType"], [4, 1, 1, "", "ShardedConfig"], [4, 1, 1, "", "ShardedEmbeddingModule"], [4, 1, 1, "", "ShardedEmbeddingTable"], [4, 1, 1, "", "ShardedMetaConfig"], [4, 3, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.DTensorMetadata": [[4, 4, 1, "", "mesh"], [4, 4, 1, "", "placements"], [4, 4, 1, "", "size"], [4, 4, 1, "", "stride"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[4, 4, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[4, 4, 1, "", "DENSE"], [4, 4, 1, "", "FUSED"], [4, 4, 1, "", "FUSED_UVM"], [4, 4, 1, "", "FUSED_UVM_CACHING"], [4, 4, 1, "", "KEY_VALUE"], [4, 4, 1, "", "QUANT"], [4, 4, 1, "", "QUANT_UVM"], [4, 4, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[4, 2, 1, "", "feature_names"], [4, 2, 1, "", "feature_names_per_rank"], [4, 2, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "data_type"], [4, 2, 1, "", "dim_sum"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 4, 1, "", "embedding_tables"], [4, 2, 1, "", "feature_hash_sizes"], [4, 2, 1, "", "feature_names"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "has_feature_processor"], [4, 4, 1, "", "is_weighted"], [4, 2, 1, "", "num_features"], [4, 4, 1, "", "pooling"], [4, 2, 1, "", "table_names"]], "torchrec.distributed.embedding_types.InputDistOutputs": [[4, 4, 1, "", "bucket_mapping_tensor"], [4, 4, 1, "", "bucketized_length"], [4, 4, 1, "", "features"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "unbucketize_permute_tensor"]], "torchrec.distributed.embedding_types.KJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[4, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[4, 4, 1, "", "ADAGRAD"], [4, 4, 1, "", "ADAM"], [4, 4, 1, "", "ADAMW"], [4, 4, 1, "", "LAMB"], [4, 4, 1, "", "LARS_SGD"], [4, 4, 1, "", "LION"], [4, 4, 1, "", "PARTIAL_ROWWISE_ADAM"], [4, 4, 1, "", "PARTIAL_ROWWISE_LAMB"], [4, 4, 1, "", "ROWWISE_ADAGRAD"], [4, 4, 1, "", "SGD"], [4, 4, 1, "", "SHAMPOO"], [4, 4, 1, "", "SHAMPOO_V2"], [4, 4, 1, "", "SHAMPOO_V2_MRS"]], "torchrec.distributed.embedding_types.ShardedConfig": [[4, 4, 1, "", "local_cols"], [4, 4, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[4, 2, 1, "", "extra_repr"], [4, 2, 1, "", "prefetch"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[4, 4, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[4, 4, 1, "", "dtensor_metadata"], [4, 4, 1, "", "global_metadata"], [4, 4, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[4, 1, 1, "", "EmbeddingAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionContext"], [4, 1, 1, "", "EmbeddingBagCollectionSharder"], [4, 1, 1, "", "EmbeddingBagSharder"], [4, 1, 1, "", "ShardedEmbeddingBag"], [4, 1, 1, "", "ShardedEmbeddingBagCollection"], [4, 1, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [4, 3, 1, "", "construct_output_kt"], [4, 3, 1, "", "create_embedding_bag_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[4, 4, 1, "", "divisor"], [4, 4, 1, "", "inverse_indices"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "sharding_contexts"], [4, 4, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_modules"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[4, 1, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[4, 1, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[4, 4, 1, "", "evictions_per_table"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_modules": [[4, 1, 1, "", "ManagedCollisionCollectionAwaitable"], [4, 1, 1, "", "ManagedCollisionCollectionContext"], [4, 1, 1, "", "ManagedCollisionCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionCollection"], [4, 3, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "evict"], [4, 2, 1, "", "global_to_local_index"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "open_slots"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.model_parallel": [[4, 1, 1, "", "DataParallelWrapper"], [4, 1, 1, "", "DefaultDataParallelWrapper"], [4, 1, 1, "", "DistributedModelParallel"], [4, 3, 1, "", "get_module"], [4, 3, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[4, 2, 1, "", "bare_named_parameters"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "forward"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "init_data_parallel"], [4, 2, 1, "", "load_state_dict"], [4, 5, 1, "", "module"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 5, 1, "", "plan"], [4, 2, 1, "", "sparse_grad_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.planner": [[5, 0, 0, "-", "constants"], [5, 0, 0, "-", "enumerators"], [5, 0, 0, "-", "partitioners"], [5, 0, 0, "-", "perf_models"], [5, 0, 0, "-", "planners"], [5, 0, 0, "-", "proposers"], [5, 0, 0, "-", "shard_estimators"], [5, 0, 0, "-", "stats"], [5, 0, 0, "-", "storage_reservations"], [5, 0, 0, "-", "types"], [5, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[5, 3, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[5, 1, 1, "", "EmbeddingEnumerator"], [5, 3, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[5, 1, 1, "", "GreedyPerfPartitioner"], [5, 1, 1, "", "MemoryBalancedPartitioner"], [5, 1, 1, "", "OrderedDeviceHardware"], [5, 1, 1, "", "ShardingOptionGroup"], [5, 1, 1, "", "SortBy"], [5, 3, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[5, 4, 1, "", "device"], [5, 4, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[5, 4, 1, "", "param_count"], [5, 4, 1, "", "perf_sum"], [5, 4, 1, "", "sharding_options"], [5, 4, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[5, 4, 1, "", "PERF"], [5, 4, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[5, 1, 1, "", "NoopPerfModel"], [5, 1, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[5, 1, 1, "", "EmbeddingShardingPlanner"], [5, 1, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[5, 1, 1, "", "DynamicProgrammingProposer"], [5, 1, 1, "", "EmbeddingOffloadScaleupProposer"], [5, 1, 1, "", "GreedyProposer"], [5, 1, 1, "", "GridSearchProposer"], [5, 1, 1, "", "UniformProposer"], [5, 3, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.DynamicProgrammingProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[5, 2, 1, "", "allocate_budget"], [5, 2, 1, "", "build_affine_storage_model"], [5, 2, 1, "", "clf_to_bytes"], [5, 2, 1, "", "feedback"], [5, 2, 1, "", "get_budget"], [5, 2, 1, "", "get_cacheability"], [5, 2, 1, "", "get_expected_lookups"], [5, 2, 1, "", "load"], [5, 2, 1, "", "next_plan"], [5, 2, 1, "", "promote_high_prefetch_overheaad_table_to_hbm"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[5, 1, 1, "", "EmbeddingOffloadStats"], [5, 1, 1, "", "EmbeddingPerfEstimator"], [5, 1, 1, "", "EmbeddingStorageEstimator"], [5, 3, 1, "", "calculate_pipeline_io_cost"], [5, 3, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[5, 5, 1, "", "cacheability"], [5, 2, 1, "", "estimate_cache_miss_rate"], [5, 5, 1, "", "expected_lookups"], [5, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[5, 2, 1, "", "estimate"], [5, 2, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[5, 1, 1, "", "EmbeddingStats"], [5, 1, 1, "", "NoopEmbeddingStats"], [5, 3, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[5, 1, 1, "", "FixedPercentageStorageReservation"], [5, 1, 1, "", "HeuristicalStorageReservation"], [5, 1, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[5, 1, 1, "", "CustomTopologyData"], [5, 1, 1, "", "DeviceHardware"], [5, 1, 1, "", "Enumerator"], [5, 1, 1, "", "ParameterConstraints"], [5, 1, 1, "", "PartitionByType"], [5, 1, 1, "", "Partitioner"], [5, 1, 1, "", "Perf"], [5, 1, 1, "", "PerfModel"], [5, 6, 1, "", "PlannerError"], [5, 1, 1, "", "PlannerErrorType"], [5, 1, 1, "", "Proposer"], [5, 1, 1, "", "Shard"], [5, 1, 1, "", "ShardEstimator"], [5, 1, 1, "", "ShardingOption"], [5, 1, 1, "", "Stats"], [5, 1, 1, "", "Storage"], [5, 1, 1, "", "StorageReservation"], [5, 1, 1, "", "Topology"]], "torchrec.distributed.planner.types.CustomTopologyData": [[5, 2, 1, "", "get_data"], [5, 2, 1, "", "has_data"], [5, 4, 1, "", "supported_fields"]], "torchrec.distributed.planner.types.DeviceHardware": [[5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[5, 4, 1, "id0", "batch_sizes"], [5, 4, 1, "id1", "bounds_check_mode"], [5, 4, 1, "id2", "cache_params"], [5, 4, 1, "id3", "compute_kernels"], [5, 4, 1, "id4", "device_group"], [5, 4, 1, "id5", "enforce_hbm"], [5, 4, 1, "id6", "feature_names"], [5, 4, 1, "id7", "is_weighted"], [5, 4, 1, "id8", "key_value_params"], [5, 4, 1, "id9", "min_partition"], [5, 4, 1, "id10", "num_poolings"], [5, 4, 1, "id11", "output_dtype"], [5, 4, 1, "id12", "pooling_factors"], [5, 4, 1, "id13", "sharding_types"], [5, 4, 1, "id14", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[5, 4, 1, "", "DEVICE"], [5, 4, 1, "", "HOST"], [5, 4, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[5, 4, 1, "", "bwd_comms"], [5, 4, 1, "", "bwd_compute"], [5, 4, 1, "", "fwd_comms"], [5, 4, 1, "", "fwd_compute"], [5, 4, 1, "", "prefetch_compute"], [5, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[5, 4, 1, "", "INSUFFICIENT_STORAGE"], [5, 4, 1, "", "OTHER"], [5, 4, 1, "", "PARTITION"], [5, 4, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[5, 4, 1, "", "offset"], [5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "size"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[5, 4, 1, "", "batch_size"], [5, 4, 1, "", "bounds_check_mode"], [5, 5, 1, "", "cache_load_factor"], [5, 4, 1, "", "cache_params"], [5, 4, 1, "", "compute_kernel"], [5, 4, 1, "", "dependency"], [5, 4, 1, "", "enforce_hbm"], [5, 4, 1, "", "feature_names"], [5, 5, 1, "", "fqn"], [5, 4, 1, "", "input_lengths"], [5, 5, 1, "id15", "is_pooled"], [5, 4, 1, "", "key_value_params"], [5, 5, 1, "id16", "module"], [5, 2, 1, "", "module_pooled"], [5, 4, 1, "", "name"], [5, 5, 1, "", "num_inputs"], [5, 5, 1, "", "num_shards"], [5, 4, 1, "", "output_dtype"], [5, 5, 1, "", "path"], [5, 4, 1, "", "sharding_type"], [5, 4, 1, "", "shards"], [5, 4, 1, "", "stochastic_rounding"], [5, 5, 1, "id17", "tensor"], [5, 5, 1, "", "total_perf"], [5, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[5, 4, 1, "", "ddr"], [5, 2, 1, "", "fits_in"], [5, 4, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[5, 5, 1, "", "bwd_compute_multiplier"], [5, 5, 1, "", "compute_device"], [5, 5, 1, "", "ddr_mem_bw"], [5, 5, 1, "", "devices"], [5, 5, 1, "", "hbm_mem_bw"], [5, 5, 1, "", "inter_host_bw"], [5, 5, 1, "", "intra_host_bw"], [5, 5, 1, "", "local_world_size"], [5, 5, 1, "", "uneven_sharding_perf_multiplier"], [5, 5, 1, "", "weighted_feature_bwd_compute_multiplier"], [5, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[5, 1, 1, "", "BinarySearchPredicate"], [5, 1, 1, "", "LuusJaakolaSearch"], [5, 3, 1, "", "bytes_to_gb"], [5, 3, 1, "", "bytes_to_mb"], [5, 3, 1, "", "gb_to_bytes"], [5, 3, 1, "", "placement"], [5, 3, 1, "", "prod"], [5, 3, 1, "", "reset_shard_rank"], [5, 3, 1, "", "sharder_name"], [5, 3, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[5, 2, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[5, 2, 1, "", "best"], [5, 2, 1, "", "clamp"], [5, 2, 1, "", "next"], [5, 2, 1, "", "shrink_right"], [5, 2, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[4, 1, 1, "", "QuantEmbeddingBagCollectionSharder"], [4, 1, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedQuantEbcInputDist"], [4, 1, 1, "", "ShardedQuantEmbeddingBagCollection"], [4, 1, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [4, 3, 1, "", "create_infer_embedding_bag_sharding"], [4, 3, 1, "", "flatten_feature_lengths"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "embedding_bag_configs"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharding_type_device_group_to_sharding_infos"], [4, 5, 1, "", "shardings"], [4, 2, 1, "", "tbes_configs"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[4, 2, 1, "", "apply_feature_processor"], [4, 2, 1, "", "compute"], [4, 4, 1, "", "embedding_bags"], [4, 4, 1, "", "tbes"], [4, 4, 1, "", "training"]], "torchrec.distributed.sharding": [[6, 0, 0, "-", "cw_sharding"], [6, 0, 0, "-", "dp_sharding"], [6, 0, 0, "-", "rw_sharding"], [6, 0, 0, "-", "tw_sharding"], [6, 0, 0, "-", "twcw_sharding"], [6, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[6, 1, 1, "", "BaseCwEmbeddingSharding"], [6, 1, 1, "", "CwPooledEmbeddingSharding"], [6, 1, 1, "", "InferCwPooledEmbeddingDist"], [6, 1, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [6, 1, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "uncombined_embedding_dims"], [6, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[6, 1, 1, "", "BaseDpEmbeddingSharding"], [6, 1, 1, "", "DpPooledEmbeddingDist"], [6, 1, 1, "", "DpPooledEmbeddingSharding"], [6, 1, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[6, 1, 1, "", "BaseRwEmbeddingSharding"], [6, 1, 1, "", "InferRwPooledEmbeddingDist"], [6, 1, 1, "", "InferRwPooledEmbeddingSharding"], [6, 1, 1, "", "InferRwSparseFeaturesDist"], [6, 1, 1, "", "RwPooledEmbeddingDist"], [6, 1, 1, "", "RwPooledEmbeddingSharding"], [6, 1, 1, "", "RwSparseFeaturesDist"], [6, 3, 1, "", "get_block_sizes_runtime_device"], [6, 3, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[6, 1, 1, "", "BaseTwEmbeddingSharding"], [6, 1, 1, "", "InferTwEmbeddingSharding"], [6, 1, 1, "", "InferTwPooledEmbeddingDist"], [6, 1, 1, "", "InferTwSparseFeaturesDist"], [6, 1, 1, "", "TwPooledEmbeddingDist"], [6, 1, 1, "", "TwPooledEmbeddingSharding"], [6, 1, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"], [6, 2, 1, "", "feature_names_per_rank"], [6, 2, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[6, 1, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[6, 1, 1, "", "BaseTwRwEmbeddingSharding"], [6, 1, 1, "", "TwRwPooledEmbeddingDist"], [6, 1, 1, "", "TwRwPooledEmbeddingSharding"], [6, 1, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.types": [[4, 1, 1, "", "Awaitable"], [4, 1, 1, "", "CacheParams"], [4, 1, 1, "", "CacheStatistics"], [4, 1, 1, "", "CommOp"], [4, 1, 1, "", "ComputeKernel"], [4, 1, 1, "", "EmbeddingModuleShardingPlan"], [4, 1, 1, "", "GenericMeta"], [4, 1, 1, "", "GetItemLazyAwaitable"], [4, 1, 1, "", "KeyValueParams"], [4, 1, 1, "", "LazyAwaitable"], [4, 1, 1, "", "LazyGetItemMixin"], [4, 1, 1, "", "LazyNoWait"], [4, 1, 1, "", "ModuleSharder"], [4, 1, 1, "", "ModuleShardingPlan"], [4, 1, 1, "", "NoOpQuantizedCommCodec"], [4, 1, 1, "", "NoWait"], [4, 1, 1, "", "NullShardedModuleContext"], [4, 1, 1, "", "NullShardingContext"], [4, 1, 1, "", "ObjectPoolShardingPlan"], [4, 1, 1, "", "ObjectPoolShardingType"], [4, 1, 1, "", "ParameterSharding"], [4, 1, 1, "", "ParameterStorage"], [4, 1, 1, "", "PipelineType"], [4, 1, 1, "", "QuantizedCommCodec"], [4, 1, 1, "", "QuantizedCommCodecs"], [4, 1, 1, "", "ShardedModule"], [4, 1, 1, "", "ShardingEnv"], [4, 1, 1, "", "ShardingPlan"], [4, 1, 1, "", "ShardingPlanner"], [4, 1, 1, "", "ShardingType"], [4, 3, 1, "", "get_tensor_size_bytes"], [4, 3, 1, "", "rank_device"], [4, 3, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[4, 5, 1, "", "callbacks"], [4, 2, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[4, 4, 1, "id34", "algorithm"], [4, 4, 1, "id35", "load_factor"], [4, 4, 1, "", "multipass_prefetch_config"], [4, 4, 1, "id36", "precision"], [4, 4, 1, "id37", "prefetch_pipeline"], [4, 4, 1, "id38", "reserved_memory"], [4, 4, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[4, 5, 1, "", "cacheability"], [4, 5, 1, "", "expected_lookups"], [4, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[4, 4, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [4, 4, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [4, 4, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[4, 4, 1, "", "DEFAULT"]], "torchrec.distributed.types.KeyValueParams": [[4, 4, 1, "id40", "gather_ssd_cache_stats"], [4, 4, 1, "", "l2_cache_size"], [4, 4, 1, "", "ods_prefix"], [4, 4, 1, "id41", "ps_client_thread_num"], [4, 4, 1, "id42", "ps_hosts"], [4, 4, 1, "id43", "ps_max_key_per_request"], [4, 4, 1, "id44", "ps_max_local_index_length"], [4, 4, 1, "", "report_interval"], [4, 4, 1, "id45", "ssd_rocksdb_shards"], [4, 4, 1, "id46", "ssd_rocksdb_write_buffer_size"], [4, 4, 1, "id47", "ssd_storage_directory"], [4, 4, 1, "", "stats_reporter_config"], [4, 4, 1, "", "use_passed_in_path"]], "torchrec.distributed.types.ModuleSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 2, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.ObjectPoolShardingPlan": [[4, 4, 1, "", "inference"], [4, 4, 1, "", "sharding_type"]], "torchrec.distributed.types.ObjectPoolShardingType": [[4, 4, 1, "", "REPLICATED_ROW_WISE"], [4, 4, 1, "", "ROW_WISE"]], "torchrec.distributed.types.ParameterSharding": [[4, 4, 1, "", "bounds_check_mode"], [4, 4, 1, "", "cache_params"], [4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "enforce_hbm"], [4, 4, 1, "", "key_value_params"], [4, 4, 1, "", "output_dtype"], [4, 4, 1, "", "ranks"], [4, 4, 1, "", "sharding_spec"], [4, 4, 1, "", "sharding_type"], [4, 4, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[4, 4, 1, "", "DDR"], [4, 4, 1, "", "HBM"]], "torchrec.distributed.types.PipelineType": [[4, 4, 1, "", "NONE"], [4, 4, 1, "", "TRAIN_BASE"], [4, 4, 1, "", "TRAIN_PREFETCH_SPARSE_DIST"], [4, 4, 1, "", "TRAIN_SPARSE_DIST"]], "torchrec.distributed.types.QuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[4, 2, 1, "", "from_local"], [4, 2, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[4, 2, 1, "", "get_plan_for_module"], [4, 4, 1, "id48", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[4, 2, 1, "", "collective_plan"], [4, 2, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[4, 4, 1, "", "COLUMN_WISE"], [4, 4, 1, "", "DATA_PARALLEL"], [4, 4, 1, "", "ROW_WISE"], [4, 4, 1, "", "TABLE_COLUMN_WISE"], [4, 4, 1, "", "TABLE_ROW_WISE"], [4, 4, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[4, 1, 1, "", "CopyableMixin"], [4, 1, 1, "", "ForkedPdb"], [4, 3, 1, "", "add_params_from_parameter_sharding"], [4, 3, 1, "", "add_prefix_to_state_dict"], [4, 3, 1, "", "append_prefix"], [4, 3, 1, "", "convert_to_fbgemm_types"], [4, 3, 1, "", "copy_to_device"], [4, 3, 1, "", "filter_state_dict"], [4, 3, 1, "", "get_unsharded_module_names"], [4, 3, 1, "", "init_parameters"], [4, 3, 1, "", "merge_fused_params"], [4, 3, 1, "", "none_throws"], [4, 3, 1, "", "optimizer_type_to_emb_opt_type"], [4, 1, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[4, 2, 1, "", "copy"], [4, 4, 1, "", "training"]], "torchrec.distributed.utils.ForkedPdb": [[4, 2, 1, "", "interaction"]], "torchrec.fx": [[7, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[7, 1, 1, "", "Tracer"], [7, 3, 1, "", "is_fx_tracing"], [7, 3, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[7, 2, 1, "", "create_arg"], [7, 2, 1, "", "is_leaf_module"], [7, 2, 1, "", "path_of_module"], [7, 2, 1, "", "trace"]], "torchrec.inference": [[8, 0, 0, "-", "model_packager"], [8, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[8, 1, 1, "", "PredictFactoryPackager"], [8, 3, 1, "", "load_config_text"], [8, 3, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[8, 2, 1, "", "save_predict_factory"], [8, 2, 1, "", "set_extern_modules"], [8, 2, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[8, 1, 1, "", "BatchingMetadata"], [8, 1, 1, "", "PredictFactory"], [8, 1, 1, "", "PredictModule"], [8, 1, 1, "", "QualNameMetadata"], [8, 3, 1, "", "assign_weights_to_tbe"], [8, 3, 1, "", "get_table_to_weights_from_tbe"], [8, 3, 1, "", "quantize_dense"], [8, 3, 1, "", "quantize_embeddings"], [8, 3, 1, "", "quantize_feature"], [8, 3, 1, "", "quantize_inference_model"], [8, 3, 1, "", "set_pruning_data"], [8, 3, 1, "", "shard_quant_model"], [8, 3, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[8, 4, 1, "", "device"], [8, 4, 1, "", "pinned"], [8, 4, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[8, 2, 1, "", "batching_metadata"], [8, 2, 1, "", "batching_metadata_json"], [8, 2, 1, "", "create_predict_module"], [8, 2, 1, "", "model_inputs_data"], [8, 2, 1, "", "qualname_metadata"], [8, 2, 1, "", "qualname_metadata_json"], [8, 2, 1, "", "result_metadata"], [8, 2, 1, "", "run_weights_dependent_transformations"], [8, 2, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[8, 2, 1, "", "forward"], [8, 2, 1, "", "predict_forward"], [8, 5, 1, "", "predict_module"], [8, 2, 1, "", "state_dict"], [8, 4, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[8, 4, 1, "", "need_preproc"]], "torchrec.metrics": [[9, 0, 0, "-", "accuracy"], [9, 0, 0, "-", "auc"], [9, 0, 0, "-", "auprc"], [9, 0, 0, "-", "calibration"], [9, 0, 0, "-", "ctr"], [9, 0, 0, "-", "mae"], [9, 0, 0, "-", "metric_module"], [9, 0, 0, "-", "mse"], [9, 0, 0, "-", "multiclass_recall"], [9, 0, 0, "-", "ndcg"], [9, 0, 0, "-", "ne"], [9, 0, 0, "-", "precision"], [9, 0, 0, "-", "rauc"], [9, 0, 0, "-", "rec_metric"], [9, 0, 0, "-", "recall"], [9, 0, 0, "-", "throughput"], [9, 0, 0, "-", "weighted_avg"], [9, 0, 0, "-", "xauc"]], "torchrec.metrics.accuracy": [[9, 1, 1, "", "AccuracyMetric"], [9, 1, 1, "", "AccuracyMetricComputation"], [9, 3, 1, "", "compute_accuracy"], [9, 3, 1, "", "compute_accuracy_sum"], [9, 3, 1, "", "get_accuracy_states"]], "torchrec.metrics.accuracy.AccuracyMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.auc": [[9, 1, 1, "", "AUCMetric"], [9, 1, 1, "", "AUCMetricComputation"], [9, 3, 1, "", "compute_auc"], [9, 3, 1, "", "compute_auc_per_group"]], "torchrec.metrics.auc.AUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.auprc": [[9, 1, 1, "", "AUPRCMetric"], [9, 1, 1, "", "AUPRCMetricComputation"], [9, 3, 1, "", "compute_auprc"], [9, 3, 1, "", "compute_auprc_per_group"]], "torchrec.metrics.auprc.AUPRCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.calibration": [[9, 1, 1, "", "CalibrationMetric"], [9, 1, 1, "", "CalibrationMetricComputation"], [9, 3, 1, "", "compute_calibration"], [9, 3, 1, "", "get_calibration_states"]], "torchrec.metrics.calibration.CalibrationMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ctr": [[9, 1, 1, "", "CTRMetric"], [9, 1, 1, "", "CTRMetricComputation"], [9, 3, 1, "", "compute_ctr"], [9, 3, 1, "", "get_ctr_states"]], "torchrec.metrics.ctr.CTRMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.mae": [[9, 1, 1, "", "MAEMetric"], [9, 1, 1, "", "MAEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mae"], [9, 3, 1, "", "get_mae_states"]], "torchrec.metrics.mae.MAEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.metric_module": [[9, 1, 1, "", "RecMetricModule"], [9, 1, 1, "", "StateMetric"], [9, 3, 1, "", "generate_metric_module"]], "torchrec.metrics.metric_module.RecMetricModule": [[9, 4, 1, "", "batch_size"], [9, 2, 1, "", "check_memory_usage"], [9, 2, 1, "", "compute"], [9, 4, 1, "", "compute_count"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 4, 1, "", "last_compute_time"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "memory_usage_limit_mb"], [9, 4, 1, "", "memory_usage_mb_avg"], [9, 4, 1, "", "oom_count"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "rec_tasks"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "should_compute"], [9, 4, 1, "", "state_metrics"], [9, 2, 1, "", "sync"], [9, 4, 1, "", "throughput_metric"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"], [9, 4, 1, "", "world_size"]], "torchrec.metrics.metric_module.StateMetric": [[9, 2, 1, "", "get_metrics"]], "torchrec.metrics.mse": [[9, 1, 1, "", "MSEMetric"], [9, 1, 1, "", "MSEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mse"], [9, 3, 1, "", "compute_rmse"], [9, 3, 1, "", "get_mse_states"]], "torchrec.metrics.mse.MSEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.multiclass_recall": [[9, 1, 1, "", "MulticlassRecallMetric"], [9, 1, 1, "", "MulticlassRecallMetricComputation"], [9, 3, 1, "", "compute_multiclass_recall_at_k"], [9, 3, 1, "", "compute_true_positives_at_k"], [9, 3, 1, "", "get_multiclass_recall_states"]], "torchrec.metrics.multiclass_recall.MulticlassRecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ndcg": [[9, 1, 1, "", "NDCGComputation"], [9, 1, 1, "", "NDCGMetric"]], "torchrec.metrics.ndcg.NDCGComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ne": [[9, 1, 1, "", "NEMetric"], [9, 1, 1, "", "NEMetricComputation"], [9, 3, 1, "", "compute_cross_entropy"], [9, 3, 1, "", "compute_logloss"], [9, 3, 1, "", "compute_ne"], [9, 3, 1, "", "get_ne_states"]], "torchrec.metrics.ne.NEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.precision": [[9, 1, 1, "", "PrecisionMetric"], [9, 1, 1, "", "PrecisionMetricComputation"], [9, 3, 1, "", "compute_false_pos_sum"], [9, 3, 1, "", "compute_precision"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_precision_states"]], "torchrec.metrics.precision.PrecisionMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.rauc": [[9, 1, 1, "", "RAUCMetric"], [9, 1, 1, "", "RAUCMetricComputation"], [9, 3, 1, "", "compute_rauc"], [9, 3, 1, "", "compute_rauc_per_group"], [9, 3, 1, "", "conquer_and_count"], [9, 3, 1, "", "count_reverse_pairs_divide_and_conquer"], [9, 3, 1, "", "divide"]], "torchrec.metrics.rauc.RAUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric": [[9, 1, 1, "", "MetricComputationReport"], [9, 1, 1, "", "RecMetric"], [9, 1, 1, "", "RecMetricComputation"], [9, 6, 1, "", "RecMetricException"], [9, 1, 1, "", "RecMetricList"], [9, 1, 1, "", "WindowBuffer"]], "torchrec.metrics.rec_metric.MetricComputationReport": [[9, 4, 1, "", "description"], [9, 4, 1, "", "metric_prefix"], [9, 4, 1, "", "name"], [9, 4, 1, "", "value"]], "torchrec.metrics.rec_metric.RecMetric": [[9, 4, 1, "", "LABELS"], [9, 4, 1, "", "PREDICTIONS"], [9, 4, 1, "", "WEIGHTS"], [9, 2, 1, "", "compute"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "state_dict"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricComputation": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_window_state"], [9, 2, 1, "", "get_window_state_name"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "pre_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricList": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "required_inputs"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.WindowBuffer": [[9, 2, 1, "", "aggregate_state"], [9, 5, 1, "", "buffers"]], "torchrec.metrics.recall": [[9, 1, 1, "", "RecallMetric"], [9, 1, 1, "", "RecallMetricComputation"], [9, 3, 1, "", "compute_false_neg_sum"], [9, 3, 1, "", "compute_recall"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_recall_states"]], "torchrec.metrics.recall.RecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.throughput": [[9, 1, 1, "", "ThroughputMetric"]], "torchrec.metrics.throughput.ThroughputMetric": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "update"]], "torchrec.metrics.weighted_avg": [[9, 1, 1, "", "WeightedAvgMetric"], [9, 1, 1, "", "WeightedAvgMetricComputation"], [9, 3, 1, "", "get_mean"]], "torchrec.metrics.weighted_avg.WeightedAvgMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.xauc": [[9, 1, 1, "", "XAUCMetric"], [9, 1, 1, "", "XAUCMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_weighted_num_pairs"], [9, 3, 1, "", "compute_xauc"], [9, 3, 1, "", "get_xauc_states"]], "torchrec.metrics.xauc.XAUCMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.models": [[10, 0, 0, "-", "deepfm"], [10, 0, 0, "-", "dlrm"]], "torchrec.models.deepfm": [[10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "FMInteractionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SimpleDeepFMNN"], [10, 1, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm": [[10, 1, 1, "", "DLRM"], [10, 1, 1, "", "DLRMTrain"], [10, 1, 1, "", "DLRM_DCN"], [10, 1, 1, "", "DLRM_Projection"], [10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "InteractionArch"], [10, 1, 1, "", "InteractionDCNArch"], [10, 1, 1, "", "InteractionProjectionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SparseArch"], [10, 3, 1, "", "choose"]], "torchrec.models.dlrm.DLRM": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRMTrain": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_DCN": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_Projection": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionDCNArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionProjectionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.SparseArch": [[10, 2, 1, "", "forward"], [10, 5, 1, "", "sparse_feature_names"], [10, 4, 1, "", "training"]], "torchrec.modules": [[11, 0, 0, "-", "activation"], [11, 0, 0, "-", "crossnet"], [11, 0, 0, "-", "deepfm"], [11, 0, 0, "-", "embedding_configs"], [11, 0, 0, "-", "embedding_modules"], [11, 0, 0, "-", "feature_processor"], [11, 0, 0, "-", "lazy_extension"], [11, 0, 0, "-", "mc_embedding_modules"], [11, 0, 0, "-", "mc_modules"], [11, 0, 0, "-", "mlp"], [11, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[11, 1, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet": [[11, 1, 1, "", "CrossNet"], [11, 1, 1, "", "LowRankCrossNet"], [11, 1, 1, "", "LowRankMixtureCrossNet"], [11, 1, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm": [[11, 1, 1, "", "DeepFM"], [11, 1, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_configs": [[11, 1, 1, "", "BaseEmbeddingConfig"], [11, 1, 1, "", "EmbeddingBagConfig"], [11, 1, 1, "", "EmbeddingConfig"], [11, 1, 1, "", "EmbeddingTableConfig"], [11, 1, 1, "", "PoolingType"], [11, 1, 1, "", "QuantConfig"], [11, 1, 1, "", "ShardingType"], [11, 3, 1, "", "data_type_to_dtype"], [11, 3, 1, "", "data_type_to_sparse_type"], [11, 3, 1, "", "dtype_to_data_type"], [11, 3, 1, "", "pooling_type_to_pooling_mode"], [11, 3, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[11, 4, 1, "", "data_type"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 2, 1, "", "get_weight_init_max"], [11, 2, 1, "", "get_weight_init_min"], [11, 4, 1, "", "init_fn"], [11, 4, 1, "", "name"], [11, 4, 1, "", "need_pos"], [11, 4, 1, "", "num_embeddings"], [11, 4, 1, "", "num_embeddings_post_pruning"], [11, 2, 1, "", "num_features"], [11, 4, 1, "", "weight_init_max"], [11, 4, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 4, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[11, 4, 1, "", "embedding_names"], [11, 4, 1, "", "has_feature_processor"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[11, 4, 1, "", "MEAN"], [11, 4, 1, "", "NONE"], [11, 4, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[11, 4, 1, "", "activation"], [11, 4, 1, "", "per_table_weight_dtype"], [11, 4, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[11, 4, 1, "", "COLUMN_WISE"], [11, 4, 1, "", "DATA_PARALLEL"], [11, 4, 1, "", "ROW_WISE"], [11, 4, 1, "", "TABLE_COLUMN_WISE"], [11, 4, 1, "", "TABLE_ROW_WISE"], [11, 4, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[11, 1, 1, "", "EmbeddingBagCollection"], [11, 1, 1, "", "EmbeddingBagCollectionInterface"], [11, 1, 1, "", "EmbeddingCollection"], [11, 1, 1, "", "EmbeddingCollectionInterface"], [11, 3, 1, "", "get_embedding_names_by_table"], [11, 3, 1, "", "process_pooled_embeddings"], [11, 3, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor": [[11, 1, 1, "", "BaseFeatureProcessor"], [11, 1, 1, "", "BaseGroupedFeatureProcessor"], [11, 1, 1, "", "PositionWeightedModule"], [11, 1, 1, "", "PositionWeightedProcessor"], [11, 3, 1, "", "offsets_to_range_traceble"], [11, 3, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "named_buffers"], [11, 2, 1, "", "state_dict"], [11, 4, 1, "", "training"]], "torchrec.modules.lazy_extension": [[11, 1, 1, "", "LazyModuleExtensionMixin"], [11, 3, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[11, 2, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[11, 1, 1, "", "BaseManagedCollisionEmbeddingCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingBagCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingCollection"], [11, 3, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_modules": [[11, 1, 1, "", "DistanceLFU_EvictionPolicy"], [11, 1, 1, "", "LFU_EvictionPolicy"], [11, 1, 1, "", "LRU_EvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicyMetadataInfo"], [11, 1, 1, "", "MCHManagedCollisionModule"], [11, 1, 1, "", "ManagedCollisionCollection"], [11, 1, 1, "", "ManagedCollisionModule"], [11, 3, 1, "", "apply_mc_method_to_jt_dict"], [11, 3, 1, "", "average_threshold_filter"], [11, 3, 1, "", "dynamic_threshold_filter"], [11, 3, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[11, 4, 1, "", "is_history_metadata"], [11, 4, 1, "", "is_mch_metadata"], [11, 4, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "open_slots"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[11, 5, 1, "", "device"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mlp": [[11, 1, 1, "", "MLP"], [11, 1, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.utils": [[11, 1, 1, "", "SequenceVBEContext"], [11, 3, 1, "", "check_module_output_dimension"], [11, 3, 1, "", "construct_jagged_tensors"], [11, 3, 1, "", "construct_jagged_tensors_inference"], [11, 3, 1, "", "construct_modulelist_from_single_module"], [11, 3, 1, "", "convert_list_of_modules_to_modulelist"], [11, 3, 1, "", "deterministic_dedup"], [11, 3, 1, "", "extract_module_or_tensor_callable"], [11, 3, 1, "", "get_module_output_dimension"], [11, 3, 1, "", "init_mlp_weights_xavier_uniform"], [11, 3, 1, "", "jagged_index_select_with_empty"]], "torchrec.modules.utils.SequenceVBEContext": [[11, 4, 1, "", "recat"], [11, 2, 1, "", "record_stream"], [11, 4, 1, "", "reindexed_length_per_key"], [11, 4, 1, "", "reindexed_lengths"], [11, 4, 1, "", "reindexed_values"], [11, 4, 1, "", "unpadded_lengths"]], "torchrec.optim": [[12, 0, 0, "-", "clipping"], [12, 0, 0, "-", "fused"], [12, 0, 0, "-", "keyed"], [12, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[12, 1, 1, "", "GradientClipping"], [12, 1, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[12, 4, 1, "", "NONE"], [12, 4, 1, "", "NORM"], [12, 4, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[12, 2, 1, "", "step"]], "torchrec.optim.fused": [[12, 1, 1, "", "EmptyFusedOptimizer"], [12, 1, 1, "", "FusedOptimizer"], [12, 1, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[12, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[12, 1, 1, "", "CombinedOptimizer"], [12, 1, 1, "", "KeyedOptimizer"], [12, 1, 1, "", "KeyedOptimizerWrapper"], [12, 1, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[12, 5, 1, "", "optimizers"], [12, 5, 1, "", "param_groups"], [12, 5, 1, "", "params"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "prepend_opt_key"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "set_optimizer_step"], [12, 5, 1, "", "state"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "init_state"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[12, 1, 1, "", "WarmupOptimizer"], [12, 1, 1, "", "WarmupPolicy"], [12, 1, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[12, 4, 1, "", "CONSTANT"], [12, 4, 1, "", "COSINE_ANNEALING_WARM_RESTARTS"], [12, 4, 1, "", "INVSQRT"], [12, 4, 1, "", "LINEAR"], [12, 4, 1, "", "NONE"], [12, 4, 1, "", "POLY"], [12, 4, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[12, 4, 1, "", "decay_iters"], [12, 4, 1, "", "lr_scale"], [12, 4, 1, "", "max_iters"], [12, 4, 1, "", "policy"], [12, 4, 1, "", "sgdr_period"], [12, 4, 1, "", "value"]], "torchrec.quant": [[13, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[13, 1, 1, "", "EmbeddingBagCollection"], [13, 1, 1, "", "EmbeddingCollection"], [13, 1, 1, "", "FeatureProcessedEmbeddingBagCollection"], [13, 3, 1, "", "for_each_module_of_type_do"], [13, 3, 1, "", "quant_prep_customize_row_alignment"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [13, 3, 1, "", "quant_prep_enable_register_tbes"], [13, 3, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_bag_configs"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "is_weighted"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_configs"], [13, 2, 1, "", "embedding_dim"], [13, 2, 1, "", "embedding_names_by_table"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "need_indices"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[13, 4, 1, "", "embedding_bags"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 4, 1, "", "tbes"], [13, 4, 1, "", "training"]], "torchrec.sparse": [[14, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[14, 1, 1, "", "ComputeJTDictToKJT"], [14, 1, 1, "", "ComputeKJTToJTDict"], [14, 1, 1, "", "JaggedTensor"], [14, 1, 1, "", "JaggedTensorMeta"], [14, 1, 1, "", "KeyedJaggedTensor"], [14, 1, 1, "", "KeyedTensor"], [14, 3, 1, "", "flatten_kjt_list"], [14, 3, 1, "", "jt_is_equal"], [14, 3, 1, "", "kjt_is_equal"], [14, 3, 1, "", "permute_multi_embedding"], [14, 3, 1, "", "regroup_kts"], [14, 3, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "from_dense"], [14, 2, 1, "", "from_dense_lengths"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dense"], [14, 2, 1, "", "to_dense_weights"], [14, 2, 1, "", "to_padded_dense"], [14, 2, 1, "", "to_padded_dense_weights"], [14, 2, 1, "", "values"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[14, 2, 1, "", "concat"], [14, 2, 1, "", "device"], [14, 2, 1, "", "dist_init"], [14, 2, 1, "", "dist_labels"], [14, 2, 1, "", "dist_splits"], [14, 2, 1, "", "dist_tensors"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "empty_like"], [14, 2, 1, "", "flatten_lengths"], [14, 2, 1, "", "from_jt_dict"], [14, 2, 1, "", "from_lengths_sync"], [14, 2, 1, "", "from_offsets_sync"], [14, 2, 1, "", "index_per_key"], [14, 2, 1, "", "inverse_indices"], [14, 2, 1, "", "inverse_indices_or_none"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "length_per_key_or_none"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_offset_per_key"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "offset_per_key_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "permute"], [14, 2, 1, "", "pin_memory"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "split"], [14, 2, 1, "", "stride"], [14, 2, 1, "", "stride_per_key"], [14, 2, 1, "", "stride_per_key_per_rank"], [14, 2, 1, "", "sync"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "unsync"], [14, 2, 1, "", "values"], [14, 2, 1, "", "variable_stride_per_key"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "from_tensor_list"], [14, 2, 1, "", "key_dim"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "regroup"], [14, 2, 1, "", "regroup_as_dict"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:function", "4": "py:attribute", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "function", "Python function"], "4": ["py", "attribute", "Python attribute"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "document": 0, "get": 0, "start": 0, "how": 0, "contribut": 0, "overview": 1, "why": 1, "dataset": [2, 3], "criteo": 2, "movielen": 2, "random": 2, "util": [2, 4, 5, 11], "script": 3, "contiguous_preproc_criteo": 3, "npy_preproc_criteo": 3, "distribut": [4, 5, 6], "collective_util": 4, "comm": 4, "comm_op": 4, "dist_data": [4, 6], "embed": 4, "embedding_lookup": 4, "embedding_shard": 4, "embedding_typ": 4, "embeddingbag": 4, "grouped_position_weight": 4, "model_parallel": 4, "quant_embeddingbag": 4, "train_pipelin": 4, "type": [4, 5], "mc_modul": [4, 11], "mc_embeddingbag": 4, "mc_embed": 4, "planner": 5, "constant": 5, "enumer": 5, "partition": 5, "perf_model": 5, "propos": 5, "shard_estim": 5, "stat": 5, "storage_reserv": 5, "shard": 6, "cw_shard": 6, "dp_shard": 6, "rw_shard": 6, "tw_shard": 6, "twcw_shard": 6, "twrw_shard": 6, "fx": 7, "tracer": 7, "modul": [7, 8, 10, 11, 12, 13, 14], "content": [7, 8, 10, 12, 13, 14], "infer": 8, "model_packag": 8, "metric": 9, "accuraci": 9, "auc": 9, "auprc": 9, "calibr": 9, "ctr": 9, "mae": 9, "mse": 9, "multiclass_recal": 9, "ndcg": 9, "ne": 9, "recal": 9, "precis": 9, "rauc": 9, "throughput": 9, "weighted_avg": 9, "xauc": 9, "metric_modul": 9, "rec_metr": 9, "model": 10, "deepfm": [10, 11], "dlrm": 10, "activ": 11, "crossnet": 11, "embedding_config": 11, "embedding_modul": [11, 13], "feature_processor": 11, "lazy_extens": 11, "mlp": 11, "mc_embedding_modul": 11, "optim": 12, "clip": 12, "fuse": 12, "kei": 12, "warmup": 12, "quant": 13, "spars": 14, "jagged_tensor": 14}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}})
\ No newline at end of file
diff --git a/torchrec.datasets.html b/torchrec.datasets.html
index 4590ba296..81cac1f2c 100644
--- a/torchrec.datasets.html
+++ b/torchrec.datasets.html
@@ -10,7 +10,7 @@
- torchrec.datasets — TorchRec 0.9.0 documentation
+ torchrec.datasets — TorchRec 1.1.0 documentation
@@ -30,6 +30,9 @@
+
+
+
@@ -268,7 +271,7 @@
- 0.9.0.dev20240801+cpu
+ 1.1.0.dev20240924+cpu
@@ -391,19 +394,567 @@
-
-torchrec.datasets
-
-torchrec.datasets.criteo
+
+torchrec.datasets
+Torchrec Datasets
+Torchrec contains two popular recys datasets, the Kaggle/Criteo Display Advertising Dataset
+and the MovieLens 20M Dataset.
+Additionally, it contains a RandomDataset, which is useful to generate random data in the same format as the above.
+Lastly, it contains scripts and utilities for pre-processing, loading, etc.
+Example:
+from torchrec.datasets.criteo import criteo_kaggle
+datapipe = criteo_terabyte (
+ ( "/home/datasets/criteo/day_0.tsv" , "/home/datasets/criteo/day_1.tsv" )
+)
+datapipe = dp . iter . Batcher ( datapipe , 100 )
+datapipe = dp . iter . Collator ( datapipe )
+batch = next ( iter ( datapipe ))
+
+
+
+torchrec.datasets.criteo
+
+
+class torchrec.datasets.criteo. BinaryCriteoUtils
+Bases: object
+Utility functions used to preprocess, save, load, partition, etc. the Criteo
+dataset in a binary (numpy) format.
+
+
+static get_file_row_ranges_and_remainder ( lengths : List [ int ] , rank : int , world_size : int , start_row : int = 0 , last_row : Optional [ int ] = None ) → Tuple [ Dict [ int , Tuple [ int , int ] ] , int ]
+Given a rank, world_size, and the lengths (number of rows) for a list of files,
+return which files and which portions of those files (represented as row ranges
+- all range indices are inclusive) should be handled by the rank. Each rank
+will be assigned the same number of rows.
+The ranges are determined in such a way that each rank deals with large
+continuous ranges of files. This enables each rank to reduce the amount of data
+it needs to read while avoiding seeks.
+
+Parameters:
+
+
+Returns:
+First item is a mapping of files
+to the range in those files to be handled by the rank. The keys of this dict are indices.
+The second item is the remainder of dataset length / world size.
+
+Return type:
+output (Tuple[Dict[int, Tuple[int, int]], int])
+
+
+
+
+
+
+static get_shape_from_npy ( path : str , path_manager_key : str = 'torchrec' ) → Tuple [ int , ... ]
+Returns the shape of an npy file using only its header.
+
+Parameters:
+
+
+Returns:
+Shape tuple.
+
+Return type:
+shape (Tuple[int, …])
+
+
+
+
+
+
+static load_npy_range ( fname : str , start_row : int , num_rows : int , path_manager_key : str = 'torchrec' , mmap_mode : bool = False ) → ndarray
+Load part of an npy file.
+NOTE: Assumes npy represents a numpy array of ndim 2.
+
+Parameters:
+
+fname (str ) – path string to npy file.
+start_row (int ) – starting row from the npy file.
+num_rows (int ) – number of rows to get from the npy file.
+path_manager_key (str ) – Path manager key used to load from different
+filesystems.
+
+
+Returns:
+
+numpy array with the desired range of data from the supplied npy file.
+
+
+
+
+Return type:
+output (np.ndarray)
+
+
+
+
+
+
+static shuffle ( input_dir_labels_and_dense : str , input_dir_sparse : str , output_dir_shuffled : str , rows_per_day : Dict [ int , int ] , output_dir_full_set : Optional [ str ] = None , days : int = 24 , int_columns : int = 13 , sparse_columns : int = 26 , path_manager_key : str = 'torchrec' , random_seed : int = 0 ) → None
+Shuffle the dataset. Expects the files to be in .npy format and the data
+to be split by day and by dense, sparse and label data.
+Dense data must be in: day_x_dense.npy
+Sparse data must be in: day_x_sparse.npy
+Labels data must be in: day_x_labels.npy
+The dataset will be reconstructed, shuffled and then split back into
+separate dense, sparse and labels files.
+This will only shuffle the first DAYS-1 days as the training set. The final day will remain
+untouched as the validation, and training set.
+
+Parameters:
+
+input_dir_labels_and_dense (str ) – Input directory of labels and dense npy files.
+input_dir_sparse (str ) – Input directory of sparse npy files.
+output_dir_shuffled (str ) – Output directory for shuffled labels, dense and sparse npy files.
+Dict [ int (rows_per_day ) – Number of rows in each file.
+int ] – Number of rows in each file.
+output_dir_full_set (str ) – Output directory of the full dataset, if desired.
+days (int ) – Number of day files.
+int_columns (int ) – Number of columns with dense features.
+sparse_columns (int ) – Total number of categorical columns.
+path_manager_key (str ) – Path manager key used to load from different filesystems.
+random_seed (int ) – Random seed used for the random.shuffle operator.
+
+
+
+
+
+
+
+static sparse_to_contiguous ( in_files : List [ str ] , output_dir : str , frequency_threshold : int = 3 , columns : int = 26 , path_manager_key : str = 'torchrec' , output_file_suffix : str = '_contig_freq.npy' ) → None
+Convert all sparse .npy files to have contiguous integers. Store in a separate
+.npy file. All input files must be processed together because columns
+can have matching IDs between files. Hence, they must be transformed
+together. Also, the transformed IDs are not unique between columns. IDs
+that appear less than frequency_threshold amount of times will be remapped
+to have a value of 1.
+Example transformation, frequency_threshold of 2:
+day_0_sparse.npy
+| col_0 | col_1 |
+—————–
+| abc | xyz |
+| iop | xyz |
+day_1_sparse.npy
+| col_0 | col_1 |
+—————–
+| iop | tuv |
+| lkj | xyz |
+day_0_sparse_contig.npy
+| col_0 | col_1 |
+—————–
+| 1 | 2 |
+| 2 | 2 |
+day_1_sparse_contig.npy
+| col_0 | col_1 |
+—————–
+| 2 | 1 |
+| 1 | 2 |
+
+Parameters:
+
+List [ str ] (in_files ) – Input directory of npy files.
+output_dir (str ) – Output directory of processed npy files.
+frequency_threshold – IDs occurring less than this frequency will be remapped to a value of 1.
+path_manager_key (str ) – Path manager key used to load from different filesystems.
+
+
+Returns:
+None.
+
+
+
+
+
+
+static tsv_to_npys ( in_file : str , out_dense_file : str , out_sparse_file : str , out_labels_file : str , dataset_name : str = 'criteo_1tb' , path_manager_key : str = 'torchrec' ) → None
+Convert one Criteo tsv file to three npy files: one for dense (np.float32), one
+for sparse (np.int32), and one for labels (np.int32).
+The tsv file is expected to be part of the Criteo 1TB Click Logs Dataset (“criteo_1tb”)
+or the Criteo Kaggle Display Advertising Challenge dataset (“criteo_kaggle”).
+For the “criteo_kaggle” test set, we set the labels to -1 representing filler data,
+because label data is not included in the “criteo_kaggle” test set.
+
+Parameters:
+
+in_file (str ) – Input tsv file path.
+out_dense_file (str ) – Output dense npy file path.
+out_sparse_file (str ) – Output sparse npy file path.
+out_labels_file (str ) – Output labels npy file path.
+dataset_name (str ) – The dataset name. “criteo_1tb” or “criteo_kaggle” is expected.
+path_manager_key (str ) – Path manager key used to load from different
+filesystems.
+
+
+Returns:
+None.
+
+
+
+
+
+
+
+
+class torchrec.datasets.criteo. CriteoIterDataPipe ( paths: ~typing.Iterable[str], *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw )
+Bases: IterDataPipe
+IterDataPipe that can be used to stream either the Criteo 1TB Click Logs Dataset
+(https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/ ) or the
+Kaggle/Criteo Display Advertising Dataset
+(https://www.kaggle.com/c/criteo-display-ad-challenge/ ) from the source TSV
+files.
+
+Parameters:
+
+paths (Iterable [ str ] ) – local paths to TSV files that constitute the Criteo
+dataset.
+row_mapper (Optional [ Callable [ [ List [ str ] ] , Any ] ] ) – function to apply to each
+split TSV line.
+open_kw – options to pass to underlying invocation of
+iopath.common.file_io.PathManager.open.
+
+
+
+Example:
+datapipe = CriteoIterDataPipe (
+ ( "/home/datasets/criteo/day_0.tsv" , "/home/datasets/criteo/day_1.tsv" )
+)
+datapipe = dp . iter . Batcher ( datapipe , 100 )
+datapipe = dp . iter . Collator ( datapipe )
+batch = next ( iter ( datapipe ))
+
+
+
+
+
+
+class torchrec.datasets.criteo. InMemoryBinaryCriteoIterDataPipe ( stage : str , dense_paths : List [ str ] , sparse_paths : List [ str ] , labels_paths : List [ str ] , batch_size : int , rank : int , world_size : int , drop_last : Optional [ bool ] = False , shuffle_batches : bool = False , shuffle_training_set : bool = False , shuffle_training_set_random_seed : int = 0 , mmap_mode : bool = False , hashes : Optional [ List [ int ] ] = None , path_manager_key : str = 'torchrec' )
+Bases: IterableDataset
+Datapipe designed to operate over binary (npy) versions of Criteo datasets. Loads
+the entire dataset into memory to prevent disk speed from affecting throughout. Each
+rank reads only the data for the portion of the dataset it is responsible for.
+The torchrec/datasets/scripts/npy_preproc_criteo.py script can be used to convert
+the Criteo tsv files to the npy files expected by this dataset.
+
+Parameters:
+
+stage (str ) – “train”, “val”, or “test”.
+dense_paths (List [ str ] ) – List of path strings to dense npy files.
+sparse_paths (List [ str ] ) – List of path strings to sparse npy files.
+labels_paths (List [ str ] ) – List of path strings to labels npy files.
+batch_size (int ) – batch size.
+rank (int ) – rank.
+world_size (int ) – world size.
+shuffle_batches (bool ) – Whether to shuffle batches
+hashes (Optional [ int ] ) – List of max categorical feature value for each feature.
+Length of this list should be CAT_FEATURE_COUNT.
+path_manager_key (str ) – Path manager key used to load from different
+filesystems.
+
+
+
+Example:
+template = "/home/datasets/criteo/1tb_binary/day_ {} _ {} .npy"
+datapipe = InMemoryBinaryCriteoIterDataPipe (
+ dense_paths = [ template . format ( 0 , "dense" ), template . format ( 1 , "dense" )],
+ sparse_paths = [ template . format ( 0 , "sparse" ), template . format ( 1 , "sparse" )],
+ labels_paths = [ template . format ( 0 , "labels" ), template . format ( 1 , "labels" )],
+ batch_size = 1024 ,
+ rank = torch . distributed . get_rank (),
+ world_size = torch . distributed . get_world_size (),
+)
+batch = next ( iter ( datapipe ))
+
+
+
+
+
+
+torchrec.datasets.criteo. criteo_kaggle ( path: str, *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw ) → IterDataPipe
+Kaggle/Criteo Display Advertising Dataset
+
+Parameters:
+
+path (str ) – local path to train or test dataset file.
+row_mapper (Optional [ Callable [ [ List [ str ] ] , Any ] ] ) – function to apply to each split TSV line.
+open_kw – options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
+
+
+
+Example:
+train_datapipe = criteo_kaggle (
+ "/home/datasets/criteo_kaggle/train.txt" ,
+)
+example = next ( iter ( train_datapipe ))
+test_datapipe = criteo_kaggle (
+ "/home/datasets/criteo_kaggle/test.txt" ,
+)
+example = next ( iter ( test_datapipe ))
+
+
+
+
+
+
+torchrec.datasets.criteo. criteo_terabyte ( paths: ~typing.Iterable[str], *, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw ) → IterDataPipe
+Criteo 1TB Click Logs Dataset
+
+Parameters:
+
+paths (Iterable [ str ] ) – local paths to TSV files that constitute the Criteo 1TB
+dataset.
+row_mapper (Optional [ Callable [ [ List [ str ] ] , Any ] ] ) – function to apply to each
+split TSV line.
+open_kw – options to pass to underlying invocation of
+iopath.common.file_io.PathManager.open.
+
+
+
+Example:
+datapipe = criteo_terabyte (
+ ( "/home/datasets/criteo/day_0.tsv" , "/home/datasets/criteo/day_1.tsv" )
+)
+datapipe = dp . iter . Batcher ( datapipe , 100 )
+datapipe = dp . iter . Collator ( datapipe )
+batch = next ( iter ( datapipe ))
+
+
+
+
-
-torchrec.datasets.movielens
+
+torchrec.datasets.movielens
+
+
+torchrec.datasets.movielens. movielens_20m ( root: str, *, include_movies_data: bool = False, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw ) → IterDataPipe
+MovieLens 20M Dataset
+:param root: local path to root directory containing MovieLens 20M dataset files.
+:type root: str
+:param include_movies_data: if True, adds movies data to each line.
+:type include_movies_data: bool
+:param row_mapper: function to apply to each split line.
+:type row_mapper: Optional[Callable[[List[str]], Any]]
+:param open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
+Example:
+datapipe = movielens_20m ( "/home/datasets/ml-20" )
+datapipe = dp . iter . Batch ( datapipe , 100 )
+datapipe = dp . iter . Collate ( datapipe )
+batch = next ( iter ( datapipe ))
+
+
+
+
+
+
+torchrec.datasets.movielens. movielens_25m ( root: str, *, include_movies_data: bool = False, row_mapper: ~typing.Optional[~typing.Callable[[~typing.List[str]], ~typing.Any]] = <function _default_row_mapper>, **open_kw ) → IterDataPipe
+MovieLens 25M Dataset
+:param root: local path to root directory containing MovieLens 25M dataset files.
+:type root: str
+:param include_movies_data: if True, adds movies data to each line.
+:type include_movies_data: bool
+:param row_mapper: function to apply to each split line.
+:type row_mapper: Optional[Callable[[List[str]], Any]]
+:param open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
+Example:
+datapipe = movielens_25m ( "/home/datasets/ml-25" )
+datapipe = dp . iter . Batch ( datapipe , 100 )
+datapipe = dp . iter . Collate ( datapipe )
+batch = next ( iter ( datapipe ))
+
+
+
+
-
-torchrec.datasets.random
+
+torchrec.datasets.random
+
+
+class torchrec.datasets.random. RandomRecDataset ( keys : List [ str ] , batch_size : int , hash_size : Optional [ int ] = None , hash_sizes : Optional [ List [ int ] ] = None , ids_per_feature : Optional [ int ] = None , ids_per_features : Optional [ List [ int ] ] = None , num_dense : int = 50 , manual_seed : Optional [ int ] = None , num_batches : Optional [ int ] = None , num_generated_batches : int = 10 , min_ids_per_feature : Optional [ int ] = None , min_ids_per_features : Optional [ List [ int ] ] = None )
+Bases: IterableDataset
[Batch
]
+Random iterable dataset used to generate batches for recommender systems
+(RecSys). Currently produces unweighted sparse features only. TODO: Add
+weighted sparse features.
+
+Parameters:
+
+keys (List [ str ] ) – List of feature names for sparse features.
+batch_size (int ) – batch size.
+hash_size (Optional [ int ] ) – Max sparse id value. All sparse IDs will be taken
+modulo this value.
+hash_sizes (Optional [ List [ int ] ] ) – Max sparse id value per feature in keys. Each
+sparse ID will be taken modulo the corresponding value from this argument. Note, if this is used, hash_size will be ignored.
+ids_per_feature (Optional [ int ] ) – Number of IDs per sparse feature per sample.
+ids_per_features (Optional [ List [ int ] ] ) – Number of IDs per sparse feature per sample in each key. Note, if this is used, ids_per_feature will be ignored.
+num_dense (int ) – Number of dense features.
+manual_seed (int ) – Seed for deterministic behavior.
+num_batches – (Optional[int]): Num batches to generate before raising StopIteration
+int (num_generated_batches ) – Num batches to cache. If num_batches > num_generated batches, then we will cycle to the first generated batch.
+If this value is negative, batches will be generated on the fly.
+min_ids_per_feature (Optional [ int ] ) – Minimum number of IDs per features.
+min_ids_per_features (Optional [ List [ int ] ] ) – Minimum number of IDs per sparse feature per sample in each key. Note, if this is used, min_ids_per_feature will be ignored.
+
+
+
+Example:
+dataset = RandomRecDataset (
+ keys = [ "feat1" , "feat2" ],
+ batch_size = 16 ,
+ hash_size = 100_000 ,
+ ids_per_feature = 1 ,
+ num_dense = 13 ,
+),
+example = next ( iter ( dataset ))
+
+
+
+
-
-torchrec.datasets.utils
+
+torchrec.datasets.utils
+
+
+class torchrec.datasets.utils. Batch ( dense_features : torch.Tensor , sparse_features : torchrec.sparse.jagged_tensor.KeyedJaggedTensor , labels : torch.Tensor )
+Bases: Pipelineable
+
+
+dense_features : Tensor
+
+
+
+
+labels : Tensor
+
+
+
+
+pin_memory ( ) → Batch
+
+
+
+
+record_stream ( stream : Stream ) → None
+See https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html
+
+
+
+
+sparse_features : KeyedJaggedTensor
+
+
+
+
+to ( device : device , non_blocking : bool = False ) → Batch
+Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html ,
+to might return self or a copy of self. So please remember to use to with the assignment operator,
+for example, in = in.to(new_device) .
+
+
+
+
+
+
+class torchrec.datasets.utils. Limit ( datapipe : IterDataPipe , limit : int )
+Bases: IterDataPipe
+
+
+
+
+class torchrec.datasets.utils. LoadFiles ( datapipe : Iterable [ str ] , mode : str = 'b' , length : int = - 1 , path_manager_key : str = 'torchrec' , ** open_kw )
+Bases: IterDataPipe
[Tuple
[str
, IOBase
]]
+Taken and adapted from torch.utils.data.datapipes.iter.LoadFilesFromDisk
+TODO:
+Merge this back or replace this with something in core Datapipes lib
+
+
+
+
+class torchrec.datasets.utils. ParallelReadConcat ( *datapipes: ~torch.utils.data.datapipes.datapipe.IterDataPipe, dp_selector: ~typing.Callable[[~typing.Sequence[~torch.utils.data.datapipes.datapipe.IterDataPipe]], ~typing.Sequence[~torch.utils.data.datapipes.datapipe.IterDataPipe]] = <function _default_dp_selector> )
+Bases: IterDataPipe
+ParallelReadConcat
.
+Iterable DataPipe that concatenates multiple Iterable DataPipes.
+When used with a DataLoader, assigns a subset of datapipes to each DataLoader worker
+to allow for parallel reading.
+:param datapipes: IterDataPipe instances to read from.
+:param dp_selector: function that each DataLoader worker would use to determine the subset of datapipes
+:param to read from.:
+Example:
+datapipes = [
+ criteo_terabyte (
+ ( f "/home/local/datasets/criteo/shard_ { idx } .tsv" ,),
+ )
+ . batch ( 100 )
+ . collate ()
+ for idx in range ( 4 )
+]
+dataloader = DataLoader (
+ ParallelReadConcat ( * datapipes ), num_workers = 4 , batch_size = None
+)
+
+
+
+
+
+
+class torchrec.datasets.utils. ReadLinesFromCSV ( datapipe : IterDataPipe [ Tuple [ str , IOBase ] ] , skip_first_line : bool = False , ** kw )
+Bases: IterDataPipe
+
+
+
+
+torchrec.datasets.utils. idx_split_train_val ( datapipe: ~torch.utils.data.datapipes.datapipe.IterDataPipe, train_perc: float, decimal_places: int = 2, key_fn: ~typing.Callable[[int], int] = <function _default_key_fn> ) → Tuple [ IterDataPipe , IterDataPipe ]
+
+
+
+
+torchrec.datasets.utils. rand_split_train_val ( datapipe : IterDataPipe , train_perc : float , random_seed : int = 0 ) → Tuple [ IterDataPipe , IterDataPipe ]
+Via uniform random sampling, generates two IterDataPipe instances representing
+disjoint train and val splits of the given IterDataPipe.
+
+Parameters:
+
+datapipe (IterDataPipe ) – datapipe to split.
+train_perc (float ) – value in range (0.0, 1.0) specifying target proportion of
+datapipe samples to include in train split. Note that the actual proportion
+is not guaranteed to match train_perc exactly.
+random_seed (int ) – determines split membership for a given sample
+and train_perc. Use the same value across calls to generate consistent splits.
+
+
+
+Example:
+datapipe = criteo_terabyte (
+ ( "/home/datasets/criteo/day_0.tsv" , "/home/datasets/criteo/day_1.tsv" )
+)
+train_datapipe , val_datapipe = rand_split_train_val ( datapipe , 0.75 )
+train_batch = next ( iter ( train_datapipe ))
+val_batch = next ( iter ( val_datapipe ))
+
+
+
+
+
+
+torchrec.datasets.utils. safe_cast ( val : T , dest_type : Callable [ [ T ] , T ] , default : T ) → T
+
+
+
+
+torchrec.datasets.utils. train_filter ( key_fn : Callable [ [ int ] , int ] , train_perc : float , decimal_places : int , idx : int ) → bool
+
+
+
+
+torchrec.datasets.utils. val_filter ( key_fn : Callable [ [ int ] , int ] , train_perc : float , decimal_places : int , idx : int ) → bool
+
+
@@ -431,7 +982,7 @@ torchrec.datasets.utils
@@ -451,10 +1002,10 @@ torchrec.datasets.utils