From b1db83e874075c237942408f7dd38556a1916522 Mon Sep 17 00:00:00 2001 From: facebook-github-bot Date: Thu, 26 Sep 2024 13:02:39 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20=20@=20a4ea5?= =?UTF-8?q?d1b13e4da2d23ad491a400178f50e15e7fc=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- genindex.html | 6 +- index.html | 2 +- objects.inv | Bin 13507 -> 13519 bytes overview.html | 2 +- py-modindex.html | 2 +- search.html | 2 +- searchindex.js | 2 +- torchrec.datasets.html | 2 +- torchrec.datasets.scripts.html | 2 +- torchrec.distributed.html | 7 +- torchrec.distributed.planner.html | 2 +- torchrec.distributed.sharding.html | 2 +- torchrec.fx.html | 2 +- torchrec.inference.html | 2 +- torchrec.metrics.html | 2 +- torchrec.models.html | 2 +- torchrec.modules.html | 131 ++++- torchrec.optim.html | 14 +- torchrec.quant.html | 2 +- torchrec.sparse.html | 787 +++++++++++++++++++++++++---- 20 files changed, 847 insertions(+), 126 deletions(-) diff --git a/genindex.html b/genindex.html index d5c770c3d..dc02aa4aa 100644 --- a/genindex.html +++ b/genindex.html @@ -268,7 +268,7 @@
- 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
@@ -736,6 +736,8 @@

C

  • clamp() (torchrec.distributed.planner.utils.LuusJaakolaSearch method)
  • clf_to_bytes() (torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer static method) +
  • +
  • clip_grad_norm_() (torchrec.optim.clipping.GradientClippingOptimizer method)
  • coalesce_history_metadata() (torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy method) @@ -3679,6 +3681,8 @@

    S

  • sgdr_period (torchrec.optim.warmup.WarmupStage attribute)
  • SHAMPOO (torchrec.distributed.embedding_types.OptimType attribute) +
  • +
  • SHAMPOO_MRS (torchrec.distributed.embedding_types.OptimType attribute)
  • SHAMPOO_V2 (torchrec.distributed.embedding_types.OptimType attribute)
  • diff --git a/index.html b/index.html index 915c9dbfe..16120ec95 100644 --- a/index.html +++ b/index.html @@ -272,7 +272,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/objects.inv b/objects.inv index e7fde3a634168b8b378cb78a004ede9aa76b9deb..b93e51a7a2416b1dd79c2269583ba734fe4c92b6 100644 GIT binary patch delta 1473 zcmV;y1wQ)2Y0qhpOaV5rO?VA|xm&KRaK5~^gh9$tafmR=d{IN&31XK9@^)&Dos9yO z#;=H|YBVzP{z?dl!``x#l9VMWbSTpUgg+E*5Qn%Ap-o(aZl4=qV9$CKP`Lt%*@_SP zj#s>^TA!rD&XltY;*Dt3R+-w5AR?Pn=khRv`uo8yEiY97+NZdo;BNt43(4BTHT1UB zhWaJNyw;^31^m9!Ns##3iw`_ya2o1IFD&Kq9ahhGVdrK}!*NFGr@7T$rQ(_m%h`S@ z?qMPuc@dLZ94zL>?4$#DC*DaA(mg+?xv;nWjO1bMz>}X27=K9W<=|XU)Q7{noT!h8 zaGFCO3hM}|J|50lQNuLtKWx$LwP%zYrTSQ0|E+dk5bXMAFV~)JV$1>bu;6d5I|;Jo zo-N*DHzEaAy9z5e+MQpqZgw>wQmyW{RO8p~TWuE@1G1tHpq_#}6;vZ+32uoFPX z*Z7G)V%$1cV}Dt>&2JPB>q0*Xw|b|a1g1>A!9%L|c*v`tJ*oLhayXd647k&KuSgPe zU#c|rvH?sa_Ua#Q644)gE&dL&+%l62^=jr^JZRW%mIyd!KQ|q4A?e%{0QNnhOQ3H> zuS+_2G@?7XJ{kl*7db(@!+a(cb_RHAB4B3FW>R4%)|0glED&8)p*n0SPfQbMU8kb>FIKfOUlu68pjpMv<-i=^TrvF-iR%T~H3K-TFTz(NJ67L! zuHtaxM`pF!qvTW>r;}BjtB3p5`r>{)bFAKza4;+aW^9w5FdYKjXp_A#8GkVHV-HSLE!#|%`zDV^ z?e_ZWqy7n6{qe~6M31$;8=*Dju)!L^K>EZEDj>*&!`y&~gG7u23 z#Qtyt^(;n^L^AGDxZG2U0Vp~J1soBeh8&oH+M5~h1h8PGM?IPZ{xg#GYy%Nr zvIsI}>43=CsP-Nw5}}M^_&#zGF=g0+htV%r)ZQ1vKpdT6j}C$CW{zQ^sF1R8{Qe%3 z+A?qjm$}QA&y!CxFak+plYui90(U)=nKPUMlbn+`G(-XXlYBHJBmT?r=`UU1IJ}N< b{4zq<^M?&Q9ou)hzEJZu;v4@DgKnZ8E`P|3 delta 1423 zcmV;A1#tS$X~SudOaV2qO?VA|`SRWphTJVz2&m$4UzB;AhPD&LE)C?Z(HuJ)1uBhS z5mD7BVdTA%5DPIguu!ZX4~sNkZ+wkO+~wV?lc<9TQsh&wKxSK+KBihceApoK4-8KK*!hX zi9cf823KQ6S-E9z6c6j7JqfpZ*PaBXOubn{s`qQitDilo`ATv)Fv1MD(|e*w5_4aw zH1?PQOeFT|A8z>2AABwT4zk=blM409Q{O4xAKJXuF6lZ|G7aQO7K`$br8FxJ){|Kgdp`1j zv-PyrYn%oi8GY)pK`P4AVdUuqarCj=L5Ee~9k}9pEH}=h>mB%8pasxs%}(DFWqplMEF>3SoLG{8t0` z)LWO6R}~!v;_2ch=#zOBJ_7q{lera(2CKauds{fON)`?oe>yy4J%g!X>21!7_Xp?M zj{3f{khcFp{%DfTUM;iSbt;PgVzLv(IbRkizo1#g&gH=D+*~pJ5sB*s*)=yetSZdGWkus-p zumL5_=?f$65Yh${7XeXaA!M!-=v6PLr<3t76AO>kYaJ?iI!+y!HbOIeBIq5T{&IDNp=F1D#xopUuH)GWhLf`UX>2RPT{gu$VEYP!&NCahku4 zwm<1yIS2}3v&U@n&9ixvTQMRGDl2w0Z1qvqUfj1GlZi1M0(%FOs4*WKYZx}b7+HzU z`JgG%`pEe1A&dU3fBaK?-q>A};xR~n_IZJ-s{Yp#KQlz>`mRz?1hTH&fEdLQe9;sK zzG3O2rF+EAZ3FczMvz1@?ozni?TG;>wq&DapstUWkHemIP+6?xmYa8xT97=T0ieH8 zYwy5)*pa5(Meg>E%S^T5I>`=gx$9kSb3v33=!OLx5uk<~n1I?l6YvDEV5LVrngsqc zlJ#r@5nr+hGG^(3$k?d%h9wfAjAQsdauG3Q*nx-9uRYY> - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu diff --git a/py-modindex.html b/py-modindex.html index f2ff4c600..1c73cf620 100644 --- a/py-modindex.html +++ b/py-modindex.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/search.html b/search.html index 2c3d6d576..edc67890d 100644 --- a/search.html +++ b/search.html @@ -268,7 +268,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/searchindex.js b/searchindex.js index 9b80fbef5..4596f5e4b 100644 --- a/searchindex.js +++ b/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["index", "overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "overview.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.metrics.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "TorchRec Overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"recommendation system": 0, "shard": [0, 1, 4, 5, 8, 11, 12, 13], "distributed train": 0, "special": [0, 1, 7, 9, 11, 12], "librari": [0, 1], "within": [0, 4, 5, 6, 8, 10, 11, 14], "pytorch": [0, 1, 2, 4, 11, 12, 14], "ecosystem": [0, 1], "tailor": 0, "build": [0, 1, 5], "scale": [0, 1], "deploi": [0, 1, 8], "larg": [0, 1, 2, 5], "recommend": [0, 1, 2, 9, 10], "system": [0, 1, 2, 4, 5, 10], "nich": 0, "directli": [0, 4, 12], "address": [0, 1, 4], "standard": 0, "offer": 0, "advanc": [0, 1, 12], "featur": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 14], "complex": [0, 7], "techniqu": [0, 1], "massiv": [0, 1], "embed": [0, 1, 5, 6, 7, 10, 11, 13, 14], "tabl": [0, 1, 4, 5, 6, 7, 8, 10, 11, 13], "enhanc": 0, "distribut": [0, 1, 2, 8, 9, 11, 12, 14], "train": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "capabl": [0, 1], "topic": 0, "thi": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "section": 0, "help": [0, 4], "you": [0, 4, 6, 7, 14], "overview": 0, "A": [0, 2, 4, 5, 6, 7, 8, 9, 12, 14], "short": 0, "intro": 0, "why": 0, "need": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "set": [0, 2, 4, 5, 6, 8, 9, 11, 12], "up": [0, 4, 5, 13], "learn": [0, 8, 10, 11, 12], "instal": 0, "us": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "your": [0, 7, 9], "environ": [0, 1, 4, 8], "tutori": 0, "follow": [0, 1, 4, 5, 6, 9, 10, 11, 12, 14], "our": 0, "interact": [0, 4, 10, 11], "step": [0, 4, 5, 12], "real": [0, 5], "life": 0, "exampl": [0, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "we": [0, 2, 4, 5, 6, 7, 9, 11, 12, 13, 14], "feedback": [0, 5], "from": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "commun": [0, 1, 4, 5, 6, 9], "If": [0, 2, 4, 5, 8, 9, 11, 12, 14], "ar": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "interest": 0, "improv": [0, 1, 12], "project": [0, 10], "here": [0, 5, 10], "can": [0, 1, 2, 4, 5, 9, 10, 11, 12, 14], "visit": 0, "github": [0, 11], "repositori": 0, "There": [0, 4], "yoou": 0, "find": [0, 5], "sourc": [0, 2, 10, 11], "code": [0, 1, 4, 11], "issu": [0, 4, 6, 11], "ongo": 0, "submit": 0, "encount": 0, "ani": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 14], "bug": 0, "have": [0, 2, 4, 5, 6, 9, 10, 11, 12, 14], "suggest": 0, "pleas": [0, 2, 4, 5, 8, 9, 11, 14], "an": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "through": [0, 4, 7, 9, 12], "tracker": 0, "propos": [0, 10], "chang": [0, 4, 11, 12], "fork": [0, 4], "pull": 0, "request": [0, 4, 8, 12], "whether": [0, 2, 4, 5, 7, 9, 11, 13], "s": [0, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "fix": [0, 14], "ad": [0, 2, 4, 8, 9, 11, 12], "new": [0, 4, 5, 9], "alwai": [0, 11, 14], "make": [0, 4, 11, 12], "sure": [0, 12], "review": 0, "md": 0, "go": [0, 12], "repo": 0, "design": [1, 2, 4, 8, 9, 11], "provid": [1, 4, 5, 6, 8, 9, 10, 11, 13], "common": [1, 2, 11, 14], "primit": [1, 4, 6], "creat": [1, 4, 7, 8, 9, 11, 12, 14], "state": [1, 4, 5, 8, 9, 11, 12], "art": 1, "person": [1, 10], "model": [1, 4, 5, 6, 7, 8, 9, 11, 12, 13], "path": [1, 2, 4, 5, 8], "product": [1, 10], "wide": 1, "adopt": 1, "mani": [1, 4, 6], "meta": [1, 4, 5, 8], "infer": [1, 4, 5, 6, 13, 14], "workflow": 1, "uniqu": [1, 2, 5, 9, 11], "challeng": [1, 2], "which": [1, 2, 4, 5, 6, 8, 9, 11, 12, 14], "focu": [1, 9], "regular": 1, "more": [1, 4, 5, 6, 9, 11], "specif": [1, 4, 5, 8, 12], "gener": [1, 2, 4, 5, 7, 8, 10, 11, 12, 14], "compon": [1, 9, 11], "simplist": 1, "modul": [1, 4, 5, 6, 9], "author": [1, 4], "flexibl": [1, 11], "customiz": [1, 5], "method": [1, 4, 7, 8, 9, 11], "row": [1, 2, 4, 5, 6], "wise": [1, 4, 5, 6, 11], "column": [1, 2, 5, 6], "so": [1, 2, 4, 5, 9, 10, 12, 14], "automat": [1, 4, 5, 9, 14], "determin": [1, 2, 4, 5, 6], "best": [1, 5], "plan": [1, 4, 5, 8, 11], "devic": [1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "topolog": [1, 4, 5, 6], "effici": [1, 5, 11], "memori": [1, 2, 4, 5, 8, 9, 12], "balanc": [1, 5], "while": [1, 2, 4, 5, 6, 7, 8, 11], "support": [1, 4, 5, 6, 7, 9, 11, 12], "basic": [1, 10, 14], "extend": [1, 4], "sophist": 1, "parallel": [1, 2, 4, 6], "incred": 1, "optim": [1, 4, 5, 8, 9, 11, 13], "top": [1, 4, 9, 11], "fbgemm": [1, 4, 5, 6, 13], "after": [1, 4, 5, 6, 9, 11], "all": [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "power": 1, "some": [1, 4, 9, 10, 14], "largest": [1, 5], "frictionless": 1, "deploy": 1, "simpl": [1, 10], "api": [1, 4, 6, 7, 9, 11], "transform": [1, 2, 4, 8, 11], "load": [1, 2, 4, 5, 6, 11, 12], "c": [1, 2, 4, 6, 8, 14], "most": [1, 4, 8, 12], "integr": 1, "built": [1, 11], "mean": [1, 4, 5, 9, 11], "seamlessli": 1, "exist": [1, 4, 6, 11, 14], "tool": 1, "allow": [1, 2, 4, 5, 7, 9, 11, 12], "develop": 1, "leverag": [1, 11], "knowledg": [1, 4, 5, 9], "codebas": 1, "util": [1, 6], "By": 1, "being": [1, 4, 5, 8, 9, 11], "part": [1, 2, 4, 5, 6, 11, 12], "benefit": 1, "robust": 1, "continu": [1, 2], "updat": [1, 4, 5, 6, 8, 9, 11, 12], "come": [1, 11], "contain": [2, 4, 5, 6, 8, 9, 11, 12, 13], "two": [2, 4, 5, 9, 10, 11, 14], "popular": [2, 10], "reci": 2, "kaggl": 2, "displai": 2, "advertis": 2, "20m": 2, "addition": 2, "randomdataset": 2, "data": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "same": [2, 4, 5, 6, 8, 9, 10, 11, 14], "format": [2, 3, 5, 8, 14], "abov": [2, 11, 14], "lastli": 2, "script": [2, 14], "pre": [2, 9, 11, 12], "process": [2, 3, 4, 5, 6, 9, 10, 11, 13], "etc": [2, 4, 8, 12, 14], "import": [2, 4, 5, 8, 11, 13], "criteo_kaggl": 2, "datapip": 2, "criteo_terabyt": 2, "home": 2, "day_0": 2, "tsv": [2, 3], "day_1": 2, "dp": [2, 5], "iter": [2, 4, 5, 11, 12], "batcher": 2, "100": [2, 4, 5, 9, 10, 11], "collat": 2, "batch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "next": [2, 5], "class": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "binarycriteoutil": 2, "base": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "object": [2, 4, 5, 8, 9, 11, 12], "function": [2, 3, 4, 5, 6, 7, 8, 11, 12, 14], "preprocess": [2, 3, 11], "save": [2, 3, 4, 5, 11, 12], "partit": [2, 5, 6], "binari": [2, 3, 5, 9], "numpi": 2, "static": [2, 4, 5, 9, 12, 14], "get_file_row_ranges_and_remaind": 2, "length": [2, 4, 5, 6, 10, 11, 13, 14], "list": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "int": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "rank": [2, 4, 5, 6, 9, 10, 11, 12, 14], "world_siz": [2, 4, 5, 6, 8, 9], "start_row": 2, "0": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "last_row": 2, "option": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "none": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "tupl": [2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14], "dict": [2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "given": [2, 4, 5, 6, 7, 11], "number": [2, 4, 5, 6, 8, 9, 10, 11, 14], "file": [2, 3, 4], "return": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "portion": 2, "those": [2, 10, 11], "repres": [2, 4, 5, 8, 10, 11, 13, 14], "rang": [2, 5, 7, 11], "indic": [2, 4, 5, 6, 8, 11, 12, 13, 14], "inclus": 2, "should": [2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "handl": [2, 4, 5, 6, 7, 11, 12], "each": [2, 4, 5, 6, 9, 10, 11, 13, 14], "assign": [2, 4, 14], "The": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "wai": [2, 4, 5], "deal": 2, "enabl": [2, 4, 5, 9, 12], "reduc": [2, 4, 6, 11, 13], "amount": [2, 5], "read": 2, "avoid": [2, 4, 8, 9, 11, 12], "seek": 2, "paramet": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "count": [2, 5, 9, 11], "world": [2, 4, 6], "size": [2, 4, 5, 6, 9, 10, 11, 13, 14], "first": [2, 4, 5, 6, 10, 11, 12, 14], "item": [2, 4, 6], "map": [2, 4, 8, 9, 11, 12, 13], "kei": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "second": [2, 4, 5, 6, 9, 10, 11, 14], "remaind": 2, "type": [2, 6, 7, 8, 9, 10, 11, 12, 13, 14], "output": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "get_shape_from_npi": 2, "str": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "path_manager_kei": 2, "shape": [2, 4, 6, 9, 10, 11, 14], "npy": [2, 3], "onli": [2, 4, 5, 6, 9, 11, 14], "its": [2, 4, 5, 6, 8, 9, 11, 12, 14], "header": 2, "input": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "manag": [2, 11], "differ": [2, 4, 5, 6, 11, 12, 14], "filesystem": 2, "load_npy_rang": 2, "fname": 2, "num_row": 2, "mmap_mod": 2, "bool": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "fals": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "ndarrai": 2, "note": [2, 4, 5, 6, 11, 14], "assum": [2, 4, 5, 6, 8, 9, 10, 12], "arrai": 2, "ndim": 2, "2": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "string": [2, 5, 8, 11], "start": [2, 4, 11, 14], "get": [2, 4, 5, 6], "desir": [2, 4, 8], "suppli": 2, "np": 2, "shuffl": [2, 6], "input_dir_labels_and_dens": 2, "input_dir_spars": 2, "output_dir_shuffl": 2, "rows_per_dai": 2, "output_dir_full_set": 2, "dai": 2, "24": [2, 4, 6], "int_column": 2, "13": 2, "sparse_column": 2, "26": 2, "random_se": 2, "expect": [2, 3, 4, 5, 10, 11], "split": [2, 4, 5, 6, 8, 14], "dens": [2, 4, 5, 10, 11, 14], "spars": [2, 3, 4, 6, 10, 11, 13], "label": [2, 4, 6, 9, 10], "must": [2, 4, 5, 6, 8, 9, 10, 11], "day_x_dens": 2, "day_x_spars": 2, "day_x_label": 2, "reconstruct": 2, "back": 2, "separ": [2, 3, 4], "1": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "final": [2, 5, 9, 10, 11, 13, 14], "remain": 2, "untouch": 2, "valid": [2, 5, 11, 14], "directori": [2, 4], "full": [2, 11, 12, 14], "total": [2, 4, 5, 6, 9], "categor": 2, "seed": [2, 5], "oper": [2, 4, 5, 6, 7, 11, 14], "sparse_to_contigu": 2, "in_fil": 2, "output_dir": 2, "frequency_threshold": 2, "3": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "output_file_suffix": 2, "_contig_freq": 2, "convert": [2, 4, 7, 8, 14], "contigu": [2, 3], "integ": 2, "store": [2, 4, 5, 6, 14], "togeth": [2, 4, 11], "becaus": [2, 5, 11, 12], "match": [2, 4, 5, 8, 9, 10, 11], "id": [2, 4, 5, 6, 11], "between": [2, 4, 5, 9, 10, 11, 14], "henc": 2, "thei": [2, 4, 5, 14], "also": [2, 4, 5, 8, 9, 10, 11, 12], "appear": [2, 11], "less": 2, "than": [2, 5, 11, 12], "time": [2, 4, 5, 6, 8, 9, 11], "remap": [2, 11], "valu": [2, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14], "day_0_spars": 2, "col_0": 2, "col_1": 2, "abc": [2, 4, 5, 8, 9, 11, 12], "xyz": 2, "iop": 2, "day_1_spars": 2, "tuv": 2, "lkj": 2, "day_0_sparse_contig": 2, "day_1_sparse_contig": 2, "occur": 2, "frequenc": [2, 4], "tsv_to_npi": 2, "out_dense_fil": 2, "out_sparse_fil": 2, "out_labels_fil": 2, "dataset_nam": 2, "criteo_1tb": 2, "one": [2, 4, 5, 6, 8, 9, 10, 11, 12], "three": [2, 9], "float32": [2, 4, 8, 11, 13], "int32": [2, 6, 14], "1tb": 2, "click": [2, 9], "log": [2, 5, 9], "For": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "test": [2, 11], "filler": 2, "includ": [2, 4, 5, 7, 8, 9, 11, 14], "name": [2, 4, 5, 8, 9, 10, 11, 12, 13, 14], "criteoiterdatapip": 2, "row_mapp": 2, "callabl": [2, 4, 6, 7, 11, 12, 13], "_default_row_mapp": 2, "open_kw": 2, "iterdatapip": 2, "stream": [2, 4, 11, 14], "either": [2, 4, 5, 9, 11], "http": [2, 4, 5, 10, 11, 14], "ailab": 2, "com": [2, 11], "download": 2, "www": 2, "local": [2, 4, 5, 6, 9, 11], "constitut": 2, "appli": [2, 4, 6, 10, 11], "line": [2, 3], "pass": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "underli": [2, 9], "invoc": [2, 5], "iopath": 2, "file_io": 2, "pathmanag": 2, "open": 2, "inmemorybinarycriteoiterdatapip": [2, 3], "stage": [2, 12], "dense_path": 2, "sparse_path": 2, "labels_path": 2, "batch_siz": [2, 4, 5, 6, 9, 11, 14], "drop_last": 2, "shuffle_batch": 2, "shuffle_training_set": 2, "shuffle_training_set_random_se": 2, "hash": [2, 6, 11], "iterabledataset": 2, "over": [2, 4, 6, 10, 11, 12], "version": [2, 4, 11, 13], "entir": [2, 5, 6], "prevent": 2, "disk": 2, "speed": [2, 13], "affect": [2, 5], "throughout": [2, 10], "respons": [2, 4], "npy_preproc_criteo": 2, "py": [2, 4], "val": [2, 4], "max": [2, 5, 11, 12], "cat_feature_count": 2, "templat": [2, 9], "1tb_binari": 2, "day_": 2, "_": [2, 8], "1024": 2, "torch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "get_rank": [2, 4], "get_world_s": [2, 8], "train_datapip": 2, "txt": 2, "test_datapip": 2, "movielens_20m": 2, "root": [2, 7], "include_movies_data": 2, "param": [2, 4, 5, 9, 12], "true": [2, 4, 5, 8, 9, 11, 12, 14], "add": [2, 4, 7, 11, 12], "movi": 2, "ml": 2, "20": [2, 10, 11], "movielens_25m": 2, "25m": 2, "25": [2, 4, 9], "randomrecdataset": 2, "hash_siz": 2, "ids_per_featur": 2, "num_dens": 2, "50": 2, "manual_se": 2, "num_batch": 2, "num_generated_batch": 2, "10": [2, 5, 10, 11, 13, 14], "min_ids_per_featur": 2, "recsi": [2, 5, 10, 12], "current": [2, 4, 5, 6, 8, 9, 11], "produc": [2, 4, 5], "unweight": 2, "todo": 2, "weight": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "taken": [2, 5], "modulo": 2, "per": [2, 4, 5, 6, 9, 11, 14], "correspond": [2, 4, 5, 6, 8, 9, 11, 14], "argument": [2, 4, 7, 8, 9, 11], "ignor": [2, 4, 5, 6, 8, 11], "sampl": [2, 5, 9, 11], "determinist": [2, 5], "behavior": [2, 4, 7, 12], "num": 2, "befor": [2, 4, 5, 6, 9, 11, 12], "rais": [2, 4], "stopiter": 2, "cach": [2, 4, 5], "num_gener": 2, "cycl": 2, "neg": 2, "fly": 2, "minimum": [2, 5], "feat1": 2, "feat2": 2, "16": [2, 4, 6, 11, 13], "100_000": 2, "dense_featur": [2, 10, 11], "tensor": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "sparse_featur": [2, 4, 6, 10], "jagged_tensor": [2, 4], "keyedjaggedtensor": [2, 4, 6, 10, 11, 13, 14], "pipelin": [2, 4, 5, 11, 14], "pin_memori": [2, 14], "record_stream": [2, 4, 11, 14], "see": [2, 4, 5, 6, 7, 9, 11, 14], "org": [2, 4, 5, 10, 11, 14], "doc": [2, 4, 11, 14], "stabl": [2, 4, 11, 14], "html": [2, 4, 11, 14], "non_block": [2, 14], "awar": [2, 4, 14], "accord": [2, 4, 5, 6, 8, 10, 12, 14], "might": [2, 5, 14], "self": [2, 4, 5, 6, 8, 11, 14], "copi": [2, 4, 6, 8, 9, 11, 12, 14], "rememb": [2, 4, 14], "new_devic": [2, 14], "limit": [2, 8, 9, 11], "loadfil": 2, "mode": [2, 4, 5, 9], "b": [2, 4, 5, 6, 10, 11, 13, 14], "iobas": 2, "adapt": [2, 11], "loadfilesfromdisk": 2, "merg": [2, 4, 6], "replac": [2, 9, 12], "someth": 2, "core": 2, "lib": 2, "parallelreadconcat": 2, "dp_selector": 2, "sequenc": [2, 4, 5, 6], "_default_dp_selector": 2, "concaten": [2, 4, 6, 10, 11, 14], "multipl": [2, 4, 5, 9, 10, 11, 12], "when": [2, 4, 5, 7, 9, 11, 12], "dataload": [2, 4], "subset": [2, 5], "worker": [2, 4], "instanc": [2, 4, 6, 7, 8, 9, 11], "would": [2, 4, 6, 14], "f": [2, 4, 5, 6, 10, 11, 13], "shard_": 2, "idx": [2, 4], "4": [2, 4, 5, 6, 9, 10, 11, 13, 14], "num_work": [2, 14], "readlinesfromcsv": 2, "skip_first_lin": 2, "kw": 2, "idx_split_train_v": 2, "train_perc": 2, "float": [2, 4, 5, 7, 9, 11, 12, 14], "decimal_plac": 2, "key_fn": 2, "_default_key_fn": 2, "rand_split_train_v": 2, "via": [2, 4, 6], "uniform": [2, 5, 11], "disjoint": 2, "specifi": [2, 4, 5, 6, 7, 9, 10, 11, 12], "target": [2, 4, 10], "proport": 2, "actual": [2, 4, 5, 6, 8, 9, 11], "guarante": [2, 5, 7, 12], "exactli": [2, 4], "membership": 2, "across": [2, 4, 5, 6, 9], "call": [2, 4, 5, 6, 8, 9, 11, 12, 13], "consist": [2, 4], "val_datapip": 2, "75": 2, "train_batch": 2, "val_batch": 2, "safe_cast": 2, "t": [2, 4, 5, 6, 7, 8, 11, 12, 14], "dest_typ": 2, "default": [2, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "train_filt": 2, "val_filt": 2, "main": [3, 9], "argv": 3, "result": [3, 4, 5, 6, 8, 9, 11, 13], "command": 3, "arg": [3, 4, 5, 8, 9, 11, 13, 14], "parse_arg": 3, "namespac": [3, 14], "raw": [3, 11], "criteo": [3, 10], "necessari": [4, 5, 6, 8, 9], "These": [4, 5, 9, 11], "distributedmodelparallel": 4, "collect": [4, 6, 10, 11, 12, 13], "scatter": [4, 6], "wrapper": [4, 12], "kjt": [4, 5, 6, 10, 11, 13, 14], "variou": [4, 8, 11], "implement": [4, 5, 6, 8, 9, 10, 11, 12, 14], "shardedembeddingbag": 4, "nn": [4, 5, 7, 8, 10, 11, 13], "shardedembeddingbagcollect": [4, 11, 13], "embeddingbagcollect": [4, 8, 10, 11, 13], "sharder": [4, 5, 8], "defin": [4, 5, 6, 8, 9, 10, 11], "comput": [4, 5, 6, 8, 9, 10, 11, 13], "kernel": [4, 5, 11], "cpu": [4, 5, 9], "gpu": [4, 5], "mai": [4, 14], "fusion": 4, "trainpipelinesparsedist": 4, "overlap": 4, "transfer": 4, "inter": [4, 11], "input_dist": [4, 11], "forward": [4, 5, 6, 8, 9, 10, 11, 13, 14], "backward": [4, 5, 7, 12], "increas": [4, 9], "perform": [4, 5, 6, 8, 9, 11, 12, 13], "quantiz": [4, 6, 7, 8, 13], "precis": [4, 5, 11, 13], "construct": [4, 7, 11, 14], "control": [4, 7], "flow": [4, 7], "invoke_on_rank_and_broadcast_result": 4, "pg": [4, 5, 6, 9], "processgroup": [4, 5, 6, 9], "func": 4, "kwarg": [4, 9, 11, 14], "invok": [4, 5], "broadcast": [4, 5], "member": [4, 11], "group": [4, 5, 6, 9, 11, 12, 14], "allocate_id": 4, "is_lead": 4, "leader_rank": 4, "check": [4, 5, 9, 11, 12, 14], "processs": 4, "leader": [4, 9], "dist": [4, 6, 9], "impli": 4, "e": [4, 5, 6, 7, 8, 9, 10, 11, 12], "g": [4, 5, 8, 9, 10, 11, 12], "singl": [4, 5, 6, 11, 12], "program": [4, 5], "definit": [4, 7, 8], "caller": 4, "overrid": [4, 5, 7, 8, 9], "context": [4, 6, 14], "run_on_lead": 4, "get_group_rank": 4, "avail": [4, 5, 6], "group_rank": 4, "varibl": 4, "get_num_group": 4, "elast": 4, "run": [4, 5, 6, 8, 9, 11, 12], "get_local_rank": 4, "usual": [4, 5, 6, 9, 11], "node": [4, 7], "get_local_s": 4, "equival": 4, "max_nnod": 4, "intra_and_cross_node_pg": 4, "backend": [4, 6, 8, 9], "sub": 4, "intra": 4, "cross": [4, 10, 11], "all2alldenseinfo": 4, "output_split": [4, 6], "input_shap": 4, "input_split": [4, 6], "attribut": [4, 5, 9, 12], "alltoall_dens": 4, "all2allpooledinfo": 4, "batch_size_per_rank": [4, 6], "dim_sum_per_rank": [4, 6], "dim_sum_per_rank_tensor": 4, "cumsum_dim_sum_per_rank_tensor": 4, "codec": [4, 6], "quantizedcommcodec": [4, 6], "alltoall_pool": [4, 6], "sum": [4, 5, 6, 11], "dimens": [4, 5, 6, 10, 11, 13, 14], "fast": 4, "_recat_pooled_embedding_grad_out": 4, "cumul": [4, 9, 14], "all2allsequenceinfo": 4, "embedding_dim": [4, 5, 6, 10, 11, 13], "lengths_after_sparse_data_all2al": 4, "forward_recat_tensor": 4, "backward_recat_tensor": 4, "variable_batch_s": 4, "permuted_lengths_after_sparse_data_all2al": 4, "alltoall_sequ": 4, "alltoal": [4, 6], "recat": [4, 6, 11, 14], "variabl": [4, 6, 9, 11, 13, 14], "all2allvinfo": 4, "dims_sum_per_rank": 4, "b_global": 4, "b_local": 4, "b_local_list": 4, "d_local_list": 4, "input_split_s": 4, "factori": [4, 5, 11], "output_split_s": 4, "alltoallv": 4, "global": [4, 5, 6, 9], "my": 4, "how": [4, 5, 6, 12], "do": [4, 5, 9, 11, 12, 14], "all_to_all_singl": 4, "fill": 4, "all2all_pooled_req": 4, "ctx": 4, "unus": [4, 11], "formula": [4, 5], "differenti": 4, "overridden": [4, 6, 8, 9, 11], "subclass": [4, 6, 8, 11, 12], "vjp": 4, "It": [4, 5, 6, 8, 9, 11, 12, 13, 14], "accept": [4, 5, 8, 9, 11], "non": [4, 5, 6, 7, 9, 11, 13], "were": 4, "gradient": [4, 5, 12], "w": [4, 6, 11, 14], "r": [4, 11], "requir": [4, 5, 9, 11, 12], "grad": [4, 12], "just": [4, 5, 10, 11, 14], "retriev": 4, "dure": [4, 5, 9, 12], "ha": [4, 5, 9, 10, 11, 14], "needs_input_grad": 4, "boolean": 4, "myreq": 4, "a2ai": 4, "input_embed": [4, 11], "custom": [4, 5, 7, 11], "autograd": [4, 8, 9, 11], "usag": [4, 5, 9], "combin": [4, 11, 12], "staticmethod": 4, "def": [4, 11], "other": [4, 5, 6, 9, 12], "detail": [4, 5, 6, 9, 11], "setup_context": 4, "longer": [4, 5], "instead": [4, 5, 6, 8, 11, 12], "arbitrari": 4, "though": 4, "enforc": [4, 8, 9, 11], "compat": [4, 5, 7, 10, 12], "save_for_backward": 4, "intend": 4, "save_for_forward": 4, "jvp": 4, "all2all_pooled_wait": 4, "grad_output": 4, "dummy_tensor": 4, "all2all_seq_req": 4, "sharded_input_embed": 4, "all2all_seq_req_wait": 4, "sharded_grad_output": 4, "all2allv_req": 4, "all2allv_wait": 4, "allgatherbaseinfo": 4, "input_s": [4, 5, 11], "all_gatther_base_pool": 4, "allgatherbase_req": 4, "agi": 4, "allgatherbase_wait": 4, "reducescatterbaseinfo": 4, "reduce_scatter_base_pool": 4, "flatten": [4, 6, 11], "reducescatterbase_req": 4, "rsi": 4, "reducescatterbase_wait": 4, "reducescatterinfo": 4, "reduce_scatter_pool": 4, "reducescattervinfo": 4, "equal_split": 4, "total_input_s": 4, "reduce_scatter_v_pool": 4, "along": [4, 6, 9, 10, 12, 14], "dim": [4, 6], "reducescatterv_req": 4, "reducescatterv_wait": 4, "reducescatter_req": 4, "reducescatter_wait": 4, "await": [4, 6, 7], "variablebatchall2allpooledinfo": 4, "batch_size_per_rank_per_featur": [4, 6], "batch_size_per_feature_pre_a2a": [4, 6], "emb_dim_per_rank_per_featur": [4, 6], "variable_batch_alltoall_pool": [4, 6], "variable_batch_all2all_pooled_req": 4, "variable_batch_all2all_pooled_wait": 4, "all2all_pooled_sync": 4, "all2all_sequence_sync": 4, "all2allv_sync": 4, "all_gather_base_pool": 4, "gather": [4, 6], "form": [4, 11, 13], "pool": [4, 5, 6, 10, 11, 13, 14], "output_tensor_s": 4, "work": [4, 5, 8, 9, 11, 14], "async": [4, 6], "wait": [4, 6], "later": [4, 11], "experiment": [4, 11], "subject": 4, "all_gather_base_sync": 4, "all_gather_into_tensor_backward": 4, "all_gather_into_tensor_fak": 4, "gather_dim": 4, "group_siz": 4, "group_nam": 4, "gradient_divis": 4, "all_gather_into_tensor_setup_context": 4, "all_to_all_single_backward": 4, "all_to_all_single_fak": 4, "all_to_all_single_setup_context": 4, "a2a_pooled_embs_tensor": 4, "Then": 4, "receiv": [4, 6, 12], "Its": 4, "x": [4, 5, 6, 10, 11, 13, 14], "d_local_sum": 4, "where": [4, 5, 6, 9, 10, 11, 13], "a2a_sequence_embs_tensor": 4, "doe": [4, 5, 10, 11, 12, 14], "mix": 4, "out_split": 4, "per_rank_split_length": 4, "assumpt": [4, 14], "emb": 4, "get_gradient_divis": 4, "get_use_sync_collect": 4, "pg_name": 4, "reduce_scatter_base_sync": 4, "chunk": [4, 6], "reduce_scatter_sync": 4, "reduce_scatter_tensor_backward": 4, "reduce_scatter_tensor_fak": 4, "reduceop": 4, "reduce_scatter_tensor_setup_context": 4, "reduce_scatter_v_per_feature_pool": 4, "v": [4, 6, 11, 14], "d": [4, 5, 10, 11, 13, 14], "unevenli": 4, "reduce_scatter_v_sync": 4, "set_gradient_divis": 4, "set_use_sync_collect": 4, "torchrec_use_sync_collect": 4, "variable_batch_all2all_pooled_sync": 4, "embeddingsalltoon": [4, 6], "cat_dim": [4, 6, 14], "buffer": [4, 6, 8, 9, 11], "alloc": [4, 5, 6, 8], "like": [4, 5, 6, 7, 11, 12, 14], "alltoon": [4, 6], "set_devic": [4, 6], "device_str": [4, 6], "embeddingsalltoonereduc": [4, 6], "jaggedtensoralltoal": [4, 6], "jt": [4, 6, 11, 14], "jaggedtensor": [4, 6, 11, 13, 14], "num_items_to_send": [4, 6], "num_items_to_rec": [4, 6], "redistribut": [4, 6], "send": [4, 6], "known": [4, 5, 6, 11], "ahead": [4, 6], "keyedjaggedtensorpool": [4, 6], "lookup": [4, 5, 6, 10, 11, 13], "anoth": [4, 5, 6], "kjtalltoal": [4, 6], "stagger": [4, 6, 14], "kjtalltoallsplitsawait": [4, 6], "transmit": [4, 6], "correct": [4, 6, 14], "space": [4, 5, 6, 10], "kjtalltoalltensorsawait": [4, 6], "asynchron": [4, 6], "len": [4, 6, 10], "order": [4, 5, 6, 8, 9, 11, 14], "destin": [4, 6, 8, 9, 11], "_get_recat": [4, 6], "kjta2a": [4, 6], "rank0_input": [4, 6], "hold": [4, 5, 6, 12, 14], "v0": [4, 6, 14], "v1": [4, 6, 11, 14], "v2": [4, 6, 10, 11, 14], "rank1_input": [4, 6], "v3": [4, 6, 14], "v4": [4, 6, 14], "rank0_output": [4, 6], "5": [4, 6, 9, 10, 11, 13, 14], "rank1_output": [4, 6], "relev": [4, 5, 6], "tensor_split": [4, 6], "input_tensor": [4, 6], "ie": [4, 5, 6, 11, 14], "stride_per_rank": [4, 6, 14], "stride": [4, 6, 14], "case": [4, 5, 6, 9, 11, 12, 14], "kjtonetoal": [4, 6], "onetoal": [4, 6], "essenti": [4, 5, 6, 14], "p2p": [4, 6], "keyjaggedtensor": [4, 6], "them": [4, 6, 8, 10, 11, 12], "kjtlist": [4, 6], "slice": [4, 6, 7, 14], "mergepooledembeddingsmodul": [4, 6], "merge_pooled_embedding_optim": [4, 6], "_mergepooledembeddingsmoduleimpl": [4, 6], "merge_pooled_embed": [4, 6], "pooledembeddingsallgath": [4, 6], "wrap": [4, 6, 9, 10, 12], "layout": [4, 6, 7], "want": [4, 6, 9], "nccl": [4, 6], "happen": [4, 5, 6], "init_distribut": [4, 6], "new_group": [4, 6, 9], "randn": [4, 6, 10, 11], "m": [4, 5, 6, 7, 11], "local_emb": [4, 6], "pooledembeddingsawait": [4, 6], "num_bucket": [4, 6], "pooledembeddingsalltoal": [4, 6], "callback": [4, 6], "a2a": [4, 6], "t0": [4, 6], "rand": [4, 6, 10], "6": [4, 5, 6, 10, 11, 13, 14], "t1": [4, 6, 10, 11, 13], "print": [4, 6, 11, 13], "properti": [4, 5, 6, 8, 9, 10, 11, 12, 13], "tensor_await": [4, 6], "pooledembeddingsreducescatt": [4, 6], "twrw": [4, 5, 6], "unequ": [4, 6], "bucket": [4, 6], "seqembeddingsalltoon": [4, 6], "concat": [4, 6, 11, 14], "sequenceembeddingsalltoal": [4, 6], "features_per_rank": [4, 6], "sharding_ctx": [4, 6], "sequenceshardingcontext": [4, 6], "lengths_after_input_dist": [4, 6], "unbucketize_permute_tensor": [4, 6], "sparse_features_recat": [4, 6], "sequenceembeddingsawait": [4, 6], "permut": [4, 6, 14], "splitsalltoallawait": [4, 6], "tensoralltoal": [4, 6], "1d": [4, 5, 6], "tensoralltoallsplitsawait": [4, 6], "tensoralltoallvaluesawait": [4, 6], "tensor_a2a": [4, 6], "rank0": [4, 6], "rank1": [4, 6], "v5": [4, 6, 14], "v6": [4, 6, 14], "v7": [4, 6, 14], "v8": [4, 6], "v9": [4, 6], "v10": [4, 6], "v11": [4, 6], "v12": [4, 6], "tensorvaluesalltoal": [4, 6], "tensor_vals_a2a": [4, 6], "v13": [4, 6], "v14": [4, 6], "v15": [4, 6], "sent": [4, 6], "equal": [4, 5, 6, 11, 14], "_pg": [4, 6], "variablebatchpooledembeddingsalltoal": [4, 6], "kjt_split": [4, 6], "r0_batch_siz": [4, 6], "r1_batch_siz": [4, 6], "f_0": [4, 6], "f_1": [4, 6], "f_2": [4, 6], "r0_batch_size_per_rank_per_featur": [4, 6], "r1_batch_size_per_rank_per_featur": [4, 6], "r0_batch_size_per_feature_pre_a2a": [4, 6], "r1_batch_size_per_feature_pre_a2a": [4, 6], "r0": [4, 6], "r1": [4, 6], "14": [4, 6], "post": [4, 6], "rank_0": [4, 6], "rank_1": [4, 6], "variablebatchpooledembeddingsreducescatt": [4, 6], "rw": [4, 5, 6, 11], "multipli": [4, 5, 6], "batch_size_r0_f0": [4, 6], "emb_dim_f0": [4, 6], "embeddingcollectionawait": 4, "lazyawait": 4, "embeddingcollectioncontext": 4, "sharding_context": 4, "input_featur": 4, "reverse_indic": [4, 11], "seq_vbe_ctx": [4, 11], "sequencevbecontext": [4, 11], "multistream": [4, 11], "embeddingcollectionshard": 4, "fused_param": [4, 6], "qcomm_codecs_registri": [4, 6], "use_index_dedup": 4, "baseembeddingshard": 4, "embeddingcollect": [4, 8, 11, 13], "module_typ": [4, 8, 13], "parametershard": 4, "env": [4, 6], "shardingenv": [4, 6], "shardedembeddingcollect": [4, 11, 13], "locat": 4, "replic": [4, 5, 6], "embeddingmoduleshardingplan": 4, "fulli": [4, 5, 12], "qualifi": 4, "spec": 4, "shardedmodul": 4, "shardable_paramet": 4, "sharding_typ": [4, 5, 11], "compute_device_typ": 4, "shardingtyp": [4, 5, 11], "well": [4, 5, 11], "table_name_to_parameter_shard": 4, "shardedembeddingmodul": 4, "fusedoptimizermodul": [4, 12], "public": [4, 11], "manual": [4, 10, 12], "dist_input": 4, "compute_and_output_dist": 4, "In": [4, 5, 11, 12, 14], "sens": [4, 12], "initi": [4, 11, 12], "distibut": 4, "soon": 4, "complet": [4, 5], "create_context": 4, "fused_optim": [4, 12], "keyedoptim": [4, 12], "output_dist": [4, 5], "reset_paramet": [4, 11], "create_embedding_shard": 4, "sharding_info": [4, 6], "embeddingshardinginfo": [4, 6], "embeddingshard": [4, 6], "create_sharding_infos_by_shard": 4, "embeddingcollectioninterfac": [4, 11, 13], "create_sharding_infos_by_sharding_device_group": 4, "get_device_from_parameter_shard": 4, "ps": [4, 5], "get_ec_index_dedup": 4, "pad_vbe_kjt_length": 4, "set_ec_index_dedup": 4, "commopgradientsc": 4, "functionctx": 4, "scale_gradient_factor": 4, "groupedembeddingslookup": 4, "grouped_config": 4, "groupedembeddingconfig": [4, 6], "baseembeddinglookup": [4, 6], "i": [4, 5, 6, 7, 9, 10, 11], "flush": 4, "everi": [4, 5, 6, 8, 11], "although": [4, 6, 8, 11], "recip": [4, 6, 8, 11], "afterward": [4, 6, 8, 11], "sinc": [4, 5, 6, 8, 11], "former": [4, 6, 8, 11], "take": [4, 5, 6, 8, 11, 12], "care": [4, 6, 8, 11], "regist": [4, 6, 7, 8, 11], "hook": [4, 6, 8, 11], "latter": [4, 6, 8, 11], "silent": [4, 6, 8, 11], "load_state_dict": [4, 12], "state_dict": [4, 8, 9, 11, 12], "ordereddict": [4, 8, 9, 11], "union": [4, 5, 7, 8, 9, 11, 12], "shardedtensor": [4, 12], "strict": [4, 12], "_incompatiblekei": 4, "descend": [4, 5], "unless": [4, 12], "get_swap_module_params_on_convers": 4, "persist": [4, 8, 9, 11], "strictli": [4, 11], "preserv": [4, 11], "except": [4, 5, 9, 11], "requires_grad": 4, "field": [4, 11, 12, 14], "missing_kei": 4, "miss": [4, 5], "unexpected_kei": 4, "present": [4, 12], "namedtupl": 4, "runtimeerror": 4, "named_buff": [4, 11], "prefix": [4, 8, 9, 11], "recurs": [4, 11], "remove_dupl": [4, 11], "yield": [4, 11], "both": [4, 8, 9, 10, 11, 12, 14], "itself": [4, 10, 11], "prepend": [4, 11], "submodul": [4, 11, 12], "otherwis": [4, 5, 8, 9, 11, 12, 14], "direct": [4, 11], "remov": [4, 7, 11], "duplic": [4, 11, 12], "xdoctest": [4, 8, 9, 11], "skip": [4, 8, 9, 11, 12], "undefin": [4, 8, 9, 11], "var": [4, 8, 9, 11], "buf": [4, 11], "running_var": [4, 11], "named_paramet": 4, "bia": [4, 8, 9, 11], "named_parameters_by_t": 4, "tablebatchedembeddingslic": 4, "table_nam": 4, "embedding_weight": 4, "cw": [4, 5], "compos": [4, 8, 9, 11], "prefetch": [4, 5], "forward_stream": 4, "purg": 4, "keep_var": [4, 8, 9, 11], "dictionari": [4, 8, 9, 11], "refer": [4, 8, 9, 11, 14], "whole": [4, 8, 9, 11], "averag": [4, 5, 8, 9, 11], "shallow": [4, 8, 9, 11], "posit": [4, 5, 6, 8, 9, 11], "howev": [4, 8, 9, 11, 12], "deprec": [4, 8, 9, 11], "keyword": [4, 8, 9, 11], "futur": [4, 8, 9, 11], "releas": [4, 8, 9, 11], "end": [4, 5, 8, 9, 11], "user": [4, 5, 8, 9, 11, 12], "detach": [4, 8, 9, 11], "groupedpooledembeddingslookup": 4, "feature_processor": [4, 6, 13], "basegroupedfeatureprocessor": [4, 6, 11], "scale_weight_gradi": 4, "infercpugroupedembeddingslookup": 4, "grouped_configs_per_rank": 4, "infergroupedlookupmixin": 4, "inputdistoutput": [4, 6], "tbetoregistermixin": 4, "get_tbes_to_regist": 4, "intnbittablebatchedembeddingbagscodegen": 4, "infergroupedembeddingslookup": 4, "input_dist_output": 4, "infergroupedpooledembeddingslookup": 4, "metainfergroupedembeddingslookup": 4, "tbe": [4, 5, 13], "op": [4, 5, 6, 12, 13], "metainfergroupedpooledembeddingslookup": 4, "bag": [4, 6, 7, 10, 11], "dtype": [4, 5, 6, 7, 8, 11, 13, 14], "embeddings_cat_empty_rank_handl": 4, "dummy_embs_tensor": 4, "embeddings_cat_empty_rank_handle_infer": 4, "fx_wrap_tensor_view2d": 4, "dim0": 4, "dim1": 4, "baseembeddingdist": [4, 6], "embeddinglookup": 4, "abstract": [4, 5, 8, 9, 11, 12], "basesparsefeaturesdist": [4, 6], "featureshardingmixin": 4, "table_wis": [4, 11], "create_input_dist": [4, 6], "create_lookup": [4, 6], "create_output_dist": [4, 6], "embedding_nam": [4, 6, 11], "embedding_names_per_rank": [4, 6], "embedding_shard_metadata": [4, 6], "shardmetadata": [4, 6], "embedding_t": [4, 6], "shardedembeddingt": [4, 6], "uncombined_embedding_dim": [4, 6], "uncombined_embedding_nam": [4, 6], "embeddingshardingcontext": [4, 6], "variable_batch_per_featur": 4, "embedding_config": [4, 13], "embeddingtableconfig": [4, 11], "param_shard": 4, "nonetyp": [4, 9, 11], "fusedkjtlistsplitsawait": 4, "kjtlistsplitsawait": 4, "kjtlistawait": 4, "info": [4, 11], "metadata": [4, 8, 11], "kjtsplitsalltoallmeta": 4, "distributed_c10d": 4, "_input": 4, "splits_tensor": 4, "listofkjtlistawait": 4, "listofkjtlist": 4, "listofkjtlistsplitsawait": 4, "bucketize_kjt_before_all2al": 4, "block_siz": [4, 6], "output_permut": 4, "bucketize_po": 4, "block_bucketize_row_po": 4, "keep_original_indic": [4, 6], "readjust": 4, "unbucket": 4, "offset": [4, 5, 10, 11, 13, 14], "keep": [4, 5, 11], "origin": [4, 5, 8, 10], "bucketize_kjt_infer": 4, "is_sequ": [4, 6], "group_tabl": 4, "tables_per_rank": 4, "datatyp": [4, 5, 11, 13, 14], "poolingtyp": [4, 11], "embeddingcomputekernel": [4, 5], "weighted": 4, "interfac": [4, 8, 9, 11], "reli": [4, 8, 11, 13], "moduleshard": [4, 5, 8], "compute_kernel": [4, 5], "storage_usag": 4, "resourc": 4, "processor": [4, 6, 8, 11], "basequantembeddingshard": 4, "shardable_param": 4, "dtensormetadata": 4, "mesh": 4, "device_mesh": 4, "devicemesh": 4, "placement": [4, 5], "_tensor": 4, "placement_typ": 4, "embeddingattribut": 4, "enum": [4, 5, 11, 12], "enumer": [4, 11, 12], "fuse": [4, 6, 9], "fused_uvm": 4, "fused_uvm_cach": 4, "key_valu": 4, "quant": 4, "quant_uvm": 4, "quant_uvm_cach": 4, "feature_nam": [4, 5, 6, 10, 11, 13], "feature_names_per_rank": [4, 6], "data_typ": [4, 11], "is_weight": [4, 5, 11, 13, 14], "has_feature_processor": [4, 6, 11], "dim_sum": 4, "feature_hash_s": [4, 6], "num_featur": [4, 6, 10, 11], "bucket_mapping_tensor": 4, "bucketized_length": 4, "moduleshardingmixin": 4, "access": [4, 5, 12, 14], "scheme": 4, "optimtyp": 4, "adagrad": [4, 12], "adam": [4, 12], "adamw": 4, "lamb": 4, "lars_sgd": 4, "lion": 4, "partial_rowwise_adam": 4, "partial_rowwise_lamb": 4, "rowwise_adagrad": 4, "sgd": 4, "shampoo": 4, "shampoo_v2": 4, "shampoo_v2_mr": 4, "shardedconfig": 4, "local_row": [4, 5], "local_col": [4, 5], "compin": 4, "distout": 4, "out": [4, 11, 14], "shrdctx": 4, "commop": 4, "extra_repr": 4, "pretti": 4, "represent": [4, 5, 7, 11, 14], "num_embed": [4, 5, 10, 11, 13], "fp32": [4, 5, 11], "weight_init_max": [4, 11], "weight_init_min": [4, 11], "num_embeddings_post_prun": [4, 11], "init_fn": [4, 11], "need_po": [4, 6, 11], "local_metadata": 4, "_shard": 4, "global_metadata": 4, "sharded_tensor": 4, "shardedtensormetadata": 4, "dtensor_metadata": 4, "shardedmetaconfig": 4, "compute_kernel_to_embedding_loc": 4, "embeddingloc": 4, "embeddingawait": 4, "embeddingbagcollectionawait": 4, "lazygetitemmixin": 4, "keyedtensor": [4, 10, 11, 13, 14], "embeddingbagcollectioncontext": 4, "inverse_indic": [4, 11, 14], "divisor": 4, "embeddingbagcollectionshard": 4, "embeddingbagshard": 4, "nullshardedmodulecontext": 4, "per_sample_weight": 4, "named_modul": 4, "memo": 4, "network": [4, 5, 11, 12], "alreadi": [4, 6, 8, 12], "onc": [4, 11], "l": [4, 11, 13], "linear": [4, 5, 11, 12], "net": [4, 10, 11], "sequenti": [4, 5, 11], "in_featur": [4, 10, 11], "out_featur": [4, 11], "sharded_parameter_nam": 4, "embeddingbagcollectioninterfac": [4, 11, 13], "variablebatchembeddingbagcollectionawait": 4, "construct_output_kt": 4, "create_embedding_bag_shard": 4, "permute_embed": [4, 6], "suffix": 4, "replace_placement_with_meta_devic": 4, "could": [4, 5, 14], "unmatch": 4, "scenario": [4, 11, 13], "dmp": 4, "cuda": [4, 5, 8], "embeddingshardingplann": [4, 5], "planner": 4, "groupedpositionweightedmodul": 4, "max_feature_length": [4, 11], "dataparallelwrapp": 4, "defaultdataparallelwrapp": 4, "bucket_cap_mb": 4, "static_graph": 4, "find_unused_paramet": 4, "allreduce_comm_precis": 4, "params_to_ignor": 4, "ddp_kwarg": 4, "unshard": [4, 5, 11, 13], "shardingplan": [4, 5, 8], "init_data_parallel": 4, "init_paramet": 4, "data_parallel_wrapp": 4, "entri": 4, "point": [4, 5], "collective_plan": [4, 5], "lazi": [4, 11, 12], "delai": 4, "until": 4, "still": [4, 14], "no_grad": [4, 11], "init_weight": [4, 11], "isinst": 4, "fill_": [4, 11], "elif": 4, "init": 4, "kaiming_normal_": 4, "mymodel": 4, "bare_named_paramet": 4, "tor": 4, "safe": 4, "ddp": 4, "fsdp": 4, "sparse_grad_parameter_nam": [4, 12], "get_modul": 4, "unwrap": 4, "get_unwrapped_modul": 4, "quantembeddingbagcollectionshard": [4, 8], "shardedquantembeddingbagcollect": 4, "quantfeatureprocessedembeddingbagcollectionshard": [4, 8], "featureprocessedembeddingbagcollect": [4, 8, 13], "shardedquantebcinputdist": 4, "sharding_type_device_group_to_shard": 4, "nullshardingcontext": [4, 6], "sharding_type_to_shard": 4, "sqebc_input_dist": 4, "infertwsequenceembeddingshard": 4, "f1": [4, 10, 11, 13], "f2": [4, 10, 11, 13], "7": [4, 9, 10, 11, 13, 14], "8": [4, 5, 10, 11, 13, 14], "shardedquantembeddingmodulest": 4, "embedding_bag_config": [4, 11, 13], "embeddingbagconfig": [4, 10, 11, 13], "execut": [4, 5, 8, 11, 13], "sharding_type_device_group_to_sharding_info": 4, "tbes_config": 4, "shardedquantfeatureprocessedembeddingbagcollect": 4, "featureprocessorscollect": [4, 13], "apply_feature_processor": 4, "kjt_list": [4, 14], "embedding_bag": [4, 13], "moduledict": [4, 13], "modulelist": [4, 9, 11, 13], "create_infer_embedding_bag_shard": 4, "flatten_feature_length": 4, "get_device_from_sharding_info": 4, "emb_shard_info": 4, "cacheparam": [4, 5], "algorithm": 4, "cachealgorithm": 4, "load_factor": [4, 5], "reserved_memori": 4, "prefetch_pipelin": [4, 5], "stat": 4, "cachestatist": [4, 5], "multipass_prefetch_config": 4, "multipassprefetchconfig": 4, "relat": [4, 5, 9], "uvm": [4, 5], "lru": [4, 5], "lfu": 4, "factor": [4, 5, 11], "decid": 4, "crucial": 4, "reserv": [4, 5], "ideal": 4, "aka": 4, "statist": [4, 5], "better": [4, 5], "tune": [4, 12], "cacheabl": [4, 5], "summar": [4, 5], "measur": [4, 5, 9], "difficulti": [4, 5], "dataset": [4, 5, 10], "independ": [4, 5], "score": [4, 5, 6, 11], "veri": [4, 5], "high": [4, 5, 9, 11], "difficult": [4, 5], "expected_lookup": [4, 5], "distinct": [4, 5], "expected_miss_r": [4, 5], "clf": [4, 5], "rate": [4, 5, 9, 12], "hit": [4, 5], "extrem": [4, 5], "estim": [4, 5, 9], "pooled_embeddings_all_to_al": 4, "pooled_embeddings_reduce_scatt": 4, "sequence_embeddings_all_to_al": 4, "computekernel": 4, "moduleshardingplan": 4, "describ": 4, "genericmeta": 4, "getitemlazyawait": 4, "parentw": 4, "kt": [4, 14], "__getitem__": 4, "parent": 4, "keyvalueparam": [4, 5], "ssd_storage_directori": 4, "ssd_rocksdb_write_buffer_s": 4, "ssd_rocksdb_shard": 4, "gather_ssd_cache_stat": 4, "stats_reporter_config": 4, "tbestatsreporterconfig": 4, "use_passed_in_path": 4, "l2_cache_s": 4, "ps_host": 4, "ps_client_thread_num": 4, "ps_max_key_per_request": 4, "ps_max_local_index_length": 4, "ssd": [4, 5], "ssdtablebatchedembeddingbag": 4, "data00_nvidia": 4, "local_rank": 4, "rocksdb": 4, "write": 4, "relav": 4, "compact": 4, "std": 4, "report": [4, 9], "od": 4, "report_interv": 4, "interv": [4, 9, 11], "ods_prefix": 4, "server": 4, "host": [4, 5, 6], "ip": 4, "port": 4, "2000": 4, "2001": 4, "2002": 4, "reason": [4, 12], "hashabl": 4, "thread": [4, 5], "client": 4, "maximum": [4, 5, 9, 11], "index": [4, 11, 14], "expos": [4, 12], "concret": 4, "achiev": 4, "late": 4, "possibl": [4, 5, 9], "__torch_function__": 4, "below": 4, "doesn": [4, 11, 12], "python": [4, 7, 10], "magic": 4, "__getattr__": 4, "caveat": 4, "mechan": [4, 11], "ensur": [4, 11, 14], "perfect": 4, "quickli": 4, "long": [4, 5, 11], "kwd": 4, "vt_co": 4, "augment": 4, "trigger": [4, 11], "keyedlazyawait": 4, "defer": 4, "mixin": 4, "inherit": [4, 9, 11], "mro": 4, "properli": [4, 11], "select": [4, 5, 6, 14], "lazynowait": 4, "classmethod": [4, 5, 8, 13], "noopquantizedcommcodec": 4, "quantizationcontext": 4, "No": [4, 6, 9], "calc_quantized_s": 4, "input_len": 4, "decod": 4, "input_grad": 4, "encod": 4, "padded_s": 4, "dim_per_rank": 4, "my_rank": [4, 9], "qcomm_ctx": 4, "quantized_dtyp": 4, "nowait": [4, 7], "obj": 4, "objectpoolshardingplan": 4, "objectpoolshardingtyp": 4, "replicated_row_wis": 4, "row_wis": [4, 11], "sharding_spec": 4, "shardingspec": 4, "cache_param": [4, 5], "enforce_hbm": [4, 5], "stochastic_round": [4, 5], "bounds_check_mod": [4, 5], "boundscheckmod": [4, 5], "output_dtyp": [4, 5, 8, 13], "key_value_param": [4, 5], "hbm": [4, 5], "stochast": [4, 5], "round": [4, 5], "bound": [4, 5], "place": [4, 5, 6, 12, 14], "column_wis": [4, 11], "seen": [4, 7], "individu": [4, 5, 10], "table_row_wis": [4, 11], "data_parallel": [4, 5, 11], "parameterstorag": 4, "physic": 4, "constraint": [4, 5, 8], "shardingplann": [4, 5], "ddr": [4, 5], "pipelinetyp": [4, 5], "about": 4, "train_bas": 4, "train_prefetch_sparse_dist": 4, "train_sparse_dist": 4, "pooled_all_to_al": 4, "reduce_scatt": 4, "quantized_tensor": 4, "quantized_comm_codec": 4, "collective_cal": 4, "output_tensor": 4, "assert_clos": 4, "int8": [4, 8], "addit": [4, 5, 7, 8, 10, 11, 12, 14], "carri": 4, "session": 4, "padded_dim_sum": 4, "padding_s": 4, "respect": [4, 10, 11], "sequence_all_to_al": 4, "modulenocopymixin": [4, 13], "vise": [4, 12], "versa": [4, 12], "practic": 4, "from_loc": 4, "typic": [4, 5, 7, 11, 12, 14], "from_process_group": 4, "fqn": [4, 5], "larger": [4, 5], "get_plan_for_modul": 4, "module_path": 4, "re": [4, 12], "stabil": 4, "table_column_wis": [4, 11], "get_tensor_size_byt": 4, "rank_devic": 4, "device_typ": 4, "scope": 4, "copyablemixin": 4, "mymodul": 4, "forkedpdb": 4, "completekei": 4, "tab": 4, "stdin": 4, "stdout": 4, "nosigint": 4, "readrc": 4, "pdb": 4, "multiprocess": 4, "child": 4, "debug": [4, 5, 9], "multiprocessing_util": 4, "set_trac": 4, "barrier": 4, "add_params_from_parameter_shard": 4, "parameter_shard": 4, "extract": 4, "ones": 4, "add_prefix_to_state_dict": 4, "filter": [4, 11], "append_prefix": 4, "append": 4, "convert_to_fbgemm_typ": 4, "copy_to_devic": 4, "current_devic": [4, 8], "to_devic": 4, "filter_state_dict": 4, "strip": 4, "begin": [4, 5, 12], "get_unsharded_module_nam": 4, "level": [4, 6], "don": [4, 8, 11], "merge_fused_param": 4, "param_fused_param": 4, "configur": 4, "cache_precis": 4, "preset": 4, "table_level_fused_param": 4, "precid": 4, "grouped_fused_param": 4, "null": 4, "none_throw": 4, "_t": 4, "messag": [4, 5], "unexpect": 4, "assertionerror": 4, "optimizer_type_to_emb_opt_typ": 4, "optimizer_class": 4, "emboptimtyp": 4, "sharded_model_copi": 4, "m_cpu": 4, "deepcopi": 4, "managedcollisioncollectionawait": 4, "managedcollisioncollectioncontext": 4, "managedcollisioncollectionshard": 4, "managedcollisioncollect": [4, 11], "shardedmanagedcollisioncollect": 4, "evict": [4, 11], "global_to_local_index": 4, "jt_dict": [4, 14], "open_slot": [4, 11], "create_mc_shard": 4, "managedcollisionembeddingbagcollectioncontext": 4, "evictions_per_t": 4, "remapped_kjt": 4, "managedcollisionembeddingbagcollectionshard": 4, "ebc_shard": 4, "mc_sharder": 4, "basemanagedcollisionembeddingcollectionshard": 4, "managedcollisionembeddingbagcollect": [4, 11], "shardedmanagedcollisionembeddingbagcollect": 4, "baseshardedmanagedcollisionembeddingcollect": 4, "managedcollisionembeddingcollectioncontext": 4, "managedcollisionembeddingcollectionshard": 4, "ec_shard": 4, "managedcollisionembeddingcollect": [4, 11], "shardedmanagedcollisionembeddingcollect": 4, "consid": [5, 11, 13, 14], "perf": 5, "storag": [5, 14], "peak": 5, "elimin": 5, "oom": [5, 9], "kernel_bw_lookup": 5, "compute_devic": [5, 8], "hbm_mem_bw": 5, "ddr_mem_bw": 5, "caching_ratio": 5, "calcul": [5, 9], "bandwidth": 5, "ratio": [5, 9], "embeddingenumer": 5, "parameterconstraint": [5, 8], "shardestim": 5, "use_exact_enumerate_ord": 5, "shardabl": 5, "exact": 5, "name_children": 5, "shardingopt": 5, "popul": [5, 11], "populate_estim": 5, "sharding_opt": 5, "descript": [5, 9], "get_partition_by_typ": 5, "partitionbytyp": 5, "greedyperfpartition": 5, "sort_bi": 5, "sortbi": 5, "balance_modul": 5, "greedi": 5, "sort": [5, 11], "smaller": 5, "effect": [5, 11], "storage_constraint": 5, "partition_bi": 5, "strategi": 5, "docstr": [5, 9, 14], "partition_by_devic": 5, "done": [5, 11, 12, 14], "clariti": 5, "memorybalancedpartition": 5, "max_search_count": 5, "toler": 5, "02": 5, "greedypartition": 5, "reject": 5, "200": 5, "wors": 5, "repeatedli": 5, "least": 5, "ordereddevicehardwar": 5, "devicehardwar": 5, "local_world_s": 5, "shardingoptiongroup": 5, "storage_sum": 5, "perf_sum": 5, "param_count": 5, "set_hbm_per_devic": 5, "hbm_per_devic": 5, "noopperfmodel": 5, "perfmodel": 5, "among": [5, 10], "without": [5, 9, 14], "noopstoragemodel": 5, "storagereserv": 5, "performance_model": 5, "heteroembeddingshardingplann": 5, "topology_group": 5, "dynamicprogrammingpropos": 5, "hbm_bins_per_devic": 5, "dynam": 5, "fashion": [5, 6], "problem": 5, "frame": 5, "n": [5, 8, 10, 11, 14], "minim": 5, "overal": [5, 10], "k": [5, 9, 10, 11], "mathemat": [5, 9], "formul": 5, "matrix": [5, 10, 11], "let": 5, "element": [5, 11], "denot": 5, "a_": 5, "j": 5, "b_": 5, "aim": 5, "j_0": 5, "j_1": 5, "ldot": 5, "j_": 5, "condit": [5, 11], "satisfi": 5, "sum_": 5, "j_i": 5, "leq": 5, "tackl": 5, "discret": 5, "k_i": 5, "transit": 5, "min_": 5, "left": [5, 14], "right": [5, 9, 11], "simpli": 5, "fit": 5, "card": 5, "therefor": 5, "maintain": 5, "last": [5, 10, 11, 14], "layer": [5, 10, 11, 12], "under": [5, 9], "vari": 5, "hdm": 5, "bin": 5, "perf_rat": 5, "search_spac": 5, "search": 5, "embeddingoffloadscaleuppropos": 5, "use_depth": 5, "allocate_budget": 5, "budget": 5, "allocation_prior": 5, "build_affine_storage_model": 5, "uvm_caching_sharding_opt": 5, "clf_to_byt": 5, "get_budget": 5, "get_cach": 5, "get_expected_lookup": 5, "next_plan": 5, "starting_propos": 5, "promote_high_prefetch_overheaad_table_to_hbm": 5, "overhead": 5, "io": 5, "offload": 5, "undo": 5, "promot": 5, "greedypropos": 5, "threshold": [5, 9, 11], "On": [5, 11], "tri": [5, 12], "earli": 5, "stop": 5, "consecut": 5, "best_perf_r": 5, "gridsearchpropos": 5, "max_propos": 5, "10000": 5, "uniformpropos": 5, "proposers_to_proposals_list": 5, "proposers_list": 5, "static_feedback": 5, "embeddingoffloadstat": 5, "mrc_hist_count": 5, "height": 5, "uvm_fused_cach": 5, "cachebl": 5, "area": [5, 9], "curv": [5, 9], "histogram": 5, "nth": 5, "wa": [5, 8], "estimate_cache_miss_r": 5, "cache_s": 5, "hist": 5, "mrc": 5, "embeddingperfestim": 5, "is_infer": 5, "wall": 5, "sharder_map": 5, "perf_func_emb_wall_tim": 5, "shard_siz": 5, "input_length": 5, "input_data_type_s": 5, "table_data_type_s": 5, "output_data_type_s": 5, "fwd_a2a_comm_data_type_s": 5, "bwd_a2a_comm_data_type_s": 5, "fwd_sr_comm_data_type_s": 5, "bwd_sr_comm_data_type_s": 5, "num_pool": 5, "intra_host_bw": 5, "inter_host_bw": 5, "bwd_compute_multipli": 5, "weighted_feature_bwd_compute_multipli": 5, "is_pool": 5, "expected_cache_fetch": 5, "uneven_sharding_perf_multipli": 5, "attempt": 5, "rel": [5, 11], "tw": 5, "queri": 5, "fwd_comm_data_type_s": 5, "bwd_comm_data_type_s": 5, "machin": [5, 11], "embeddingbag": [5, 7, 10, 11, 13], "unpool": 5, "ebc": [5, 8, 10, 11, 13], "signifi": 5, "fetch": 5, "embeddingstorageestim": 5, "pipeline_typ": 5, "run_embedding_at_peak_memori": 5, "Will": [5, 9], "fwd": 5, "bwd": 5, "temporari": [5, 11], "toward": 5, "cost": [5, 11], "won": [5, 11], "ll": 5, "hidden": [5, 10, 11], "old": [5, 12], "agnost": 5, "forwrad": 5, "calculate_pipeline_io_cost": 5, "output_s": [5, 11], "prefetch_s": 5, "multipass_prefetch_max_pass": 5, "count_ephemeral_storage_cost": 5, "calculate_shard_storag": 5, "compris": 5, "synonym": 5, "byte": [5, 8, 9], "embeddingstat": 5, "sharding_plan": 5, "num_propos": 5, "num_plan": 5, "run_tim": 5, "best_plan": 5, "tabular": 5, "view": 5, "chosen": [5, 11], "evalu": [5, 11], "successfulli": 5, "noopembeddingstat": 5, "noop": 5, "round_to_one_sigfig": 5, "fixedpercentagestoragereserv": 5, "percentag": 5, "heuristicalstoragereserv": 5, "parameter_multipli": 5, "dense_tensor_estim": 5, "heurist": 5, "extra": 5, "percent": 5, "act": 5, "margin": 5, "error": [5, 9, 11, 14], "beyond": 5, "inferencestoragereserv": 5, "customtopologydata": 5, "get_data": 5, "has_data": 5, "supported_field": 5, "ddr_cap": 5, "hbm_cap": 5, "512": [5, 9], "min_partit": 5, "pooling_factor": 5, "fbgemm_gpu": 5, "split_table_batched_embeddings_ops_common": 5, "device_group": 5, "around": 5, "lower": [5, 7, 8, 12, 13], "divid": [5, 9], "divis": 5, "optionallist": 5, "momentum": 5, "accuraci": [5, 11], "term": [5, 11], "fp16": 5, "exce": 5, "todai": 5, "bldm": 5, "fwd_comput": 5, "fwd_comm": 5, "bwd_comput": 5, "bwd_comm": 5, "prefetch_comput": 5, "breakdown": 5, "plannererror": 5, "error_typ": 5, "plannererrortyp": 5, "classifi": 5, "insufficient_storag": 5, "strict_constraint": 5, "prospos": 5, "paritit": 5, "much": [5, 12], "depend": [5, 8, 11], "One": [5, 9, 11], "eval": 5, "job": 5, "tower": [5, 11], "cache_load_factor": 5, "module_pool": 5, "sharding_option_nam": 5, "num_input": 5, "num_shard": 5, "total_perf": 5, "total_storag": 5, "capac": 5, "hardwar": 5, "fits_in": 5, "963146416": 5, "128": [5, 9], "54760833": 5, "024": 5, "644245094": 5, "13421772": 5, "custom_topology_data": 5, "binarysearchpred": 5, "extern": [5, 10], "predic": 5, "discov": 5, "try": 5, "prior_result": 5, "probe": 5, "prior": 5, "explor": 5, "reach": [5, 9], "luusjaakolasearch": 5, "max_iter": 5, "42": 5, "left_cost": 5, "clamp": 5, "variant": 5, "luu": 5, "jaakola": 5, "en": 5, "wikipedia": 5, "wiki": 5, "far": 5, "associ": 5, "fy": 5, "y": [5, 10, 11], "previou": 5, "subsequ": 5, "been": [5, 11], "shrink_right": 5, "shrink": 5, "boundari": 5, "infin": [5, 12], "random": [5, 10], "bytes_to_gb": 5, "num_byt": 5, "bytes_to_mb": 5, "gb_to_byt": 5, "gb": 5, "local_s": [5, 6], "prod": 5, "reset_shard_rank": 5, "sharder_nam": 5, "storage_repr_in_gb": 5, "basecwembeddingshard": 6, "basetwembeddingshard": 6, "cwpooledembeddingshard": 6, "infercwpooledembeddingdist": 6, "infercwpooledembeddingdistwithpermut": 6, "infercwpooledembeddingshard": 6, "basedpembeddingshard": 6, "dppooledembeddingdist": 6, "dppooledembeddingshard": 6, "dpsparsefeaturesdist": 6, "sparsefeatur": 6, "baserwembeddingshard": 6, "inferrwpooledembeddingdist": 6, "inferrwpooledembeddingshard": 6, "inferrwsparsefeaturesdist": 6, "rwpooledembeddingdist": 6, "share": [6, 11], "rwpooledembeddingshard": 6, "evenli": 6, "rwsparsefeaturesdist": 6, "intra_pg": 6, "get_block_sizes_runtime_devic": 6, "runtime_devic": 6, "tensor_cach": 6, "get_embedding_shard_metadata": 6, "grouped_embedding_configs_per_rank": 6, "infertwembeddingshard": 6, "infertwpooledembeddingdist": 6, "infertwsparsefeaturesdist": 6, "twpooledembeddingdist": 6, "twpooledembeddingshard": 6, "twsparsefeaturesdist": 6, "twcwpooledembeddingshard": 6, "basetwrwembeddingshard": 6, "twrwpooledembeddingdist": 6, "cross_pg": 6, "dim_sum_per_nod": 6, "emb_dim_per_node_per_featur": 6, "twrwpooledembeddingshard": 6, "twrwsparsefeaturesdist": 6, "id_list_features_per_rank": 6, "id_score_list_features_per_rank": 6, "id_list_feature_hash_s": 6, "id_score_list_feature_hash_s": 6, "look": [6, 7, 14], "reorder": 6, "document": [7, 10], "leaf_modul": 7, "trace": [7, 8], "torchscript": 7, "create_arg": 7, "memory_format": 7, "opoverload": 7, "symint": 7, "symbool": 7, "symfloat": 7, "prepar": [7, 11], "graph": 7, "emit": 7, "appropri": 7, "is_leaf_modul": 7, "module_qualified_nam": 7, "path_of_modul": 7, "mod": 7, "abil": 7, "made": [7, 12], "concrete_arg": 7, "is_fx_trac": 7, "symbolic_trac": 7, "graphmodul": 7, "symbol": 7, "record": [7, 11], "partial": 7, "structur": [7, 12], "predictfactorypackag": 8, "save_predict_factori": 8, "predict_factori": 8, "predictfactori": 8, "config": [8, 9, 10, 11], "pathlib": 8, "binaryio": 8, "extra_fil": 8, "loader_cod": 8, "nimport": 8, "packag": 8, "nmodule_factori": 8, "package_import": 8, "_sysimport": 8, "set_extern_modul": 8, "decor": 8, "abstractmethod": 8, "set_mocked_modul": 8, "load_config_text": 8, "load_pickle_config": 8, "clazz": 8, "batchingmetadata": 8, "pin": 8, "kept": [8, 11], "sync": [8, 9, 14], "batching_metadata": 8, "infom": 8, "batching_metadata_json": 8, "serial": 8, "json": 8, "eas": [8, 11], "pars": 8, "create_predict_modul": 8, "transformmodul": 8, "transform_state_dict": 8, "init_process_group": 8, "model_inputs_data": 8, "benchmark": 8, "qualname_metadata": 8, "qualnamemetadata": 8, "qualnam": 8, "inform": [8, 9, 14], "qualname_metadata_json": 8, "result_metadata": 8, "run_weights_dependent_transform": 8, "predict_modul": 8, "predict": [8, 9], "run_weights_independent_tranform": 8, "fx": 8, "predictmodul": 8, "predict_forward": 8, "primari": 8, "need_preproc": 8, "assign_weights_to_tb": 8, "table_to_weight": 8, "get_table_to_weights_from_tb": 8, "quantize_dens": 8, "additional_embedding_module_typ": 8, "quantize_embed": 8, "inplac": [8, 13], "additional_qconfig_spec_kei": 8, "additional_map": 8, "per_table_weight_dtyp": [8, 11], "quantize_featur": 8, "quantize_inference_model": 8, "quantization_map": 8, "fp_weight_dtyp": 8, "quantization_dtyp": 8, "swap": 8, "counterpart": 8, "quantembeddingbagcollect": [8, 13], "quantembeddingcollect": 8, "eb_config": 8, "dlrmpredictmodul": 8, "embedding_bag_collect": [8, 10, 11], "dense_in_featur": [8, 10], "model_config": 8, "dense_arch_layer_s": [8, 10], "over_arch_layer_s": [8, 10], "id_list_features_kei": 8, "dense_devic": [8, 10], "quant_model": 8, "set_pruning_data": 8, "tables_to_rows_post_prun": 8, "shard_quant_model": 8, "sharding_devic": 8, "device_memory_s": 8, "quantembeddingcollectionshard": 8, "tablewis": 8, "sharded_model": 8, "trim_torch_package_prefix_from_typenam": 8, "typenam": 8, "accuracymetr": 9, "task": 9, "rectaskinfo": 9, "compute_mod": 9, "reccomputemod": 9, "unfused_tasks_comput": 9, "window_s": 9, "fused_update_limit": 9, "compute_on_all_rank": 9, "should_validate_upd": 9, "process_group": 9, "recmetr": 9, "accuracymetriccomput": 9, "recmetriccomput": 9, "constructor": [9, 11], "cut": [9, 11], "off": [9, 11], "compute_accuraci": 9, "accuracy_sum": 9, "weighted_num_sampl": 9, "compute_accuracy_sum": 9, "get_accuracy_st": 9, "aucmetr": 9, "aucmetriccomput": 9, "grouped_auc": 9, "apply_bin": 9, "grouping_kei": 9, "reset": [9, 11, 12], "n_task": 9, "n_exampl": 9, "compute_auc": 9, "classif": 9, "compute_auc_per_group": 9, "auprcmetr": 9, "auprcmetriccomput": 9, "grouped_auprc": 9, "pr": 9, "compute_auprc": 9, "compute_auprc_per_group": 9, "calibrationmetr": 9, "calibrationmetriccomput": 9, "convers": 9, "compute_calibr": 9, "calibration_num": 9, "calibration_denom": 9, "get_calibration_st": 9, "ctrmetric": 9, "ctrmetriccomput": 9, "compute_ctr": 9, "ctr_num": 9, "ctr_denom": 9, "get_ctr_stat": 9, "maemetr": 9, "maemetriccomput": 9, "absolut": 9, "compute_error_sum": 9, "compute_ma": 9, "error_sum": 9, "get_mae_st": 9, "msemetr": 9, "msemetriccomput": 9, "squar": [9, 11], "compute_ms": 9, "compute_rms": 9, "get_mse_st": 9, "multiclassrecallmetr": 9, "multiclassrecallmetriccomput": 9, "compute_multiclass_recall_at_k": 9, "tp_at_k": 9, "total_weight": 9, "compute_true_positives_at_k": 9, "n_class": 9, "tp": 9, "1st": 9, "2nd": [9, 11], "n_sampl": 9, "ground": 9, "truth": 9, "true_positives_list": 9, "9": [9, 10], "15": [9, 10], "compute_multiclass_k_sum": 9, "5000": 9, "7500": 9, "0000": [9, 11], "get_multiclass_recall_st": 9, "ndcgcomput": 9, "exponential_gain": 9, "session_kei": 9, "session_id": 9, "report_ndcg_as_decreasing_curv": 9, "remove_single_length_sess": 9, "scale_by_weights_tensor": 9, "is_negative_task_mask": 9, "normal": [9, 11], "discount": 9, "gain": 9, "tensorboard": 9, "captur": 9, "decreas": 9, "loss": [9, 10, 12], "oppos": 9, "visual": [9, 14], "similarli": 9, "entropi": 9, "pointwis": 9, "noth": 9, "ndcgmetric": 9, "nemetr": 9, "nemetriccomput": 9, "include_logloss": 9, "allow_missing_label_with_zero_weight": 9, "vanilla": 9, "logloss": 9, "compute_cross_entropi": 9, "eta": 9, "compute_logloss": 9, "ce_sum": 9, "pos_label": 9, "neg_label": 9, "compute_n": 9, "get_ne_st": 9, "recallmetr": 9, "recallmetriccomput": 9, "compute_false_neg_sum": 9, "compute_recal": 9, "num_true_posit": 9, "num_false_negit": 9, "compute_true_pos_sum": 9, "get_recall_st": 9, "precisionmetr": 9, "precisionmetriccomput": 9, "compute_false_pos_sum": 9, "compute_precis": 9, "num_false_posit": 9, "get_precision_st": 9, "raucmetr": 9, "raucmetriccomput": 9, "grouped_rauc": 9, "regress": 9, "compute_rauc": 9, "compute_rauc_per_group": 9, "conquer_and_count": 9, "left_index": 9, "mid_index": 9, "right_index": 9, "count_reverse_pairs_divide_and_conqu": 9, "low": [9, 10, 11], "throughputmetr": 9, "window_second": 9, "warmup_step": 9, "batch_size_stag": 9, "batchsizestag": 9, "32": [9, 11], "time_to_train_one_step": 9, "trainer": 9, "window": 9, "window_throughput": 9, "warmup": 9, "Not": 9, "weightedavgmetr": 9, "weightedavgmetriccomput": 9, "get_mean": 9, "value_sum": 9, "num_sampl": 9, "xaucmetr": 9, "xaucmetriccomput": 9, "compute_weighted_num_pair": 9, "compute_xauc": 9, "weighted_num_pair": 9, "get_xauc_st": 9, "recmetricmodul": 9, "rec_task": 9, "recmetriclist": 9, "throughput_metr": 9, "state_metr": 9, "statemetr": 9, "compute_interval_step": 9, "min_compute_interv": 9, "max_compute_interv": 9, "inf": [9, 12], "memory_usage_limit_mb": 9, "standalon": 9, "characterist": 9, "componenet": 9, "intern": [9, 11, 14], "logic": [9, 11], "unit": [9, 11], "dataclass": 9, "defaultmetricsconfig": 9, "statemetricenum": 9, "metricmodul": 9, "generate_metric_modul": 9, "metric_class": 9, "metrics_config": 9, "64": [9, 11], "state_metrics_map": 9, "mock_optim": 9, "check_memory_usag": 9, "compute_count": 9, "sink": 9, "get_memory_usag": 9, "get_required_input": 9, "last_compute_tim": 9, "local_comput": 9, "memory_usage_mb_avg": 9, "oom_count": 9, "should_comput": 9, "unsync": [9, 14], "model_out": 9, "model_output": 9, "due": 9, "slide": 9, "qat": 9, "get_metr": 9, "metricsconfig": 9, "metriccomputationreport": 9, "metrics_namespac": 9, "metricnamebas": 9, "metric_prefix": 9, "metricprefix": 9, "signal": 9, "own": 9, "__init__": 9, "_namespac": 9, "_metrics_comput": 9, "consum": 9, "invalid": 9, "defaulttaskinfo": 9, "rec": 9, "overwrit": 9, "synchron": 9, "get_window_st": 9, "state_nam": 9, "get_window_state_nam": 9, "pre_comput": 9, "torchmetr": 9, "aggreg": 9, "recmetricexcept": 9, "encapul": 9, "required_input": 9, "windowbuff": 9, "max_siz": 9, "max_buffer_count": 9, "aggregate_st": 9, "window_st": 9, "curr_stat": 9, "dequ": 9, "architectur": [10, 11], "deep": [10, 11], "sparsearch": 10, "densearch": 10, "interactionarch": 10, "overarch": 10, "found": 10, "notat": 10, "embedding_dimens": 10, "hidden_layer_s": 10, "deepfmnn": 10, "dimension": 10, "dense_arch": 10, "dense_arch_input": 10, "dense_embed": 10, "fminteractionarch": 10, "fm_in_featur": 10, "sparse_feature_nam": 10, "deep_fm_dimens": 10, "paper": [10, 11], "arxiv": 10, "pdf": 10, "1703": 10, "04247": 10, "cat": [10, 11], "dense_modul": [10, 11], "di": 10, "arch": 10, "fm_inter_arch": 10, "length_per_kei": [10, 14], "cat_fm_output": 10, "mlp": 10, "over_arch": 10, "logit": 10, "simpledeepfmnn": 10, "num_dense_featur": 10, "relationship": 10, "deep_fm": 10, "eb1_config": [10, 13], "f3": 10, "eb2_config": [10, 13], "t2": [10, 11, 13], "sparse_nn": 10, "over_embedding_dim": 10, "from_offsets_sync": [10, 11, 13, 14], "sparse_arch": 10, "ab": 10, "1906": 10, "00091": 10, "pairwis": 10, "dlrmtrain": 10, "dlrm_modul": 10, "train_pipelin": 10, "dlrm_project": 10, "dlrm_dcn": 10, "ebc_config": 10, "dlrm_model": 10, "dcn_num_lay": 10, "dcn_low_rank_dim": 10, "dcn": [10, 11], "modifi": [10, 11, 12], "similar": 10, "deepcrossnet": 10, "2008": 10, "13535": 10, "approxim": 10, "interaction_branch1_layer_s": 10, "interaction_branch2_layer_s": 10, "branch": 10, "layer_s": [10, 11], "num_sparse_featur": 10, "dot": [10, 11], "pair": 10, "inter_arch": 10, "choos": 10, "concat_dens": 10, "interactiondcnarch": 10, "crossnet": 10, "lowrankcrossnet": [10, 11], "dnc_low_rank_dim": 10, "interactionprojectionarch": 10, "interaction_branch1": 10, "interaction_branch2": 10, "z": 10, "bx": 10, "f1xd": 10, "dxf2": 10, "i1": 10, "i2": 10, "sparse_embed": 10, "math": 10, "comb": 10, "extens": 11, "establish": 11, "pattern": 11, "swishlayernorm": 11, "positionweightedmodul": 11, "lazymoduleextensionmixin": 11, "embeddingtow": 11, "embeddingtowercollect": 11, "input_dim": 11, "swish": 11, "sigmoid": 11, "layernorm": 11, "d1": 11, "d2": 11, "d3": 11, "sln": 11, "num_lay": 11, "stack": 11, "learnabl": 11, "polynom": 11, "nxn": 11, "cover": 11, "bit": 11, "x_": 11, "x_0": 11, "w_l": 11, "cdot": 11, "x_l": 11, "b_l": 11, "low_rank": 11, "highli": 11, "matric": 11, "simplifi": 11, "v_l": 11, "vector": 11, "smartli": 11, "setup": 11, "lowrankmixturecrossnet": 11, "num_expert": 11, "relu": 11, "mixtur": 11, "expert": 11, "compar": [11, 14], "subspac": 11, "gate": 11, "moe": 11, "expert_i": 11, "k_": 11, "u_": 11, "li": 11, "c_": 11, "v_": 11, "vectorcrossnet": 11, "nx1": 11, "thu": [11, 12], "further": [11, 14], "implent": 11, "framework": 11, "factorizationmachin": 11, "fm": 11, "publish": 11, "learnt": 11, "To": 11, "90": 11, "30": 11, "40": 11, "fb": 11, "lazymlp": 11, "output_dim": 11, "192": 11, "deep_fm_output": 11, "common_spars": 11, "specialized_spars": 11, "embedding_featur": 11, "raw_embedding_featur": 11, "nativ": 11, "trained_embed": 11, "native_embed": 11, "ident": 11, "mention": 11, "baseembeddingconfig": 11, "get_weight_init_max": 11, "get_weight_init_min": 11, "embeddingconfig": [11, 13], "quantconfig": 11, "placeholderobserv": [11, 13], "alia": 11, "data_type_to_dtyp": 11, "data_type_to_sparse_typ": 11, "sparsetyp": 11, "dtype_to_data_typ": 11, "pooling_type_to_pooling_mod": 11, "pooling_typ": 11, "poolingmod": 11, "pooling_type_to_str": 11, "sensit": [11, 13], "jag": [11, 13, 14], "table_0": [11, 13], "table_1": [11, 13], "pooled_embed": 11, "8899": 11, "1342": 11, "9060": 11, "0905": 11, "2814": 11, "9369": 11, "7783": 11, "1598": 11, "0695": 11, "3265": 11, "1011": 11, "4256": 11, "1846": 11, "1648": 11, "0893": 11, "3590": 11, "9784": 11, "7681": 11, "grad_fn": [11, 13], "catbackward0": 11, "offset_per_kei": [11, 14], "need_indic": [11, 13], "e1_config": [11, 13], "e2_config": [11, 13], "ec": [11, 13], "feature_embed": [11, 13], "2050": [11, 13], "5478": [11, 13], "6054": [11, 13], "7352": [11, 13], "3210": [11, 13], "0399": [11, 13], "1279": [11, 13], "1756": [11, 13], "4130": [11, 13], "7519": [11, 13], "4341": [11, 13], "0499": [11, 13], "9329": [11, 13], "0697": [11, 13], "8095": [11, 13], "embeddingbackward": [11, 13], "embedding_names_by_t": [11, 13], "get_embedding_names_by_t": 11, "process_pooled_embed": 11, "reorder_inverse_indic": 11, "basefeatureprocessor": 11, "max_length": 11, "truncat": 11, "positionweightedprocessor": 11, "feature_length": 11, "feature0": [11, 14], "feature1": [11, 14], "feature2": 11, "from_lengths_sync": [11, 14], "pw": 11, "featureprocessorcollect": 11, "feature_processor_modul": 11, "positionweightedfeatureprocessor": 11, "fp_featur": 11, "non_fp_featur": 11, "non_fp": 11, "feature_process": 11, "And": 11, "offsets_to_range_tracebl": 11, "position_weighted_module_update_featur": 11, "weighted_featur": 11, "lazymodulemixin": 11, "upstream": 11, "59923": 11, "testlazymoduleextensionmixin": 11, "_infer_paramet": 11, "pariti": 11, "_call_impl": 11, "fn": 11, "children": 11, "uniniti": 11, "dummi": [11, 12], "lazylinear": 11, "fail": [11, 14], "hasn": 11, "yet": 11, "now": [11, 14], "lazy_appli": 11, "attach": 11, "numer": 11, "immedi": 11, "seq": 11, "in_siz": 11, "perceptron": 11, "multi": 11, "out_siz": 11, "swish_layernorm": 11, "mlp_modul": 11, "assert": 11, "o": 11, "channel": 11, "unpadded_length": 11, "reindexed_length": 11, "reindexed_length_per_kei": 11, "reindexed_valu": 11, "check_module_output_dimens": 11, "verifi": 11, "construct_jagged_tensor": 11, "features_to_permute_indic": 11, "original_featur": 11, "construct_jagged_tensors_infer": 11, "construct_modulelist_from_single_modul": 11, "nest": 11, "reiniti": 11, "convert_list_of_modules_to_modulelist": 11, "deterministic_dedup": 11, "race": 11, "conflict": 11, "extract_module_or_tensor_cal": 11, "module_or_cal": 11, "get_module_output_dimens": 11, "init_mlp_weights_xavier_uniform": 11, "jagged_index_select_with_empti": 11, "output_offset": 11, "distancelfu_evictionpolici": 11, "decay_expon": 11, "threshold_filtering_func": 11, "mchevictionpolici": 11, "coalesce_history_metadata": 11, "current_it": 11, "history_metadata": 11, "unique_ids_count": 11, "unique_inverse_map": 11, "additional_id": 11, "threshold_mask": 11, "histori": 11, "invers": [11, 14], "history_accumul": 11, "coalesc": 11, "metadata_info": 11, "mchevictionpolicymetadatainfo": 11, "record_history_metadata": 11, "incoming_id": 11, "incom": 11, "polici": [11, 12], "update_metadata_and_generate_eviction_scor": 11, "mch_size": 11, "coalesced_history_argsort_map": 11, "coalesced_history_sorted_unique_ids_count": 11, "coalesced_history_mch_matching_elements_mask": 11, "coalesced_history_mch_matching_indic": 11, "mch_metadata": 11, "coalesced_history_metadata": 11, "evicted_indic": 11, "selected_new_indic": 11, "mch": 11, "lfu_evictionpolici": 11, "lru_evictionpolici": 11, "metadata_nam": 11, "is_mch_metadata": 11, "is_history_metadata": 11, "mchmanagedcollisionmodul": 11, "zch_size": 11, "eviction_polici": 11, "eviction_interv": 11, "input_hash_s": 11, "9223372036854775807": 11, "input_hash_func": 11, "mch_hash_func": 11, "output_global_offset": 11, "output_seg": 11, "managedcollisionmodul": 11, "zch": 11, "collis": 11, "output_size_offset": 11, "drive": 11, "greater": 11, "depreci": 11, "residu": 11, "legaci": 11, "shift": 11, "zch_output_rang": 11, "down": 11, "applic": 11, "slot": 11, "assumptionn": 11, "downstream": 11, "rtype": 11, "vs": 11, "profil": 11, "rebuild_with_output_id_rang": 11, "output_id_rang": 11, "mc": 11, "validate_st": 11, "checkpoint": [11, 12], "managed_collision_modul": 11, "need_preprocess": 11, "mcc": 11, "embedding_confg": 11, "collsion": 11, "skip_state_valid": 11, "max_output_id": 11, "remapping_range_start_index": 11, "mcm": 11, "mcm_jt": 11, "fp": 11, "apply_mc_method_to_jt_dict": 11, "features_dict": 11, "average_threshold_filt": 11, "id_count": 11, "dynamic_threshold_filt": 11, "threshold_skew_multipli": 11, "total_count": 11, "num_id": 11, "probabilistic_threshold_filt": 11, "per_id_prob": 11, "01": 11, "probabl": 11, "60": 11, "randomli": 11, "chanc": 11, "basemanagedcollisionembeddingcollect": 11, "managed_collision_collect": 11, "return_remapped_featur": 11, "embedding_collect": 11, "meaning": 12, "prohibit": 12, "empti": [12, 14], "sever": 12, "combinedoptim": 12, "optimizerwrapp": 12, "rowwis": 12, "gradientclip": 12, "norm": 12, "gradientclippingoptim": 12, "max_gradi": 12, "norm_typ": 12, "p": 12, "closur": 12, "reevalu": 12, "emptyfusedoptim": 12, "fusedoptim": 12, "zero_grad": 12, "set_to_non": 12, "zero": [12, 14], "footprint": 12, "modestli": 12, "certain": 12, "0s": 12, "behav": 12, "did": 12, "altogeth": 12, "param_group": 12, "meant": 12, "post_load_state_dict": 12, "prepend_opt_kei": 12, "opt_kei": 12, "save_param_group": 12, "set_optimizer_step": 12, "stricter": 12, "switch": 12, "flag": 12, "identifi": 12, "littl": 12, "add_param_group": 12, "fine": 12, "frozen": 12, "trainabl": 12, "progress": 12, "what": 12, "init_st": 12, "introduc": 12, "usabl": 12, "sd": 12, "load_checkpoint": 12, "protocol": 12, "keyedoptimizerwrapp": 12, "optim_factori": 12, "conveni": 12, "warmupoptim": 12, "warmupstag": 12, "lr": 12, "lr_param": 12, "param_nam": 12, "__warmup": 12, "adjust": 12, "schedul": 12, "fake": 12, "warmuppolici": 12, "constant": 12, "cosine_annealing_warm_restart": 12, "invsqrt": 12, "inv_sqrt": 12, "poli": 12, "max_it": 12, "lr_scale": 12, "decay_it": 12, "sgdr_period": 12, "trec_quant": 13, "trec": 13, "qconfig": 13, "activ": 13, "with_arg": 13, "qint8": 13, "quantize_dynam": 13, "qconfig_spec": 13, "table_name_to_quantized_weight": 13, "register_tb": 13, "quant_state_dict_split_scale_bia": 13, "row_align": 13, "qebc": 13, "from_float": 13, "quantized_embed": 13, "use_precomputed_fake_qu": 13, "for_each_module_of_type_do": 13, "quant_prep_customize_row_align": 13, "quant_prep_enable_quant_state_dict_split_scale_bia": 13, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 13, "quant_prep_enable_register_tb": 13, "quantize_state_dict": 13, "table_name_to_data_typ": 13, "table_name_to_num_embeddings_post_prun": 13, "whose": 14, "dimes": 14, "computejtdicttokjt": 14, "dim_1": 14, "dim_0": 14, "computekjttojtdict": 14, "keyed_jagged_tensor": 14, "jit": 14, "abl": 14, "NOT": 14, "expens": 14, "values_dtyp": 14, "weights_dtyp": 14, "lengths_dtyp": 14, "from_dens": 14, "2d": 14, "11": 14, "12": 14, "j1": 14, "from_dense_length": 14, "lengths_or_non": 14, "offsets_or_non": 14, "to_dens": 14, "inttensor": 14, "values_list": 14, "to_dense_weight": 14, "weights_list": 14, "to_padded_dens": 14, "desired_length": 14, "padding_valu": 14, "longest": 14, "pad": 14, "dt": 14, "to_padded_dense_weight": 14, "d_wt": 14, "weights_or_non": 14, "jaggedtensormeta": 14, "abcmeta": 14, "proxyableclassmeta": 14, "stride_per_key_per_rank": 14, "outer": 14, "inner": 14, "index_per_kei": 14, "expand": 14, "dedupl": 14, "dim_2": 14, "w0": 14, "w1": 14, "w2": 14, "w3": 14, "w4": 14, "w5": 14, "w6": 14, "w7": 14, "dist_init": 14, "variable_stride_per_kei": 14, "dist_label": 14, "dist_split": 14, "key_split": 14, "dist_tensor": 14, "empty_lik": 14, "flatten_length": 14, "from_jt_dict": 14, "implicit": 14, "variable_feature_dim": 14, "But": 14, "That": 14, "didn": 14, "notic": 14, "correctli": 14, "technic": 14, "know": 14, "violat": 14, "precondit": 14, "inverse_indices_or_non": 14, "length_per_key_or_non": 14, "lengths_offset_per_kei": 14, "offset_per_key_or_non": 14, "indices_tensor": 14, "segment": 14, "stride_per_kei": 14, "to_dict": 14, "key_dim": 14, "tensor_list": 14, "from_tensor_list": 14, "regroup": 14, "keyed_tensor": 14, "regroup_as_dict": 14, "flatten_kjt_list": 14, "kjt_arr": 14, "jt_is_equ": 14, "jt_1": 14, "jt_2": 14, "comparison": 14, "themselv": 14, "treat": 14, "kjt_is_equ": 14, "kjt_1": 14, "kjt_2": 14, "permute_multi_embed": 14, "regroup_kt": 14, "unflatten_kjt_list": 14}, "objects": {"torchrec": [[2, 0, 0, "-", "datasets"], [4, 0, 0, "-", "distributed"], [7, 0, 0, "module-0", "fx"], [8, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "metrics"], [10, 0, 0, "module-0", "models"], [11, 0, 0, "-", "modules"], [12, 0, 0, "module-0", "optim"], [13, 0, 0, "module-0", "quant"], [14, 0, 0, "module-0", "sparse"]], "torchrec.datasets": [[2, 0, 0, "-", "criteo"], [2, 0, 0, "-", "movielens"], [2, 0, 0, "-", "random"], [3, 0, 0, "-", "scripts"], [2, 0, 0, "-", "utils"]], "torchrec.datasets.criteo": [[2, 1, 1, "", "BinaryCriteoUtils"], [2, 1, 1, "", "CriteoIterDataPipe"], [2, 1, 1, "", "InMemoryBinaryCriteoIterDataPipe"], [2, 3, 1, "", "criteo_kaggle"], [2, 3, 1, "", "criteo_terabyte"]], "torchrec.datasets.criteo.BinaryCriteoUtils": [[2, 2, 1, "", "get_file_row_ranges_and_remainder"], [2, 2, 1, "", "get_shape_from_npy"], [2, 2, 1, "", "load_npy_range"], [2, 2, 1, "", "shuffle"], [2, 2, 1, "", "sparse_to_contiguous"], [2, 2, 1, "", "tsv_to_npys"]], "torchrec.datasets.movielens": [[2, 3, 1, "", "movielens_20m"], [2, 3, 1, "", "movielens_25m"]], "torchrec.datasets.random": [[2, 1, 1, "", "RandomRecDataset"]], "torchrec.datasets.scripts": [[3, 0, 0, "-", "contiguous_preproc_criteo"], [3, 0, 0, "-", "npy_preproc_criteo"]], "torchrec.datasets.scripts.contiguous_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.scripts.npy_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.utils": [[2, 1, 1, "", "Batch"], [2, 1, 1, "", "Limit"], [2, 1, 1, "", "LoadFiles"], [2, 1, 1, "", "ParallelReadConcat"], [2, 1, 1, "", "ReadLinesFromCSV"], [2, 3, 1, "", "idx_split_train_val"], [2, 3, 1, "", "rand_split_train_val"], [2, 3, 1, "", "safe_cast"], [2, 3, 1, "", "train_filter"], [2, 3, 1, "", "val_filter"]], "torchrec.datasets.utils.Batch": [[2, 4, 1, "", "dense_features"], [2, 4, 1, "", "labels"], [2, 2, 1, "", "pin_memory"], [2, 2, 1, "", "record_stream"], [2, 4, 1, "", "sparse_features"], [2, 2, 1, "", "to"]], "torchrec.distributed": [[4, 0, 0, "-", "collective_utils"], [4, 0, 0, "-", "comm"], [4, 0, 0, "-", "comm_ops"], [6, 0, 0, "-", "dist_data"], [4, 0, 0, "-", "embedding"], [4, 0, 0, "-", "embedding_lookup"], [4, 0, 0, "-", "embedding_sharding"], [4, 0, 0, "-", "embedding_types"], [4, 0, 0, "-", "embeddingbag"], [4, 0, 0, "-", "grouped_position_weighted"], [4, 0, 0, "-", "mc_embedding"], [4, 0, 0, "-", "mc_embeddingbag"], [4, 0, 0, "-", "mc_modules"], [4, 0, 0, "-", "model_parallel"], [5, 0, 0, "-", "planner"], [4, 0, 0, "-", "quant_embeddingbag"], [6, 0, 0, "-", "sharding"], [4, 0, 0, "-", "train_pipeline"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[4, 3, 1, "", "invoke_on_rank_and_broadcast_result"], [4, 3, 1, "", "is_leader"], [4, 3, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[4, 3, 1, "", "get_group_rank"], [4, 3, 1, "", "get_local_rank"], [4, 3, 1, "", "get_local_size"], [4, 3, 1, "", "get_num_groups"], [4, 3, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[4, 1, 1, "", "All2AllDenseInfo"], [4, 1, 1, "", "All2AllPooledInfo"], [4, 1, 1, "", "All2AllSequenceInfo"], [4, 1, 1, "", "All2AllVInfo"], [4, 1, 1, "", "All2All_Pooled_Req"], [4, 1, 1, "", "All2All_Pooled_Wait"], [4, 1, 1, "", "All2All_Seq_Req"], [4, 1, 1, "", "All2All_Seq_Req_Wait"], [4, 1, 1, "", "All2Allv_Req"], [4, 1, 1, "", "All2Allv_Wait"], [4, 1, 1, "", "AllGatherBaseInfo"], [4, 1, 1, "", "AllGatherBase_Req"], [4, 1, 1, "", "AllGatherBase_Wait"], [4, 1, 1, "", "ReduceScatterBaseInfo"], [4, 1, 1, "", "ReduceScatterBase_Req"], [4, 1, 1, "", "ReduceScatterBase_Wait"], [4, 1, 1, "", "ReduceScatterInfo"], [4, 1, 1, "", "ReduceScatterVInfo"], [4, 1, 1, "", "ReduceScatterV_Req"], [4, 1, 1, "", "ReduceScatterV_Wait"], [4, 1, 1, "", "ReduceScatter_Req"], [4, 1, 1, "", "ReduceScatter_Wait"], [4, 1, 1, "", "Request"], [4, 1, 1, "", "VariableBatchAll2AllPooledInfo"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Req"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [4, 3, 1, "", "all2all_pooled_sync"], [4, 3, 1, "", "all2all_sequence_sync"], [4, 3, 1, "", "all2allv_sync"], [4, 3, 1, "", "all_gather_base_pooled"], [4, 3, 1, "", "all_gather_base_sync"], [4, 3, 1, "", "all_gather_into_tensor_backward"], [4, 3, 1, "", "all_gather_into_tensor_fake"], [4, 3, 1, "", "all_gather_into_tensor_setup_context"], [4, 3, 1, "", "all_to_all_single_backward"], [4, 3, 1, "", "all_to_all_single_fake"], [4, 3, 1, "", "all_to_all_single_setup_context"], [4, 3, 1, "", "alltoall_pooled"], [4, 3, 1, "", "alltoall_sequence"], [4, 3, 1, "", "alltoallv"], [4, 3, 1, "", "get_gradient_division"], [4, 3, 1, "", "get_use_sync_collectives"], [4, 3, 1, "", "pg_name"], [4, 3, 1, "", "reduce_scatter_base_pooled"], [4, 3, 1, "", "reduce_scatter_base_sync"], [4, 3, 1, "", "reduce_scatter_pooled"], [4, 3, 1, "", "reduce_scatter_sync"], [4, 3, 1, "", "reduce_scatter_tensor_backward"], [4, 3, 1, "", "reduce_scatter_tensor_fake"], [4, 3, 1, "", "reduce_scatter_tensor_setup_context"], [4, 3, 1, "", "reduce_scatter_v_per_feature_pooled"], [4, 3, 1, "", "reduce_scatter_v_pooled"], [4, 3, 1, "", "reduce_scatter_v_sync"], [4, 3, 1, "", "set_gradient_division"], [4, 3, 1, "", "set_use_sync_collectives"], [4, 3, 1, "", "torchrec_use_sync_collectives"], [4, 3, 1, "", "variable_batch_all2all_pooled_sync"], [4, 3, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[4, 4, 1, "", "batch_size"], [4, 4, 1, "", "input_shape"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[4, 4, 1, "id0", "batch_size_per_rank"], [4, 4, 1, "id1", "codecs"], [4, 4, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [4, 4, 1, "id3", "dim_sum_per_rank"], [4, 4, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[4, 4, 1, "id5", "backward_recat_tensor"], [4, 4, 1, "id6", "codecs"], [4, 4, 1, "id7", "embedding_dim"], [4, 4, 1, "id8", "forward_recat_tensor"], [4, 4, 1, "id9", "input_splits"], [4, 4, 1, "id10", "lengths_after_sparse_data_all2all"], [4, 4, 1, "id11", "output_splits"], [4, 4, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [4, 4, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[4, 4, 1, "id14", "B_global"], [4, 4, 1, "id15", "B_local"], [4, 4, 1, "id16", "B_local_list"], [4, 4, 1, "id17", "D_local_list"], [4, 4, 1, "", "codecs"], [4, 4, 1, "", "dim_sum_per_rank"], [4, 4, 1, "", "dims_sum_per_rank"], [4, 4, 1, "id18", "input_split_sizes"], [4, 4, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[4, 4, 1, "id23", "codecs"], [4, 4, 1, "id24", "equal_splits"], [4, 4, 1, "id25", "input_sizes"], [4, 4, 1, "id26", "input_splits"], [4, 4, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[4, 4, 1, "id28", "batch_size_per_feature_pre_a2a"], [4, 4, 1, "id29", "batch_size_per_rank_per_feature"], [4, 4, 1, "id30", "codecs"], [4, 4, 1, "id31", "emb_dim_per_rank_per_feature"], [4, 4, 1, "id32", "input_splits"], [4, 4, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.dist_data": [[6, 1, 1, "", "EmbeddingsAllToOne"], [6, 1, 1, "", "EmbeddingsAllToOneReduce"], [6, 1, 1, "", "JaggedTensorAllToAll"], [6, 1, 1, "", "KJTAllToAll"], [6, 1, 1, "", "KJTAllToAllSplitsAwaitable"], [6, 1, 1, "", "KJTAllToAllTensorsAwaitable"], [6, 1, 1, "", "KJTOneToAll"], [6, 1, 1, "", "MergePooledEmbeddingsModule"], [6, 1, 1, "", "PooledEmbeddingsAllGather"], [6, 1, 1, "", "PooledEmbeddingsAllToAll"], [6, 1, 1, "", "PooledEmbeddingsAwaitable"], [6, 1, 1, "", "PooledEmbeddingsReduceScatter"], [6, 1, 1, "", "SeqEmbeddingsAllToOne"], [6, 1, 1, "", "SequenceEmbeddingsAllToAll"], [6, 1, 1, "", "SequenceEmbeddingsAwaitable"], [6, 1, 1, "", "SplitsAllToAllAwaitable"], [6, 1, 1, "", "TensorAllToAll"], [6, 1, 1, "", "TensorAllToAllSplitsAwaitable"], [6, 1, 1, "", "TensorAllToAllValuesAwaitable"], [6, 1, 1, "", "TensorValuesAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.MergePooledEmbeddingsModule": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[6, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorValuesAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.embedding": [[4, 1, 1, "", "EmbeddingCollectionAwaitable"], [4, 1, 1, "", "EmbeddingCollectionContext"], [4, 1, 1, "", "EmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedEmbeddingCollection"], [4, 3, 1, "", "create_embedding_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_ec_index_dedup"], [4, 3, 1, "", "pad_vbe_kjt_lengths"], [4, 3, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[4, 1, 1, "", "CommOpGradientScaling"], [4, 1, 1, "", "GroupedEmbeddingsLookup"], [4, 1, 1, "", "GroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "InferCPUGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedLookupMixin"], [4, 1, 1, "", "InferGroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [4, 3, 1, "", "dummy_tensor"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle_inference"], [4, 3, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[4, 1, 1, "", "BaseEmbeddingDist"], [4, 1, 1, "", "BaseSparseFeaturesDist"], [4, 1, 1, "", "EmbeddingSharding"], [4, 1, 1, "", "EmbeddingShardingContext"], [4, 1, 1, "", "EmbeddingShardingInfo"], [4, 1, 1, "", "FusedKJTListSplitsAwaitable"], [4, 1, 1, "", "KJTListAwaitable"], [4, 1, 1, "", "KJTListSplitsAwaitable"], [4, 1, 1, "", "KJTSplitsAllToAllMeta"], [4, 1, 1, "", "ListOfKJTListAwaitable"], [4, 1, 1, "", "ListOfKJTListSplitsAwaitable"], [4, 3, 1, "", "bucketize_kjt_before_all2all"], [4, 3, 1, "", "bucketize_kjt_inference"], [4, 3, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[4, 2, 1, "", "create_input_dist"], [4, 2, 1, "", "create_lookup"], [4, 2, 1, "", "create_output_dist"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_names_per_rank"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 2, 1, "", "embedding_tables"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "uncombined_embedding_dims"], [4, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[4, 4, 1, "", "embedding_config"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "param"], [4, 4, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[4, 4, 1, "", "device"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "input_tensors"], [4, 4, 1, "", "keys"], [4, 4, 1, "", "labels"], [4, 4, 1, "", "pg"], [4, 4, 1, "", "splits"], [4, 4, 1, "", "splits_tensors"], [4, 4, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[4, 1, 1, "", "BaseEmbeddingLookup"], [4, 1, 1, "", "BaseEmbeddingSharder"], [4, 1, 1, "", "BaseGroupedFeatureProcessor"], [4, 1, 1, "", "BaseQuantEmbeddingSharder"], [4, 1, 1, "", "DTensorMetadata"], [4, 1, 1, "", "EmbeddingAttributes"], [4, 1, 1, "", "EmbeddingComputeKernel"], [4, 1, 1, "", "FeatureShardingMixIn"], [4, 1, 1, "", "GroupedEmbeddingConfig"], [4, 1, 1, "", "InputDistOutputs"], [4, 1, 1, "", "KJTList"], [4, 1, 1, "", "ListOfKJTList"], [4, 1, 1, "", "ModuleShardingMixIn"], [4, 1, 1, "", "OptimType"], [4, 1, 1, "", "ShardedConfig"], [4, 1, 1, "", "ShardedEmbeddingModule"], [4, 1, 1, "", "ShardedEmbeddingTable"], [4, 1, 1, "", "ShardedMetaConfig"], [4, 3, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.DTensorMetadata": [[4, 4, 1, "", "mesh"], [4, 4, 1, "", "placements"], [4, 4, 1, "", "size"], [4, 4, 1, "", "stride"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[4, 4, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[4, 4, 1, "", "DENSE"], [4, 4, 1, "", "FUSED"], [4, 4, 1, "", "FUSED_UVM"], [4, 4, 1, "", "FUSED_UVM_CACHING"], [4, 4, 1, "", "KEY_VALUE"], [4, 4, 1, "", "QUANT"], [4, 4, 1, "", "QUANT_UVM"], [4, 4, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[4, 2, 1, "", "feature_names"], [4, 2, 1, "", "feature_names_per_rank"], [4, 2, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "data_type"], [4, 2, 1, "", "dim_sum"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 4, 1, "", "embedding_tables"], [4, 2, 1, "", "feature_hash_sizes"], [4, 2, 1, "", "feature_names"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "has_feature_processor"], [4, 4, 1, "", "is_weighted"], [4, 2, 1, "", "num_features"], [4, 4, 1, "", "pooling"], [4, 2, 1, "", "table_names"]], "torchrec.distributed.embedding_types.InputDistOutputs": [[4, 4, 1, "", "bucket_mapping_tensor"], [4, 4, 1, "", "bucketized_length"], [4, 4, 1, "", "features"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "unbucketize_permute_tensor"]], "torchrec.distributed.embedding_types.KJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[4, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[4, 4, 1, "", "ADAGRAD"], [4, 4, 1, "", "ADAM"], [4, 4, 1, "", "ADAMW"], [4, 4, 1, "", "LAMB"], [4, 4, 1, "", "LARS_SGD"], [4, 4, 1, "", "LION"], [4, 4, 1, "", "PARTIAL_ROWWISE_ADAM"], [4, 4, 1, "", "PARTIAL_ROWWISE_LAMB"], [4, 4, 1, "", "ROWWISE_ADAGRAD"], [4, 4, 1, "", "SGD"], [4, 4, 1, "", "SHAMPOO"], [4, 4, 1, "", "SHAMPOO_V2"], [4, 4, 1, "", "SHAMPOO_V2_MRS"]], "torchrec.distributed.embedding_types.ShardedConfig": [[4, 4, 1, "", "local_cols"], [4, 4, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[4, 2, 1, "", "extra_repr"], [4, 2, 1, "", "prefetch"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[4, 4, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[4, 4, 1, "", "dtensor_metadata"], [4, 4, 1, "", "global_metadata"], [4, 4, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[4, 1, 1, "", "EmbeddingAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionContext"], [4, 1, 1, "", "EmbeddingBagCollectionSharder"], [4, 1, 1, "", "EmbeddingBagSharder"], [4, 1, 1, "", "ShardedEmbeddingBag"], [4, 1, 1, "", "ShardedEmbeddingBagCollection"], [4, 1, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [4, 3, 1, "", "construct_output_kt"], [4, 3, 1, "", "create_embedding_bag_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[4, 4, 1, "", "divisor"], [4, 4, 1, "", "inverse_indices"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "sharding_contexts"], [4, 4, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_modules"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[4, 1, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[4, 1, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[4, 4, 1, "", "evictions_per_table"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_modules": [[4, 1, 1, "", "ManagedCollisionCollectionAwaitable"], [4, 1, 1, "", "ManagedCollisionCollectionContext"], [4, 1, 1, "", "ManagedCollisionCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionCollection"], [4, 3, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "evict"], [4, 2, 1, "", "global_to_local_index"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "open_slots"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.model_parallel": [[4, 1, 1, "", "DataParallelWrapper"], [4, 1, 1, "", "DefaultDataParallelWrapper"], [4, 1, 1, "", "DistributedModelParallel"], [4, 3, 1, "", "get_module"], [4, 3, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[4, 2, 1, "", "bare_named_parameters"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "forward"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "init_data_parallel"], [4, 2, 1, "", "load_state_dict"], [4, 5, 1, "", "module"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 5, 1, "", "plan"], [4, 2, 1, "", "sparse_grad_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.planner": [[5, 0, 0, "-", "constants"], [5, 0, 0, "-", "enumerators"], [5, 0, 0, "-", "partitioners"], [5, 0, 0, "-", "perf_models"], [5, 0, 0, "-", "planners"], [5, 0, 0, "-", "proposers"], [5, 0, 0, "-", "shard_estimators"], [5, 0, 0, "-", "stats"], [5, 0, 0, "-", "storage_reservations"], [5, 0, 0, "-", "types"], [5, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[5, 3, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[5, 1, 1, "", "EmbeddingEnumerator"], [5, 3, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[5, 1, 1, "", "GreedyPerfPartitioner"], [5, 1, 1, "", "MemoryBalancedPartitioner"], [5, 1, 1, "", "OrderedDeviceHardware"], [5, 1, 1, "", "ShardingOptionGroup"], [5, 1, 1, "", "SortBy"], [5, 3, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[5, 4, 1, "", "device"], [5, 4, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[5, 4, 1, "", "param_count"], [5, 4, 1, "", "perf_sum"], [5, 4, 1, "", "sharding_options"], [5, 4, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[5, 4, 1, "", "PERF"], [5, 4, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[5, 1, 1, "", "NoopPerfModel"], [5, 1, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[5, 1, 1, "", "EmbeddingShardingPlanner"], [5, 1, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[5, 1, 1, "", "DynamicProgrammingProposer"], [5, 1, 1, "", "EmbeddingOffloadScaleupProposer"], [5, 1, 1, "", "GreedyProposer"], [5, 1, 1, "", "GridSearchProposer"], [5, 1, 1, "", "UniformProposer"], [5, 3, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.DynamicProgrammingProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[5, 2, 1, "", "allocate_budget"], [5, 2, 1, "", "build_affine_storage_model"], [5, 2, 1, "", "clf_to_bytes"], [5, 2, 1, "", "feedback"], [5, 2, 1, "", "get_budget"], [5, 2, 1, "", "get_cacheability"], [5, 2, 1, "", "get_expected_lookups"], [5, 2, 1, "", "load"], [5, 2, 1, "", "next_plan"], [5, 2, 1, "", "promote_high_prefetch_overheaad_table_to_hbm"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[5, 1, 1, "", "EmbeddingOffloadStats"], [5, 1, 1, "", "EmbeddingPerfEstimator"], [5, 1, 1, "", "EmbeddingStorageEstimator"], [5, 3, 1, "", "calculate_pipeline_io_cost"], [5, 3, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[5, 5, 1, "", "cacheability"], [5, 2, 1, "", "estimate_cache_miss_rate"], [5, 5, 1, "", "expected_lookups"], [5, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[5, 2, 1, "", "estimate"], [5, 2, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[5, 1, 1, "", "EmbeddingStats"], [5, 1, 1, "", "NoopEmbeddingStats"], [5, 3, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[5, 1, 1, "", "FixedPercentageStorageReservation"], [5, 1, 1, "", "HeuristicalStorageReservation"], [5, 1, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[5, 1, 1, "", "CustomTopologyData"], [5, 1, 1, "", "DeviceHardware"], [5, 1, 1, "", "Enumerator"], [5, 1, 1, "", "ParameterConstraints"], [5, 1, 1, "", "PartitionByType"], [5, 1, 1, "", "Partitioner"], [5, 1, 1, "", "Perf"], [5, 1, 1, "", "PerfModel"], [5, 6, 1, "", "PlannerError"], [5, 1, 1, "", "PlannerErrorType"], [5, 1, 1, "", "Proposer"], [5, 1, 1, "", "Shard"], [5, 1, 1, "", "ShardEstimator"], [5, 1, 1, "", "ShardingOption"], [5, 1, 1, "", "Stats"], [5, 1, 1, "", "Storage"], [5, 1, 1, "", "StorageReservation"], [5, 1, 1, "", "Topology"]], "torchrec.distributed.planner.types.CustomTopologyData": [[5, 2, 1, "", "get_data"], [5, 2, 1, "", "has_data"], [5, 4, 1, "", "supported_fields"]], "torchrec.distributed.planner.types.DeviceHardware": [[5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[5, 4, 1, "id0", "batch_sizes"], [5, 4, 1, "id1", "bounds_check_mode"], [5, 4, 1, "id2", "cache_params"], [5, 4, 1, "id3", "compute_kernels"], [5, 4, 1, "id4", "device_group"], [5, 4, 1, "id5", "enforce_hbm"], [5, 4, 1, "id6", "feature_names"], [5, 4, 1, "id7", "is_weighted"], [5, 4, 1, "id8", "key_value_params"], [5, 4, 1, "id9", "min_partition"], [5, 4, 1, "id10", "num_poolings"], [5, 4, 1, "id11", "output_dtype"], [5, 4, 1, "id12", "pooling_factors"], [5, 4, 1, "id13", "sharding_types"], [5, 4, 1, "id14", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[5, 4, 1, "", "DEVICE"], [5, 4, 1, "", "HOST"], [5, 4, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[5, 4, 1, "", "bwd_comms"], [5, 4, 1, "", "bwd_compute"], [5, 4, 1, "", "fwd_comms"], [5, 4, 1, "", "fwd_compute"], [5, 4, 1, "", "prefetch_compute"], [5, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[5, 4, 1, "", "INSUFFICIENT_STORAGE"], [5, 4, 1, "", "OTHER"], [5, 4, 1, "", "PARTITION"], [5, 4, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[5, 4, 1, "", "offset"], [5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "size"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[5, 4, 1, "", "batch_size"], [5, 4, 1, "", "bounds_check_mode"], [5, 5, 1, "", "cache_load_factor"], [5, 4, 1, "", "cache_params"], [5, 4, 1, "", "compute_kernel"], [5, 4, 1, "", "dependency"], [5, 4, 1, "", "enforce_hbm"], [5, 4, 1, "", "feature_names"], [5, 5, 1, "", "fqn"], [5, 4, 1, "", "input_lengths"], [5, 5, 1, "id15", "is_pooled"], [5, 4, 1, "", "key_value_params"], [5, 5, 1, "id16", "module"], [5, 2, 1, "", "module_pooled"], [5, 4, 1, "", "name"], [5, 5, 1, "", "num_inputs"], [5, 5, 1, "", "num_shards"], [5, 4, 1, "", "output_dtype"], [5, 5, 1, "", "path"], [5, 4, 1, "", "sharding_type"], [5, 4, 1, "", "shards"], [5, 4, 1, "", "stochastic_rounding"], [5, 5, 1, "id17", "tensor"], [5, 5, 1, "", "total_perf"], [5, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[5, 4, 1, "", "ddr"], [5, 2, 1, "", "fits_in"], [5, 4, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[5, 5, 1, "", "bwd_compute_multiplier"], [5, 5, 1, "", "compute_device"], [5, 5, 1, "", "ddr_mem_bw"], [5, 5, 1, "", "devices"], [5, 5, 1, "", "hbm_mem_bw"], [5, 5, 1, "", "inter_host_bw"], [5, 5, 1, "", "intra_host_bw"], [5, 5, 1, "", "local_world_size"], [5, 5, 1, "", "uneven_sharding_perf_multiplier"], [5, 5, 1, "", "weighted_feature_bwd_compute_multiplier"], [5, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[5, 1, 1, "", "BinarySearchPredicate"], [5, 1, 1, "", "LuusJaakolaSearch"], [5, 3, 1, "", "bytes_to_gb"], [5, 3, 1, "", "bytes_to_mb"], [5, 3, 1, "", "gb_to_bytes"], [5, 3, 1, "", "placement"], [5, 3, 1, "", "prod"], [5, 3, 1, "", "reset_shard_rank"], [5, 3, 1, "", "sharder_name"], [5, 3, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[5, 2, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[5, 2, 1, "", "best"], [5, 2, 1, "", "clamp"], [5, 2, 1, "", "next"], [5, 2, 1, "", "shrink_right"], [5, 2, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[4, 1, 1, "", "QuantEmbeddingBagCollectionSharder"], [4, 1, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedQuantEbcInputDist"], [4, 1, 1, "", "ShardedQuantEmbeddingBagCollection"], [4, 1, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [4, 3, 1, "", "create_infer_embedding_bag_sharding"], [4, 3, 1, "", "flatten_feature_lengths"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "embedding_bag_configs"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharding_type_device_group_to_sharding_infos"], [4, 5, 1, "", "shardings"], [4, 2, 1, "", "tbes_configs"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[4, 2, 1, "", "apply_feature_processor"], [4, 2, 1, "", "compute"], [4, 4, 1, "", "embedding_bags"], [4, 4, 1, "", "tbes"], [4, 4, 1, "", "training"]], "torchrec.distributed.sharding": [[6, 0, 0, "-", "cw_sharding"], [6, 0, 0, "-", "dp_sharding"], [6, 0, 0, "-", "rw_sharding"], [6, 0, 0, "-", "tw_sharding"], [6, 0, 0, "-", "twcw_sharding"], [6, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[6, 1, 1, "", "BaseCwEmbeddingSharding"], [6, 1, 1, "", "CwPooledEmbeddingSharding"], [6, 1, 1, "", "InferCwPooledEmbeddingDist"], [6, 1, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [6, 1, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "uncombined_embedding_dims"], [6, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[6, 1, 1, "", "BaseDpEmbeddingSharding"], [6, 1, 1, "", "DpPooledEmbeddingDist"], [6, 1, 1, "", "DpPooledEmbeddingSharding"], [6, 1, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[6, 1, 1, "", "BaseRwEmbeddingSharding"], [6, 1, 1, "", "InferRwPooledEmbeddingDist"], [6, 1, 1, "", "InferRwPooledEmbeddingSharding"], [6, 1, 1, "", "InferRwSparseFeaturesDist"], [6, 1, 1, "", "RwPooledEmbeddingDist"], [6, 1, 1, "", "RwPooledEmbeddingSharding"], [6, 1, 1, "", "RwSparseFeaturesDist"], [6, 3, 1, "", "get_block_sizes_runtime_device"], [6, 3, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[6, 1, 1, "", "BaseTwEmbeddingSharding"], [6, 1, 1, "", "InferTwEmbeddingSharding"], [6, 1, 1, "", "InferTwPooledEmbeddingDist"], [6, 1, 1, "", "InferTwSparseFeaturesDist"], [6, 1, 1, "", "TwPooledEmbeddingDist"], [6, 1, 1, "", "TwPooledEmbeddingSharding"], [6, 1, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"], [6, 2, 1, "", "feature_names_per_rank"], [6, 2, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[6, 1, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[6, 1, 1, "", "BaseTwRwEmbeddingSharding"], [6, 1, 1, "", "TwRwPooledEmbeddingDist"], [6, 1, 1, "", "TwRwPooledEmbeddingSharding"], [6, 1, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.types": [[4, 1, 1, "", "Awaitable"], [4, 1, 1, "", "CacheParams"], [4, 1, 1, "", "CacheStatistics"], [4, 1, 1, "", "CommOp"], [4, 1, 1, "", "ComputeKernel"], [4, 1, 1, "", "EmbeddingModuleShardingPlan"], [4, 1, 1, "", "GenericMeta"], [4, 1, 1, "", "GetItemLazyAwaitable"], [4, 1, 1, "", "KeyValueParams"], [4, 1, 1, "", "LazyAwaitable"], [4, 1, 1, "", "LazyGetItemMixin"], [4, 1, 1, "", "LazyNoWait"], [4, 1, 1, "", "ModuleSharder"], [4, 1, 1, "", "ModuleShardingPlan"], [4, 1, 1, "", "NoOpQuantizedCommCodec"], [4, 1, 1, "", "NoWait"], [4, 1, 1, "", "NullShardedModuleContext"], [4, 1, 1, "", "NullShardingContext"], [4, 1, 1, "", "ObjectPoolShardingPlan"], [4, 1, 1, "", "ObjectPoolShardingType"], [4, 1, 1, "", "ParameterSharding"], [4, 1, 1, "", "ParameterStorage"], [4, 1, 1, "", "PipelineType"], [4, 1, 1, "", "QuantizedCommCodec"], [4, 1, 1, "", "QuantizedCommCodecs"], [4, 1, 1, "", "ShardedModule"], [4, 1, 1, "", "ShardingEnv"], [4, 1, 1, "", "ShardingPlan"], [4, 1, 1, "", "ShardingPlanner"], [4, 1, 1, "", "ShardingType"], [4, 3, 1, "", "get_tensor_size_bytes"], [4, 3, 1, "", "rank_device"], [4, 3, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[4, 5, 1, "", "callbacks"], [4, 2, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[4, 4, 1, "id34", "algorithm"], [4, 4, 1, "id35", "load_factor"], [4, 4, 1, "", "multipass_prefetch_config"], [4, 4, 1, "id36", "precision"], [4, 4, 1, "id37", "prefetch_pipeline"], [4, 4, 1, "id38", "reserved_memory"], [4, 4, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[4, 5, 1, "", "cacheability"], [4, 5, 1, "", "expected_lookups"], [4, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[4, 4, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [4, 4, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [4, 4, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[4, 4, 1, "", "DEFAULT"]], "torchrec.distributed.types.KeyValueParams": [[4, 4, 1, "id40", "gather_ssd_cache_stats"], [4, 4, 1, "", "l2_cache_size"], [4, 4, 1, "", "ods_prefix"], [4, 4, 1, "id41", "ps_client_thread_num"], [4, 4, 1, "id42", "ps_hosts"], [4, 4, 1, "id43", "ps_max_key_per_request"], [4, 4, 1, "id44", "ps_max_local_index_length"], [4, 4, 1, "", "report_interval"], [4, 4, 1, "id45", "ssd_rocksdb_shards"], [4, 4, 1, "id46", "ssd_rocksdb_write_buffer_size"], [4, 4, 1, "id47", "ssd_storage_directory"], [4, 4, 1, "", "stats_reporter_config"], [4, 4, 1, "", "use_passed_in_path"]], "torchrec.distributed.types.ModuleSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 2, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.ObjectPoolShardingPlan": [[4, 4, 1, "", "inference"], [4, 4, 1, "", "sharding_type"]], "torchrec.distributed.types.ObjectPoolShardingType": [[4, 4, 1, "", "REPLICATED_ROW_WISE"], [4, 4, 1, "", "ROW_WISE"]], "torchrec.distributed.types.ParameterSharding": [[4, 4, 1, "", "bounds_check_mode"], [4, 4, 1, "", "cache_params"], [4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "enforce_hbm"], [4, 4, 1, "", "key_value_params"], [4, 4, 1, "", "output_dtype"], [4, 4, 1, "", "ranks"], [4, 4, 1, "", "sharding_spec"], [4, 4, 1, "", "sharding_type"], [4, 4, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[4, 4, 1, "", "DDR"], [4, 4, 1, "", "HBM"]], "torchrec.distributed.types.PipelineType": [[4, 4, 1, "", "NONE"], [4, 4, 1, "", "TRAIN_BASE"], [4, 4, 1, "", "TRAIN_PREFETCH_SPARSE_DIST"], [4, 4, 1, "", "TRAIN_SPARSE_DIST"]], "torchrec.distributed.types.QuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[4, 2, 1, "", "from_local"], [4, 2, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[4, 2, 1, "", "get_plan_for_module"], [4, 4, 1, "id48", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[4, 2, 1, "", "collective_plan"], [4, 2, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[4, 4, 1, "", "COLUMN_WISE"], [4, 4, 1, "", "DATA_PARALLEL"], [4, 4, 1, "", "ROW_WISE"], [4, 4, 1, "", "TABLE_COLUMN_WISE"], [4, 4, 1, "", "TABLE_ROW_WISE"], [4, 4, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[4, 1, 1, "", "CopyableMixin"], [4, 1, 1, "", "ForkedPdb"], [4, 3, 1, "", "add_params_from_parameter_sharding"], [4, 3, 1, "", "add_prefix_to_state_dict"], [4, 3, 1, "", "append_prefix"], [4, 3, 1, "", "convert_to_fbgemm_types"], [4, 3, 1, "", "copy_to_device"], [4, 3, 1, "", "filter_state_dict"], [4, 3, 1, "", "get_unsharded_module_names"], [4, 3, 1, "", "init_parameters"], [4, 3, 1, "", "merge_fused_params"], [4, 3, 1, "", "none_throws"], [4, 3, 1, "", "optimizer_type_to_emb_opt_type"], [4, 1, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[4, 2, 1, "", "copy"], [4, 4, 1, "", "training"]], "torchrec.distributed.utils.ForkedPdb": [[4, 2, 1, "", "interaction"]], "torchrec.fx": [[7, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[7, 1, 1, "", "Tracer"], [7, 3, 1, "", "is_fx_tracing"], [7, 3, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[7, 2, 1, "", "create_arg"], [7, 2, 1, "", "is_leaf_module"], [7, 2, 1, "", "path_of_module"], [7, 2, 1, "", "trace"]], "torchrec.inference": [[8, 0, 0, "-", "model_packager"], [8, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[8, 1, 1, "", "PredictFactoryPackager"], [8, 3, 1, "", "load_config_text"], [8, 3, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[8, 2, 1, "", "save_predict_factory"], [8, 2, 1, "", "set_extern_modules"], [8, 2, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[8, 1, 1, "", "BatchingMetadata"], [8, 1, 1, "", "PredictFactory"], [8, 1, 1, "", "PredictModule"], [8, 1, 1, "", "QualNameMetadata"], [8, 3, 1, "", "assign_weights_to_tbe"], [8, 3, 1, "", "get_table_to_weights_from_tbe"], [8, 3, 1, "", "quantize_dense"], [8, 3, 1, "", "quantize_embeddings"], [8, 3, 1, "", "quantize_feature"], [8, 3, 1, "", "quantize_inference_model"], [8, 3, 1, "", "set_pruning_data"], [8, 3, 1, "", "shard_quant_model"], [8, 3, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[8, 4, 1, "", "device"], [8, 4, 1, "", "pinned"], [8, 4, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[8, 2, 1, "", "batching_metadata"], [8, 2, 1, "", "batching_metadata_json"], [8, 2, 1, "", "create_predict_module"], [8, 2, 1, "", "model_inputs_data"], [8, 2, 1, "", "qualname_metadata"], [8, 2, 1, "", "qualname_metadata_json"], [8, 2, 1, "", "result_metadata"], [8, 2, 1, "", "run_weights_dependent_transformations"], [8, 2, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[8, 2, 1, "", "forward"], [8, 2, 1, "", "predict_forward"], [8, 5, 1, "", "predict_module"], [8, 2, 1, "", "state_dict"], [8, 4, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[8, 4, 1, "", "need_preproc"]], "torchrec.metrics": [[9, 0, 0, "-", "accuracy"], [9, 0, 0, "-", "auc"], [9, 0, 0, "-", "auprc"], [9, 0, 0, "-", "calibration"], [9, 0, 0, "-", "ctr"], [9, 0, 0, "-", "mae"], [9, 0, 0, "-", "metric_module"], [9, 0, 0, "-", "mse"], [9, 0, 0, "-", "multiclass_recall"], [9, 0, 0, "-", "ndcg"], [9, 0, 0, "-", "ne"], [9, 0, 0, "-", "precision"], [9, 0, 0, "-", "rauc"], [9, 0, 0, "-", "rec_metric"], [9, 0, 0, "-", "recall"], [9, 0, 0, "-", "throughput"], [9, 0, 0, "-", "weighted_avg"], [9, 0, 0, "-", "xauc"]], "torchrec.metrics.accuracy": [[9, 1, 1, "", "AccuracyMetric"], [9, 1, 1, "", "AccuracyMetricComputation"], [9, 3, 1, "", "compute_accuracy"], [9, 3, 1, "", "compute_accuracy_sum"], [9, 3, 1, "", "get_accuracy_states"]], "torchrec.metrics.accuracy.AccuracyMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.auc": [[9, 1, 1, "", "AUCMetric"], [9, 1, 1, "", "AUCMetricComputation"], [9, 3, 1, "", "compute_auc"], [9, 3, 1, "", "compute_auc_per_group"]], "torchrec.metrics.auc.AUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.auprc": [[9, 1, 1, "", "AUPRCMetric"], [9, 1, 1, "", "AUPRCMetricComputation"], [9, 3, 1, "", "compute_auprc"], [9, 3, 1, "", "compute_auprc_per_group"]], "torchrec.metrics.auprc.AUPRCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.calibration": [[9, 1, 1, "", "CalibrationMetric"], [9, 1, 1, "", "CalibrationMetricComputation"], [9, 3, 1, "", "compute_calibration"], [9, 3, 1, "", "get_calibration_states"]], "torchrec.metrics.calibration.CalibrationMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ctr": [[9, 1, 1, "", "CTRMetric"], [9, 1, 1, "", "CTRMetricComputation"], [9, 3, 1, "", "compute_ctr"], [9, 3, 1, "", "get_ctr_states"]], "torchrec.metrics.ctr.CTRMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.mae": [[9, 1, 1, "", "MAEMetric"], [9, 1, 1, "", "MAEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mae"], [9, 3, 1, "", "get_mae_states"]], "torchrec.metrics.mae.MAEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.metric_module": [[9, 1, 1, "", "RecMetricModule"], [9, 1, 1, "", "StateMetric"], [9, 3, 1, "", "generate_metric_module"]], "torchrec.metrics.metric_module.RecMetricModule": [[9, 4, 1, "", "batch_size"], [9, 2, 1, "", "check_memory_usage"], [9, 2, 1, "", "compute"], [9, 4, 1, "", "compute_count"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 4, 1, "", "last_compute_time"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "memory_usage_limit_mb"], [9, 4, 1, "", "memory_usage_mb_avg"], [9, 4, 1, "", "oom_count"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "rec_tasks"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "should_compute"], [9, 4, 1, "", "state_metrics"], [9, 2, 1, "", "sync"], [9, 4, 1, "", "throughput_metric"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"], [9, 4, 1, "", "world_size"]], "torchrec.metrics.metric_module.StateMetric": [[9, 2, 1, "", "get_metrics"]], "torchrec.metrics.mse": [[9, 1, 1, "", "MSEMetric"], [9, 1, 1, "", "MSEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mse"], [9, 3, 1, "", "compute_rmse"], [9, 3, 1, "", "get_mse_states"]], "torchrec.metrics.mse.MSEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.multiclass_recall": [[9, 1, 1, "", "MulticlassRecallMetric"], [9, 1, 1, "", "MulticlassRecallMetricComputation"], [9, 3, 1, "", "compute_multiclass_recall_at_k"], [9, 3, 1, "", "compute_true_positives_at_k"], [9, 3, 1, "", "get_multiclass_recall_states"]], "torchrec.metrics.multiclass_recall.MulticlassRecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ndcg": [[9, 1, 1, "", "NDCGComputation"], [9, 1, 1, "", "NDCGMetric"]], "torchrec.metrics.ndcg.NDCGComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ne": [[9, 1, 1, "", "NEMetric"], [9, 1, 1, "", "NEMetricComputation"], [9, 3, 1, "", "compute_cross_entropy"], [9, 3, 1, "", "compute_logloss"], [9, 3, 1, "", "compute_ne"], [9, 3, 1, "", "get_ne_states"]], "torchrec.metrics.ne.NEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.precision": [[9, 1, 1, "", "PrecisionMetric"], [9, 1, 1, "", "PrecisionMetricComputation"], [9, 3, 1, "", "compute_false_pos_sum"], [9, 3, 1, "", "compute_precision"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_precision_states"]], "torchrec.metrics.precision.PrecisionMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.rauc": [[9, 1, 1, "", "RAUCMetric"], [9, 1, 1, "", "RAUCMetricComputation"], [9, 3, 1, "", "compute_rauc"], [9, 3, 1, "", "compute_rauc_per_group"], [9, 3, 1, "", "conquer_and_count"], [9, 3, 1, "", "count_reverse_pairs_divide_and_conquer"], [9, 3, 1, "", "divide"]], "torchrec.metrics.rauc.RAUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric": [[9, 1, 1, "", "MetricComputationReport"], [9, 1, 1, "", "RecMetric"], [9, 1, 1, "", "RecMetricComputation"], [9, 6, 1, "", "RecMetricException"], [9, 1, 1, "", "RecMetricList"], [9, 1, 1, "", "WindowBuffer"]], "torchrec.metrics.rec_metric.MetricComputationReport": [[9, 4, 1, "", "description"], [9, 4, 1, "", "metric_prefix"], [9, 4, 1, "", "name"], [9, 4, 1, "", "value"]], "torchrec.metrics.rec_metric.RecMetric": [[9, 4, 1, "", "LABELS"], [9, 4, 1, "", "PREDICTIONS"], [9, 4, 1, "", "WEIGHTS"], [9, 2, 1, "", "compute"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "state_dict"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricComputation": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_window_state"], [9, 2, 1, "", "get_window_state_name"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "pre_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricList": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "required_inputs"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.WindowBuffer": [[9, 2, 1, "", "aggregate_state"], [9, 5, 1, "", "buffers"]], "torchrec.metrics.recall": [[9, 1, 1, "", "RecallMetric"], [9, 1, 1, "", "RecallMetricComputation"], [9, 3, 1, "", "compute_false_neg_sum"], [9, 3, 1, "", "compute_recall"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_recall_states"]], "torchrec.metrics.recall.RecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.throughput": [[9, 1, 1, "", "ThroughputMetric"]], "torchrec.metrics.throughput.ThroughputMetric": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "update"]], "torchrec.metrics.weighted_avg": [[9, 1, 1, "", "WeightedAvgMetric"], [9, 1, 1, "", "WeightedAvgMetricComputation"], [9, 3, 1, "", "get_mean"]], "torchrec.metrics.weighted_avg.WeightedAvgMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.xauc": [[9, 1, 1, "", "XAUCMetric"], [9, 1, 1, "", "XAUCMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_weighted_num_pairs"], [9, 3, 1, "", "compute_xauc"], [9, 3, 1, "", "get_xauc_states"]], "torchrec.metrics.xauc.XAUCMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.models": [[10, 0, 0, "-", "deepfm"], [10, 0, 0, "-", "dlrm"]], "torchrec.models.deepfm": [[10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "FMInteractionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SimpleDeepFMNN"], [10, 1, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm": [[10, 1, 1, "", "DLRM"], [10, 1, 1, "", "DLRMTrain"], [10, 1, 1, "", "DLRM_DCN"], [10, 1, 1, "", "DLRM_Projection"], [10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "InteractionArch"], [10, 1, 1, "", "InteractionDCNArch"], [10, 1, 1, "", "InteractionProjectionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SparseArch"], [10, 3, 1, "", "choose"]], "torchrec.models.dlrm.DLRM": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRMTrain": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_DCN": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_Projection": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionDCNArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionProjectionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.SparseArch": [[10, 2, 1, "", "forward"], [10, 5, 1, "", "sparse_feature_names"], [10, 4, 1, "", "training"]], "torchrec.modules": [[11, 0, 0, "-", "activation"], [11, 0, 0, "-", "crossnet"], [11, 0, 0, "-", "deepfm"], [11, 0, 0, "-", "embedding_configs"], [11, 0, 0, "-", "embedding_modules"], [11, 0, 0, "-", "feature_processor"], [11, 0, 0, "-", "lazy_extension"], [11, 0, 0, "-", "mc_embedding_modules"], [11, 0, 0, "-", "mc_modules"], [11, 0, 0, "-", "mlp"], [11, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[11, 1, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet": [[11, 1, 1, "", "CrossNet"], [11, 1, 1, "", "LowRankCrossNet"], [11, 1, 1, "", "LowRankMixtureCrossNet"], [11, 1, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm": [[11, 1, 1, "", "DeepFM"], [11, 1, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_configs": [[11, 1, 1, "", "BaseEmbeddingConfig"], [11, 1, 1, "", "EmbeddingBagConfig"], [11, 1, 1, "", "EmbeddingConfig"], [11, 1, 1, "", "EmbeddingTableConfig"], [11, 1, 1, "", "PoolingType"], [11, 1, 1, "", "QuantConfig"], [11, 1, 1, "", "ShardingType"], [11, 3, 1, "", "data_type_to_dtype"], [11, 3, 1, "", "data_type_to_sparse_type"], [11, 3, 1, "", "dtype_to_data_type"], [11, 3, 1, "", "pooling_type_to_pooling_mode"], [11, 3, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[11, 4, 1, "", "data_type"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 2, 1, "", "get_weight_init_max"], [11, 2, 1, "", "get_weight_init_min"], [11, 4, 1, "", "init_fn"], [11, 4, 1, "", "name"], [11, 4, 1, "", "need_pos"], [11, 4, 1, "", "num_embeddings"], [11, 4, 1, "", "num_embeddings_post_pruning"], [11, 2, 1, "", "num_features"], [11, 4, 1, "", "weight_init_max"], [11, 4, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 4, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[11, 4, 1, "", "embedding_names"], [11, 4, 1, "", "has_feature_processor"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[11, 4, 1, "", "MEAN"], [11, 4, 1, "", "NONE"], [11, 4, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[11, 4, 1, "", "activation"], [11, 4, 1, "", "per_table_weight_dtype"], [11, 4, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[11, 4, 1, "", "COLUMN_WISE"], [11, 4, 1, "", "DATA_PARALLEL"], [11, 4, 1, "", "ROW_WISE"], [11, 4, 1, "", "TABLE_COLUMN_WISE"], [11, 4, 1, "", "TABLE_ROW_WISE"], [11, 4, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[11, 1, 1, "", "EmbeddingBagCollection"], [11, 1, 1, "", "EmbeddingBagCollectionInterface"], [11, 1, 1, "", "EmbeddingCollection"], [11, 1, 1, "", "EmbeddingCollectionInterface"], [11, 3, 1, "", "get_embedding_names_by_table"], [11, 3, 1, "", "process_pooled_embeddings"], [11, 3, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor": [[11, 1, 1, "", "BaseFeatureProcessor"], [11, 1, 1, "", "BaseGroupedFeatureProcessor"], [11, 1, 1, "", "PositionWeightedModule"], [11, 1, 1, "", "PositionWeightedProcessor"], [11, 3, 1, "", "offsets_to_range_traceble"], [11, 3, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "named_buffers"], [11, 2, 1, "", "state_dict"], [11, 4, 1, "", "training"]], "torchrec.modules.lazy_extension": [[11, 1, 1, "", "LazyModuleExtensionMixin"], [11, 3, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[11, 2, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[11, 1, 1, "", "BaseManagedCollisionEmbeddingCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingBagCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingCollection"], [11, 3, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_modules": [[11, 1, 1, "", "DistanceLFU_EvictionPolicy"], [11, 1, 1, "", "LFU_EvictionPolicy"], [11, 1, 1, "", "LRU_EvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicyMetadataInfo"], [11, 1, 1, "", "MCHManagedCollisionModule"], [11, 1, 1, "", "ManagedCollisionCollection"], [11, 1, 1, "", "ManagedCollisionModule"], [11, 3, 1, "", "apply_mc_method_to_jt_dict"], [11, 3, 1, "", "average_threshold_filter"], [11, 3, 1, "", "dynamic_threshold_filter"], [11, 3, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[11, 4, 1, "", "is_history_metadata"], [11, 4, 1, "", "is_mch_metadata"], [11, 4, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "open_slots"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[11, 5, 1, "", "device"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mlp": [[11, 1, 1, "", "MLP"], [11, 1, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.utils": [[11, 1, 1, "", "SequenceVBEContext"], [11, 3, 1, "", "check_module_output_dimension"], [11, 3, 1, "", "construct_jagged_tensors"], [11, 3, 1, "", "construct_jagged_tensors_inference"], [11, 3, 1, "", "construct_modulelist_from_single_module"], [11, 3, 1, "", "convert_list_of_modules_to_modulelist"], [11, 3, 1, "", "deterministic_dedup"], [11, 3, 1, "", "extract_module_or_tensor_callable"], [11, 3, 1, "", "get_module_output_dimension"], [11, 3, 1, "", "init_mlp_weights_xavier_uniform"], [11, 3, 1, "", "jagged_index_select_with_empty"]], "torchrec.modules.utils.SequenceVBEContext": [[11, 4, 1, "", "recat"], [11, 2, 1, "", "record_stream"], [11, 4, 1, "", "reindexed_length_per_key"], [11, 4, 1, "", "reindexed_lengths"], [11, 4, 1, "", "reindexed_values"], [11, 4, 1, "", "unpadded_lengths"]], "torchrec.optim": [[12, 0, 0, "-", "clipping"], [12, 0, 0, "-", "fused"], [12, 0, 0, "-", "keyed"], [12, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[12, 1, 1, "", "GradientClipping"], [12, 1, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[12, 4, 1, "", "NONE"], [12, 4, 1, "", "NORM"], [12, 4, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[12, 2, 1, "", "step"]], "torchrec.optim.fused": [[12, 1, 1, "", "EmptyFusedOptimizer"], [12, 1, 1, "", "FusedOptimizer"], [12, 1, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[12, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[12, 1, 1, "", "CombinedOptimizer"], [12, 1, 1, "", "KeyedOptimizer"], [12, 1, 1, "", "KeyedOptimizerWrapper"], [12, 1, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[12, 5, 1, "", "optimizers"], [12, 5, 1, "", "param_groups"], [12, 5, 1, "", "params"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "prepend_opt_key"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "set_optimizer_step"], [12, 5, 1, "", "state"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "init_state"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[12, 1, 1, "", "WarmupOptimizer"], [12, 1, 1, "", "WarmupPolicy"], [12, 1, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[12, 4, 1, "", "CONSTANT"], [12, 4, 1, "", "COSINE_ANNEALING_WARM_RESTARTS"], [12, 4, 1, "", "INVSQRT"], [12, 4, 1, "", "LINEAR"], [12, 4, 1, "", "NONE"], [12, 4, 1, "", "POLY"], [12, 4, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[12, 4, 1, "", "decay_iters"], [12, 4, 1, "", "lr_scale"], [12, 4, 1, "", "max_iters"], [12, 4, 1, "", "policy"], [12, 4, 1, "", "sgdr_period"], [12, 4, 1, "", "value"]], "torchrec.quant": [[13, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[13, 1, 1, "", "EmbeddingBagCollection"], [13, 1, 1, "", "EmbeddingCollection"], [13, 1, 1, "", "FeatureProcessedEmbeddingBagCollection"], [13, 3, 1, "", "for_each_module_of_type_do"], [13, 3, 1, "", "quant_prep_customize_row_alignment"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [13, 3, 1, "", "quant_prep_enable_register_tbes"], [13, 3, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_bag_configs"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "is_weighted"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_configs"], [13, 2, 1, "", "embedding_dim"], [13, 2, 1, "", "embedding_names_by_table"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "need_indices"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[13, 4, 1, "", "embedding_bags"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 4, 1, "", "tbes"], [13, 4, 1, "", "training"]], "torchrec.sparse": [[14, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[14, 1, 1, "", "ComputeJTDictToKJT"], [14, 1, 1, "", "ComputeKJTToJTDict"], [14, 1, 1, "", "JaggedTensor"], [14, 1, 1, "", "JaggedTensorMeta"], [14, 1, 1, "", "KeyedJaggedTensor"], [14, 1, 1, "", "KeyedTensor"], [14, 3, 1, "", "flatten_kjt_list"], [14, 3, 1, "", "jt_is_equal"], [14, 3, 1, "", "kjt_is_equal"], [14, 3, 1, "", "permute_multi_embedding"], [14, 3, 1, "", "regroup_kts"], [14, 3, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "from_dense"], [14, 2, 1, "", "from_dense_lengths"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dense"], [14, 2, 1, "", "to_dense_weights"], [14, 2, 1, "", "to_padded_dense"], [14, 2, 1, "", "to_padded_dense_weights"], [14, 2, 1, "", "values"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[14, 2, 1, "", "concat"], [14, 2, 1, "", "device"], [14, 2, 1, "", "dist_init"], [14, 2, 1, "", "dist_labels"], [14, 2, 1, "", "dist_splits"], [14, 2, 1, "", "dist_tensors"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "empty_like"], [14, 2, 1, "", "flatten_lengths"], [14, 2, 1, "", "from_jt_dict"], [14, 2, 1, "", "from_lengths_sync"], [14, 2, 1, "", "from_offsets_sync"], [14, 2, 1, "", "index_per_key"], [14, 2, 1, "", "inverse_indices"], [14, 2, 1, "", "inverse_indices_or_none"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "length_per_key_or_none"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_offset_per_key"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "offset_per_key_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "permute"], [14, 2, 1, "", "pin_memory"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "split"], [14, 2, 1, "", "stride"], [14, 2, 1, "", "stride_per_key"], [14, 2, 1, "", "stride_per_key_per_rank"], [14, 2, 1, "", "sync"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "unsync"], [14, 2, 1, "", "values"], [14, 2, 1, "", "variable_stride_per_key"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "from_tensor_list"], [14, 2, 1, "", "key_dim"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "regroup"], [14, 2, 1, "", "regroup_as_dict"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:function", "4": "py:attribute", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "function", "Python function"], "4": ["py", "attribute", "Python attribute"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "document": 0, "get": 0, "start": 0, "how": 0, "contribut": 0, "overview": 1, "why": 1, "dataset": [2, 3], "criteo": 2, "movielen": 2, "random": 2, "util": [2, 4, 5, 11], "script": 3, "contiguous_preproc_criteo": 3, "npy_preproc_criteo": 3, "distribut": [4, 5, 6], "collective_util": 4, "comm": 4, "comm_op": 4, "dist_data": [4, 6], "embed": 4, "embedding_lookup": 4, "embedding_shard": 4, "embedding_typ": 4, "embeddingbag": 4, "grouped_position_weight": 4, "model_parallel": 4, "quant_embeddingbag": 4, "train_pipelin": 4, "type": [4, 5], "mc_modul": [4, 11], "mc_embeddingbag": 4, "mc_embed": 4, "planner": 5, "constant": 5, "enumer": 5, "partition": 5, "perf_model": 5, "propos": 5, "shard_estim": 5, "stat": 5, "storage_reserv": 5, "shard": 6, "cw_shard": 6, "dp_shard": 6, "rw_shard": 6, "tw_shard": 6, "twcw_shard": 6, "twrw_shard": 6, "fx": 7, "tracer": 7, "modul": [7, 8, 10, 11, 12, 13, 14], "content": [7, 8, 10, 12, 13, 14], "infer": 8, "model_packag": 8, "metric": 9, "accuraci": 9, "auc": 9, "auprc": 9, "calibr": 9, "ctr": 9, "mae": 9, "mse": 9, "multiclass_recal": 9, "ndcg": 9, "ne": 9, "recal": 9, "precis": 9, "rauc": 9, "throughput": 9, "weighted_avg": 9, "xauc": 9, "metric_modul": 9, "rec_metr": 9, "model": 10, "deepfm": [10, 11], "dlrm": 10, "activ": 11, "crossnet": 11, "embedding_config": 11, "embedding_modul": [11, 13], "feature_processor": 11, "lazy_extens": 11, "mlp": 11, "mc_embedding_modul": 11, "optim": 12, "clip": 12, "fuse": 12, "kei": 12, "warmup": 12, "quant": 13, "spars": 14, "jagged_tensor": 14}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}}) \ No newline at end of file +Search.setIndex({"docnames": ["index", "overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "filenames": ["index.rst", "overview.rst", "torchrec.datasets.rst", "torchrec.datasets.scripts.rst", "torchrec.distributed.rst", "torchrec.distributed.planner.rst", "torchrec.distributed.sharding.rst", "torchrec.fx.rst", "torchrec.inference.rst", "torchrec.metrics.rst", "torchrec.models.rst", "torchrec.modules.rst", "torchrec.optim.rst", "torchrec.quant.rst", "torchrec.sparse.rst"], "titles": ["Welcome to the TorchRec documentation!", "TorchRec Overview", "torchrec.datasets", "torchrec.datasets.scripts", "torchrec.distributed", "torchrec.distributed.planner", "torchrec.distributed.sharding", "torchrec.fx", "torchrec.inference", "torchrec.metrics", "torchrec.models", "torchrec.modules", "torchrec.optim", "torchrec.quant", "torchrec.sparse"], "terms": {"recommendation system": 0, "shard": [0, 1, 4, 5, 8, 11, 12, 13], "distributed train": 0, "special": [0, 1, 7, 9, 11, 12], "librari": [0, 1], "within": [0, 4, 5, 6, 8, 10, 11, 14], "pytorch": [0, 1, 2, 4, 11, 12, 14], "ecosystem": [0, 1], "tailor": 0, "build": [0, 1, 5], "scale": [0, 1], "deploi": [0, 1, 8], "larg": [0, 1, 2, 5], "recommend": [0, 1, 2, 9, 10], "system": [0, 1, 2, 4, 5, 10], "nich": 0, "directli": [0, 4, 12], "address": [0, 1, 4], "standard": 0, "offer": 0, "advanc": [0, 1, 12], "featur": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 13, 14], "complex": [0, 7], "techniqu": [0, 1], "massiv": [0, 1], "embed": [0, 1, 5, 6, 7, 10, 11, 13, 14], "tabl": [0, 1, 4, 5, 6, 7, 8, 10, 11, 13], "enhanc": 0, "distribut": [0, 1, 2, 8, 9, 11, 12, 14], "train": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "capabl": [0, 1], "topic": 0, "thi": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "section": 0, "help": [0, 4], "you": [0, 4, 6, 7, 14], "overview": 0, "A": [0, 2, 4, 5, 6, 7, 8, 9, 12, 14], "short": 0, "intro": 0, "why": 0, "need": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "set": [0, 2, 4, 5, 6, 8, 9, 11, 12], "up": [0, 4, 5, 13], "learn": [0, 8, 10, 11, 12], "instal": 0, "us": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "your": [0, 7, 9], "environ": [0, 1, 4, 8], "tutori": 0, "follow": [0, 1, 4, 5, 6, 9, 10, 11, 12, 14], "our": 0, "interact": [0, 4, 10, 11], "step": [0, 4, 5, 12], "real": [0, 5], "life": 0, "exampl": [0, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "we": [0, 2, 4, 5, 6, 7, 9, 11, 12, 13, 14], "feedback": [0, 5], "from": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "commun": [0, 1, 4, 5, 6, 9], "If": [0, 2, 4, 5, 8, 9, 11, 12, 14], "ar": [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "interest": 0, "improv": [0, 1, 12], "project": [0, 10], "here": [0, 5, 10], "can": [0, 1, 2, 4, 5, 9, 10, 11, 12, 14], "visit": 0, "github": [0, 11], "repositori": 0, "There": [0, 4], "yoou": 0, "find": [0, 5], "sourc": [0, 2, 10, 11], "code": [0, 1, 4, 11], "issu": [0, 4, 6, 11], "ongo": 0, "submit": 0, "encount": 0, "ani": [0, 2, 4, 5, 6, 7, 8, 9, 11, 12, 14], "bug": 0, "have": [0, 2, 4, 5, 6, 9, 10, 11, 12, 14], "suggest": 0, "pleas": [0, 2, 4, 5, 8, 9, 11, 14], "an": [0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "through": [0, 4, 7, 9, 12], "tracker": 0, "propos": [0, 10], "chang": [0, 4, 11, 12], "fork": [0, 4], "pull": 0, "request": [0, 4, 8, 12], "whether": [0, 2, 4, 5, 7, 9, 11, 12, 13, 14], "s": [0, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "fix": [0, 14], "ad": [0, 2, 4, 8, 9, 11, 12], "new": [0, 4, 5, 9], "alwai": [0, 11, 14], "make": [0, 4, 11, 12, 14], "sure": [0, 12], "review": 0, "md": 0, "go": [0, 12], "repo": 0, "design": [1, 2, 4, 8, 9, 11], "provid": [1, 4, 5, 6, 8, 9, 10, 11, 13], "common": [1, 2, 11, 14], "primit": [1, 4, 6], "creat": [1, 4, 7, 8, 9, 11, 12, 14], "state": [1, 4, 5, 8, 9, 11, 12], "art": 1, "person": [1, 10], "model": [1, 4, 5, 6, 7, 8, 9, 11, 12, 13], "path": [1, 2, 4, 5, 8], "product": [1, 10], "wide": 1, "adopt": 1, "mani": [1, 4, 6], "meta": [1, 4, 5, 8], "infer": [1, 4, 5, 6, 11, 13, 14], "workflow": 1, "uniqu": [1, 2, 5, 9, 11], "challeng": [1, 2], "which": [1, 2, 4, 5, 6, 8, 9, 11, 12, 14], "focu": [1, 9], "regular": 1, "more": [1, 4, 5, 6, 9, 11], "specif": [1, 4, 5, 8, 12], "gener": [1, 2, 4, 5, 7, 8, 10, 11, 12, 14], "compon": [1, 9, 11], "simplist": 1, "modul": [1, 4, 5, 6, 9], "author": [1, 4], "flexibl": [1, 11], "customiz": [1, 5], "method": [1, 4, 7, 8, 9, 11], "row": [1, 2, 4, 5, 6], "wise": [1, 4, 5, 6, 11], "column": [1, 2, 5, 6], "so": [1, 2, 4, 5, 9, 10, 12, 14], "automat": [1, 4, 5, 9, 14], "determin": [1, 2, 4, 5, 6], "best": [1, 5], "plan": [1, 4, 5, 8, 11], "devic": [1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "topolog": [1, 4, 5, 6], "effici": [1, 5, 11], "memori": [1, 2, 4, 5, 8, 9, 12], "balanc": [1, 5], "while": [1, 2, 4, 5, 6, 7, 8, 11], "support": [1, 4, 5, 6, 7, 9, 11, 12], "basic": [1, 10, 14], "extend": [1, 4], "sophist": 1, "parallel": [1, 2, 4, 6, 12], "incred": 1, "optim": [1, 4, 5, 8, 9, 11, 13], "top": [1, 4, 9, 11], "fbgemm": [1, 4, 5, 6, 13], "after": [1, 4, 5, 6, 9, 11], "all": [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "power": 1, "some": [1, 4, 9, 10, 14], "largest": [1, 5], "frictionless": 1, "deploy": 1, "simpl": [1, 10], "api": [1, 4, 6, 7, 9, 11], "transform": [1, 2, 4, 8, 11], "load": [1, 2, 4, 5, 6, 11, 12], "c": [1, 2, 4, 6, 8, 14], "most": [1, 4, 8, 12], "integr": 1, "built": [1, 11], "mean": [1, 4, 5, 9, 11], "seamlessli": 1, "exist": [1, 4, 6, 11, 14], "tool": 1, "allow": [1, 2, 4, 5, 7, 9, 11, 12, 14], "develop": 1, "leverag": [1, 11], "knowledg": [1, 4, 5, 9], "codebas": 1, "util": [1, 6], "By": 1, "being": [1, 4, 5, 8, 9, 11], "part": [1, 2, 4, 5, 6, 11, 12], "benefit": 1, "robust": 1, "continu": [1, 2], "updat": [1, 4, 5, 6, 8, 9, 11, 12], "come": [1, 11], "contain": [2, 4, 5, 6, 8, 9, 11, 12, 13], "two": [2, 4, 5, 9, 10, 11, 14], "popular": [2, 10], "reci": 2, "kaggl": 2, "displai": 2, "advertis": 2, "20m": 2, "addition": 2, "randomdataset": 2, "data": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "same": [2, 4, 5, 6, 8, 9, 10, 11, 14], "format": [2, 3, 5, 8, 14], "abov": [2, 11, 14], "lastli": 2, "script": [2, 14], "pre": [2, 9, 11, 12], "process": [2, 3, 4, 5, 6, 9, 10, 11, 12, 13], "etc": [2, 4, 8, 12, 14], "import": [2, 4, 5, 8, 11, 13], "criteo_kaggl": 2, "datapip": 2, "criteo_terabyt": 2, "home": 2, "day_0": 2, "tsv": [2, 3], "day_1": 2, "dp": [2, 5], "iter": [2, 4, 5, 11, 12], "batcher": 2, "100": [2, 4, 5, 9, 10, 11], "collat": 2, "batch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "next": [2, 5], "class": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "binarycriteoutil": 2, "base": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "object": [2, 4, 5, 8, 9, 11, 12], "function": [2, 3, 4, 5, 6, 7, 8, 11, 12, 14], "preprocess": [2, 3, 11], "save": [2, 3, 4, 5, 11, 12], "partit": [2, 5, 6], "binari": [2, 3, 5, 9], "numpi": 2, "static": [2, 4, 5, 9, 12, 14], "get_file_row_ranges_and_remaind": 2, "length": [2, 4, 5, 6, 10, 11, 13, 14], "list": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "int": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "rank": [2, 4, 5, 6, 9, 10, 11, 12, 14], "world_siz": [2, 4, 5, 6, 8, 9], "start_row": 2, "0": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "last_row": 2, "option": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "none": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "tupl": [2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14], "dict": [2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14], "given": [2, 4, 5, 6, 7, 11], "number": [2, 4, 5, 6, 8, 9, 10, 11, 14], "file": [2, 3, 4], "return": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "portion": 2, "those": [2, 10, 11], "repres": [2, 4, 5, 8, 10, 11, 13, 14], "rang": [2, 5, 7, 11], "indic": [2, 4, 5, 6, 8, 11, 12, 13, 14], "inclus": 2, "should": [2, 4, 5, 6, 8, 9, 10, 11, 12, 14], "handl": [2, 4, 5, 6, 7, 11, 12], "each": [2, 4, 5, 6, 9, 10, 11, 13, 14], "assign": [2, 4], "The": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14], "wai": [2, 4, 5], "deal": 2, "enabl": [2, 4, 5, 9, 12], "reduc": [2, 4, 6, 11, 13], "amount": [2, 5], "read": 2, "avoid": [2, 4, 8, 9, 11, 12], "seek": 2, "paramet": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "count": [2, 5, 9, 11], "world": [2, 4, 6], "size": [2, 4, 5, 6, 9, 10, 11, 13, 14], "first": [2, 4, 5, 6, 10, 11, 12, 14], "item": [2, 4, 6], "map": [2, 4, 8, 9, 11, 12, 13], "kei": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "second": [2, 4, 5, 6, 9, 10, 11, 14], "remaind": 2, "type": [2, 6, 7, 8, 9, 10, 11, 12, 13, 14], "output": [2, 4, 5, 6, 8, 9, 10, 11, 13, 14], "get_shape_from_npi": 2, "str": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "path_manager_kei": 2, "shape": [2, 4, 6, 9, 10, 11, 14], "npy": [2, 3], "onli": [2, 4, 5, 6, 9, 11, 14], "its": [2, 4, 5, 6, 8, 9, 11, 12, 14], "header": 2, "input": [2, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14], "manag": [2, 11], "differ": [2, 4, 5, 6, 11, 12, 14], "filesystem": 2, "load_npy_rang": 2, "fname": 2, "num_row": 2, "mmap_mod": 2, "bool": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "fals": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "ndarrai": 2, "note": [2, 4, 5, 6, 11, 14], "assum": [2, 4, 5, 6, 8, 9, 10, 12], "arrai": 2, "ndim": 2, "2": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "string": [2, 5, 8, 11], "start": [2, 4, 11, 14], "get": [2, 4, 5, 6, 14], "desir": [2, 4, 8, 14], "suppli": 2, "np": 2, "shuffl": [2, 6], "input_dir_labels_and_dens": 2, "input_dir_spars": 2, "output_dir_shuffl": 2, "rows_per_dai": 2, "output_dir_full_set": 2, "dai": 2, "24": [2, 4, 6], "int_column": 2, "13": 2, "sparse_column": 2, "26": 2, "random_se": 2, "expect": [2, 3, 4, 5, 10, 11], "split": [2, 4, 5, 6, 8, 14], "dens": [2, 4, 5, 10, 11, 14], "spars": [2, 3, 4, 6, 10, 11, 13], "label": [2, 4, 6, 9, 10], "must": [2, 4, 5, 6, 8, 9, 10, 11], "day_x_dens": 2, "day_x_spars": 2, "day_x_label": 2, "reconstruct": 2, "back": 2, "separ": [2, 3, 4], "1": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "final": [2, 5, 9, 10, 11, 13, 14], "remain": 2, "untouch": 2, "valid": [2, 5, 11, 14], "directori": [2, 4], "full": [2, 11, 12, 14], "total": [2, 4, 5, 6, 9], "categor": 2, "seed": [2, 5], "oper": [2, 4, 5, 6, 7, 11, 14], "sparse_to_contigu": 2, "in_fil": 2, "output_dir": 2, "frequency_threshold": 2, "3": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "output_file_suffix": 2, "_contig_freq": 2, "convert": [2, 4, 7, 8, 14], "contigu": [2, 3], "integ": 2, "store": [2, 4, 5, 6, 14], "togeth": [2, 4, 11], "becaus": [2, 5, 11, 12], "match": [2, 4, 5, 8, 9, 10, 11], "id": [2, 4, 5, 6, 11], "between": [2, 4, 5, 9, 10, 11, 14], "henc": 2, "thei": [2, 4, 5, 14], "also": [2, 4, 5, 8, 9, 10, 11, 12], "appear": [2, 11], "less": 2, "than": [2, 5, 11, 12], "time": [2, 4, 5, 6, 8, 9, 11], "remap": [2, 11], "valu": [2, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14], "day_0_spars": 2, "col_0": 2, "col_1": 2, "abc": [2, 4, 5, 8, 9, 11, 12], "xyz": 2, "iop": 2, "day_1_spars": 2, "tuv": 2, "lkj": 2, "day_0_sparse_contig": 2, "day_1_sparse_contig": 2, "occur": 2, "frequenc": [2, 4], "tsv_to_npi": 2, "out_dense_fil": 2, "out_sparse_fil": 2, "out_labels_fil": 2, "dataset_nam": 2, "criteo_1tb": 2, "one": [2, 4, 5, 6, 8, 9, 10, 11, 12], "three": [2, 9], "float32": [2, 4, 8, 11, 13], "int32": [2, 6, 14], "1tb": 2, "click": [2, 9], "log": [2, 5, 9], "For": [2, 4, 5, 6, 9, 10, 11, 12, 13, 14], "test": [2, 11], "filler": 2, "includ": [2, 4, 5, 7, 8, 9, 11, 14], "name": [2, 4, 5, 8, 9, 10, 11, 12, 13, 14], "criteoiterdatapip": 2, "row_mapp": 2, "callabl": [2, 4, 6, 7, 11, 12, 13], "_default_row_mapp": 2, "open_kw": 2, "iterdatapip": 2, "stream": [2, 4, 11, 14], "either": [2, 4, 5, 9, 11], "http": [2, 4, 5, 10, 11, 14], "ailab": 2, "com": [2, 11], "download": 2, "www": 2, "local": [2, 4, 5, 6, 9, 11, 12], "constitut": 2, "appli": [2, 4, 6, 10, 11], "line": [2, 3], "pass": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "underli": [2, 9], "invoc": [2, 5], "iopath": 2, "file_io": 2, "pathmanag": 2, "open": 2, "inmemorybinarycriteoiterdatapip": [2, 3], "stage": [2, 12], "dense_path": 2, "sparse_path": 2, "labels_path": 2, "batch_siz": [2, 4, 5, 6, 9, 11, 14], "drop_last": 2, "shuffle_batch": 2, "shuffle_training_set": 2, "shuffle_training_set_random_se": 2, "hash": [2, 6, 11], "iterabledataset": 2, "over": [2, 4, 6, 10, 11, 12], "version": [2, 4, 11, 13], "entir": [2, 5, 6], "prevent": 2, "disk": 2, "speed": [2, 13], "affect": [2, 5], "throughout": [2, 10], "respons": [2, 4], "npy_preproc_criteo": 2, "py": [2, 4], "val": [2, 4], "max": [2, 5, 11, 12], "cat_feature_count": 2, "templat": [2, 9], "1tb_binari": 2, "day_": 2, "_": [2, 8], "1024": 2, "torch": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "get_rank": [2, 4], "get_world_s": [2, 8], "train_datapip": 2, "txt": 2, "test_datapip": 2, "movielens_20m": 2, "root": [2, 7], "include_movies_data": 2, "param": [2, 4, 5, 9, 12], "true": [2, 4, 5, 8, 9, 11, 12, 14], "add": [2, 4, 7, 11, 12], "movi": 2, "ml": 2, "20": [2, 10, 11], "movielens_25m": 2, "25m": 2, "25": [2, 4, 9], "randomrecdataset": 2, "hash_siz": 2, "ids_per_featur": 2, "num_dens": 2, "50": 2, "manual_se": 2, "num_batch": 2, "num_generated_batch": 2, "10": [2, 5, 10, 11, 13, 14], "min_ids_per_featur": 2, "recsi": [2, 5, 10, 12], "current": [2, 4, 5, 6, 8, 9, 11], "produc": [2, 4, 5], "unweight": 2, "todo": 2, "weight": [2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], "taken": [2, 5], "modulo": 2, "per": [2, 4, 5, 6, 9, 11, 14], "correspond": [2, 4, 5, 6, 8, 9, 11, 14], "argument": [2, 4, 7, 8, 9, 11], "ignor": [2, 4, 5, 6, 8, 11], "sampl": [2, 5, 9, 11], "determinist": [2, 5], "behavior": [2, 4, 7, 12], "num": 2, "befor": [2, 4, 5, 6, 9, 11, 12], "rais": [2, 4], "stopiter": 2, "cach": [2, 4, 5, 14], "num_gener": 2, "cycl": 2, "neg": 2, "fly": 2, "minimum": [2, 5], "feat1": 2, "feat2": 2, "16": [2, 4, 6, 11, 13], "100_000": 2, "dense_featur": [2, 10, 11], "tensor": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "sparse_featur": [2, 4, 6, 10], "jagged_tensor": [2, 4], "keyedjaggedtensor": [2, 4, 6, 10, 11, 13, 14], "pipelin": [2, 4, 5, 11, 14], "pin_memori": [2, 14], "record_stream": [2, 4, 11, 14], "see": [2, 4, 5, 6, 7, 9, 11, 14], "org": [2, 4, 5, 10, 11, 14], "doc": [2, 4, 11, 14], "stabl": [2, 4, 11, 14], "html": [2, 4, 11, 14], "non_block": [2, 14], "awar": [2, 4], "accord": [2, 4, 5, 6, 8, 10, 12], "might": [2, 5], "self": [2, 4, 5, 6, 8, 11, 14], "copi": [2, 4, 6, 8, 9, 11, 12, 14], "rememb": [2, 4], "new_devic": 2, "limit": [2, 8, 9, 11], "loadfil": 2, "mode": [2, 4, 5, 9], "b": [2, 4, 5, 6, 10, 11, 13, 14], "iobas": 2, "adapt": [2, 11], "loadfilesfromdisk": 2, "merg": [2, 4, 6], "replac": [2, 9, 12], "someth": 2, "core": 2, "lib": 2, "parallelreadconcat": 2, "dp_selector": 2, "sequenc": [2, 4, 5, 6], "_default_dp_selector": 2, "concaten": [2, 4, 6, 10, 11, 14], "multipl": [2, 4, 5, 9, 10, 11, 12], "when": [2, 4, 5, 7, 9, 11, 12], "dataload": [2, 4], "subset": [2, 5], "worker": [2, 4], "instanc": [2, 4, 6, 7, 8, 9, 11], "would": [2, 4, 6, 14], "f": [2, 4, 5, 6, 10, 11, 13], "shard_": 2, "idx": [2, 4], "4": [2, 4, 5, 6, 9, 10, 11, 13, 14], "num_work": [2, 14], "readlinesfromcsv": 2, "skip_first_lin": 2, "kw": 2, "idx_split_train_v": 2, "train_perc": 2, "float": [2, 4, 5, 7, 9, 11, 12, 14], "decimal_plac": 2, "key_fn": 2, "_default_key_fn": 2, "rand_split_train_v": 2, "via": [2, 4, 6], "uniform": [2, 5, 11], "disjoint": 2, "specifi": [2, 4, 5, 6, 7, 9, 10, 11, 12, 14], "target": [2, 4, 10], "proport": 2, "actual": [2, 4, 5, 6, 8, 9, 11], "guarante": [2, 5, 7, 12], "exactli": [2, 4], "membership": 2, "across": [2, 4, 5, 6, 9], "call": [2, 4, 5, 6, 8, 9, 11, 12, 13, 14], "consist": [2, 4], "val_datapip": 2, "75": 2, "train_batch": 2, "val_batch": 2, "safe_cast": 2, "t": [2, 4, 5, 6, 7, 8, 11, 12, 14], "dest_typ": 2, "default": [2, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14], "train_filt": 2, "val_filt": 2, "main": [3, 9], "argv": 3, "result": [3, 4, 5, 6, 8, 9, 11, 13, 14], "command": 3, "arg": [3, 4, 5, 8, 9, 11, 13, 14], "parse_arg": 3, "namespac": [3, 14], "raw": [3, 11], "criteo": [3, 10], "necessari": [4, 5, 6, 8, 9], "These": [4, 5, 9, 11], "distributedmodelparallel": 4, "collect": [4, 6, 10, 11, 12, 13], "scatter": [4, 6], "wrapper": [4, 12], "kjt": [4, 5, 6, 10, 11, 13, 14], "variou": [4, 8, 11], "implement": [4, 5, 6, 8, 9, 10, 11, 12, 14], "shardedembeddingbag": 4, "nn": [4, 5, 7, 8, 10, 11, 12, 13], "shardedembeddingbagcollect": [4, 11, 13], "embeddingbagcollect": [4, 8, 10, 11, 13], "sharder": [4, 5, 8], "defin": [4, 5, 6, 8, 9, 10, 11], "comput": [4, 5, 6, 8, 9, 10, 11, 13, 14], "kernel": [4, 5, 11], "cpu": [4, 5, 9], "gpu": [4, 5], "mai": [4, 14], "fusion": 4, "trainpipelinesparsedist": 4, "overlap": 4, "transfer": 4, "inter": [4, 11], "input_dist": [4, 11], "forward": [4, 5, 6, 8, 9, 10, 11, 13, 14], "backward": [4, 5, 7, 12], "increas": [4, 9], "perform": [4, 5, 6, 8, 9, 11, 12, 13, 14], "quantiz": [4, 6, 7, 8, 13], "precis": [4, 5, 11, 13], "construct": [4, 7, 11, 14], "control": [4, 7], "flow": [4, 7], "invoke_on_rank_and_broadcast_result": 4, "pg": [4, 5, 6, 9], "processgroup": [4, 5, 6, 9, 12], "func": 4, "kwarg": [4, 9, 11, 14], "invok": [4, 5], "broadcast": [4, 5], "member": [4, 11], "group": [4, 5, 6, 9, 11, 12, 14], "allocate_id": 4, "is_lead": 4, "leader_rank": 4, "check": [4, 5, 9, 11, 12, 14], "processs": 4, "leader": [4, 9], "dist": [4, 6, 9, 12], "impli": 4, "e": [4, 5, 6, 7, 8, 9, 10, 11, 12], "g": [4, 5, 8, 9, 10, 11, 12], "singl": [4, 5, 6, 11, 12, 14], "program": [4, 5], "definit": [4, 7, 8], "caller": 4, "overrid": [4, 5, 7, 8, 9], "context": [4, 6, 14], "run_on_lead": 4, "get_group_rank": 4, "avail": [4, 5, 6], "group_rank": 4, "varibl": 4, "get_num_group": 4, "elast": 4, "run": [4, 5, 6, 8, 9, 11, 12], "get_local_rank": 4, "usual": [4, 5, 6, 9, 11], "node": [4, 7], "get_local_s": 4, "equival": 4, "max_nnod": 4, "intra_and_cross_node_pg": 4, "backend": [4, 6, 8, 9], "sub": 4, "intra": 4, "cross": [4, 10, 11], "all2alldenseinfo": 4, "output_split": [4, 6], "input_shap": 4, "input_split": [4, 6], "attribut": [4, 5, 9, 12], "alltoall_dens": 4, "all2allpooledinfo": 4, "batch_size_per_rank": [4, 6], "dim_sum_per_rank": [4, 6], "dim_sum_per_rank_tensor": 4, "cumsum_dim_sum_per_rank_tensor": 4, "codec": [4, 6], "quantizedcommcodec": [4, 6], "alltoall_pool": [4, 6], "sum": [4, 5, 6, 11], "dimens": [4, 5, 6, 10, 11, 13, 14], "fast": 4, "_recat_pooled_embedding_grad_out": 4, "cumul": [4, 9, 14], "all2allsequenceinfo": 4, "embedding_dim": [4, 5, 6, 10, 11, 13], "lengths_after_sparse_data_all2al": 4, "forward_recat_tensor": 4, "backward_recat_tensor": 4, "variable_batch_s": 4, "permuted_lengths_after_sparse_data_all2al": 4, "alltoall_sequ": 4, "alltoal": [4, 6], "recat": [4, 6, 11, 14], "variabl": [4, 6, 9, 11, 13, 14], "all2allvinfo": 4, "dims_sum_per_rank": 4, "b_global": 4, "b_local": 4, "b_local_list": 4, "d_local_list": 4, "input_split_s": 4, "factori": [4, 5, 11], "output_split_s": 4, "alltoallv": 4, "global": [4, 5, 6, 9, 12], "my": 4, "how": [4, 5, 6, 12], "do": [4, 5, 9, 11, 12, 14], "all_to_all_singl": 4, "fill": 4, "all2all_pooled_req": 4, "ctx": 4, "unus": [4, 11], "formula": [4, 5], "differenti": 4, "overridden": [4, 6, 8, 9, 11], "subclass": [4, 6, 8, 11, 12], "vjp": 4, "It": [4, 5, 6, 8, 9, 11, 12, 13, 14], "accept": [4, 5, 8, 9, 11], "non": [4, 5, 6, 7, 9, 11, 13, 14], "were": 4, "gradient": [4, 5, 12], "w": [4, 6, 11, 14], "r": [4, 11], "requir": [4, 5, 9, 11, 12], "grad": [4, 12], "just": [4, 5, 10, 11, 14], "retriev": 4, "dure": [4, 5, 9, 12], "ha": [4, 5, 9, 10, 11, 14], "needs_input_grad": 4, "boolean": 4, "myreq": 4, "a2ai": 4, "input_embed": [4, 11], "custom": [4, 5, 7, 11], "autograd": [4, 8, 9, 11], "usag": [4, 5, 9], "combin": [4, 11, 12], "staticmethod": 4, "def": [4, 11], "other": [4, 5, 6, 9, 12], "detail": [4, 5, 6, 9, 11], "setup_context": 4, "longer": [4, 5], "instead": [4, 5, 6, 8, 11, 12, 14], "arbitrari": 4, "though": 4, "enforc": [4, 8, 9, 11], "compat": [4, 5, 7, 10, 12], "save_for_backward": 4, "intend": 4, "save_for_forward": 4, "jvp": 4, "all2all_pooled_wait": 4, "grad_output": 4, "dummy_tensor": 4, "all2all_seq_req": 4, "sharded_input_embed": 4, "all2all_seq_req_wait": 4, "sharded_grad_output": 4, "all2allv_req": 4, "all2allv_wait": 4, "allgatherbaseinfo": 4, "input_s": [4, 5, 11], "all_gatther_base_pool": 4, "allgatherbase_req": 4, "agi": 4, "allgatherbase_wait": 4, "reducescatterbaseinfo": 4, "reduce_scatter_base_pool": 4, "flatten": [4, 6, 11], "reducescatterbase_req": 4, "rsi": 4, "reducescatterbase_wait": 4, "reducescatterinfo": 4, "reduce_scatter_pool": 4, "reducescattervinfo": 4, "equal_split": 4, "total_input_s": 4, "reduce_scatter_v_pool": 4, "along": [4, 6, 9, 10, 12, 14], "dim": [4, 6], "reducescatterv_req": 4, "reducescatterv_wait": 4, "reducescatter_req": 4, "reducescatter_wait": 4, "await": [4, 6, 7], "variablebatchall2allpooledinfo": 4, "batch_size_per_rank_per_featur": [4, 6], "batch_size_per_feature_pre_a2a": [4, 6], "emb_dim_per_rank_per_featur": [4, 6], "variable_batch_alltoall_pool": [4, 6], "variable_batch_all2all_pooled_req": 4, "variable_batch_all2all_pooled_wait": 4, "all2all_pooled_sync": 4, "all2all_sequence_sync": 4, "all2allv_sync": 4, "all_gather_base_pool": 4, "gather": [4, 6], "form": [4, 11, 13], "pool": [4, 5, 6, 10, 11, 13, 14], "output_tensor_s": 4, "work": [4, 5, 8, 9, 11, 14], "async": [4, 6], "wait": [4, 6], "later": [4, 11], "experiment": [4, 11], "subject": 4, "all_gather_base_sync": 4, "all_gather_into_tensor_backward": 4, "all_gather_into_tensor_fak": 4, "gather_dim": 4, "group_siz": 4, "group_nam": 4, "gradient_divis": 4, "all_gather_into_tensor_setup_context": 4, "all_to_all_single_backward": 4, "all_to_all_single_fak": 4, "all_to_all_single_setup_context": 4, "a2a_pooled_embs_tensor": 4, "Then": 4, "receiv": [4, 6, 12], "Its": 4, "x": [4, 5, 6, 10, 11, 13, 14], "d_local_sum": 4, "where": [4, 5, 6, 9, 10, 11, 13, 14], "a2a_sequence_embs_tensor": 4, "doe": [4, 5, 10, 11, 12, 14], "mix": 4, "out_split": 4, "per_rank_split_length": 4, "assumpt": [4, 14], "emb": 4, "get_gradient_divis": 4, "get_use_sync_collect": 4, "pg_name": 4, "reduce_scatter_base_sync": 4, "chunk": [4, 6], "reduce_scatter_sync": 4, "reduce_scatter_tensor_backward": 4, "reduce_scatter_tensor_fak": 4, "reduceop": 4, "reduce_scatter_tensor_setup_context": 4, "reduce_scatter_v_per_feature_pool": 4, "v": [4, 6, 11, 14], "d": [4, 5, 10, 11, 12, 13, 14], "unevenli": 4, "reduce_scatter_v_sync": 4, "set_gradient_divis": 4, "set_use_sync_collect": 4, "torchrec_use_sync_collect": 4, "variable_batch_all2all_pooled_sync": 4, "embeddingsalltoon": [4, 6], "cat_dim": [4, 6, 14], "buffer": [4, 6, 8, 9, 11], "alloc": [4, 5, 6, 8], "like": [4, 5, 6, 7, 11, 12, 14], "alltoon": [4, 6], "set_devic": [4, 6], "device_str": [4, 6], "embeddingsalltoonereduc": [4, 6], "jaggedtensoralltoal": [4, 6], "jt": [4, 6, 11, 14], "jaggedtensor": [4, 6, 11, 13, 14], "num_items_to_send": [4, 6], "num_items_to_rec": [4, 6], "redistribut": [4, 6], "send": [4, 6], "known": [4, 5, 6, 11], "ahead": [4, 6], "keyedjaggedtensorpool": [4, 6], "lookup": [4, 5, 6, 10, 11, 13], "anoth": [4, 5, 6], "kjtalltoal": [4, 6], "stagger": [4, 6, 14], "kjtalltoallsplitsawait": [4, 6], "transmit": [4, 6], "correct": [4, 6, 14], "space": [4, 5, 6, 10], "kjtalltoalltensorsawait": [4, 6], "asynchron": [4, 6, 14], "len": [4, 6, 10, 14], "order": [4, 5, 6, 8, 9, 11, 14], "destin": [4, 6, 8, 9, 11], "_get_recat": [4, 6], "kjta2a": [4, 6], "rank0_input": [4, 6], "hold": [4, 5, 6, 12, 14], "v0": [4, 6, 14], "v1": [4, 6, 11, 14], "v2": [4, 6, 10, 11, 14], "rank1_input": [4, 6], "v3": [4, 6, 14], "v4": [4, 6, 14], "rank0_output": [4, 6], "5": [4, 6, 9, 10, 11, 13, 14], "rank1_output": [4, 6], "relev": [4, 5, 6], "tensor_split": [4, 6], "input_tensor": [4, 6], "ie": [4, 5, 6, 11, 14], "stride_per_rank": [4, 6, 14], "stride": [4, 6, 14], "case": [4, 5, 6, 9, 11, 12, 14], "kjtonetoal": [4, 6], "onetoal": [4, 6], "essenti": [4, 5, 6, 14], "p2p": [4, 6], "keyjaggedtensor": [4, 6], "them": [4, 6, 8, 10, 11, 12, 14], "kjtlist": [4, 6], "slice": [4, 6, 7, 14], "mergepooledembeddingsmodul": [4, 6], "merge_pooled_embedding_optim": [4, 6], "_mergepooledembeddingsmoduleimpl": [4, 6], "merge_pooled_embed": [4, 6], "pooledembeddingsallgath": [4, 6], "wrap": [4, 6, 9, 10, 12], "layout": [4, 6, 7], "want": [4, 6, 9], "nccl": [4, 6], "happen": [4, 5, 6], "init_distribut": [4, 6], "new_group": [4, 6, 9], "randn": [4, 6, 10, 11], "m": [4, 5, 6, 7, 11], "local_emb": [4, 6], "pooledembeddingsawait": [4, 6], "num_bucket": [4, 6], "pooledembeddingsalltoal": [4, 6], "callback": [4, 6], "a2a": [4, 6], "t0": [4, 6], "rand": [4, 6, 10], "6": [4, 5, 6, 10, 11, 13, 14], "t1": [4, 6, 10, 11, 13], "print": [4, 6, 11, 13], "properti": [4, 5, 6, 8, 9, 10, 11, 12, 13], "tensor_await": [4, 6], "pooledembeddingsreducescatt": [4, 6], "twrw": [4, 5, 6], "unequ": [4, 6], "bucket": [4, 6], "seqembeddingsalltoon": [4, 6], "concat": [4, 6, 11, 14], "sequenceembeddingsalltoal": [4, 6], "features_per_rank": [4, 6], "sharding_ctx": [4, 6], "sequenceshardingcontext": [4, 6], "lengths_after_input_dist": [4, 6], "unbucketize_permute_tensor": [4, 6], "sparse_features_recat": [4, 6], "sequenceembeddingsawait": [4, 6], "permut": [4, 6, 14], "splitsalltoallawait": [4, 6], "tensoralltoal": [4, 6], "1d": [4, 5, 6], "tensoralltoallsplitsawait": [4, 6], "tensoralltoallvaluesawait": [4, 6], "tensor_a2a": [4, 6], "rank0": [4, 6], "rank1": [4, 6], "v5": [4, 6, 14], "v6": [4, 6, 14], "v7": [4, 6, 14], "v8": [4, 6], "v9": [4, 6], "v10": [4, 6], "v11": [4, 6], "v12": [4, 6], "tensorvaluesalltoal": [4, 6], "tensor_vals_a2a": [4, 6], "v13": [4, 6], "v14": [4, 6], "v15": [4, 6], "sent": [4, 6], "equal": [4, 5, 6, 11, 14], "_pg": [4, 6], "variablebatchpooledembeddingsalltoal": [4, 6], "kjt_split": [4, 6], "r0_batch_siz": [4, 6], "r1_batch_siz": [4, 6], "f_0": [4, 6], "f_1": [4, 6], "f_2": [4, 6], "r0_batch_size_per_rank_per_featur": [4, 6], "r1_batch_size_per_rank_per_featur": [4, 6], "r0_batch_size_per_feature_pre_a2a": [4, 6], "r1_batch_size_per_feature_pre_a2a": [4, 6], "r0": [4, 6], "r1": [4, 6], "14": [4, 6], "post": [4, 6], "rank_0": [4, 6], "rank_1": [4, 6], "variablebatchpooledembeddingsreducescatt": [4, 6], "rw": [4, 5, 6, 11], "multipli": [4, 5, 6], "batch_size_r0_f0": [4, 6], "emb_dim_f0": [4, 6], "embeddingcollectionawait": 4, "lazyawait": 4, "embeddingcollectioncontext": 4, "sharding_context": 4, "input_featur": 4, "reverse_indic": [4, 11], "seq_vbe_ctx": [4, 11], "sequencevbecontext": [4, 11], "multistream": [4, 11], "embeddingcollectionshard": 4, "fused_param": [4, 6], "qcomm_codecs_registri": [4, 6], "use_index_dedup": 4, "baseembeddingshard": 4, "embeddingcollect": [4, 8, 11, 13], "module_typ": [4, 8, 13], "parametershard": 4, "env": [4, 6], "shardingenv": [4, 6], "shardedembeddingcollect": [4, 11, 13], "locat": 4, "replic": [4, 5, 6], "embeddingmoduleshardingplan": 4, "fulli": [4, 5, 12], "qualifi": 4, "spec": 4, "shardedmodul": 4, "shardable_paramet": 4, "sharding_typ": [4, 5, 11], "compute_device_typ": 4, "shardingtyp": [4, 5, 11], "well": [4, 5, 11], "table_name_to_parameter_shard": 4, "shardedembeddingmodul": 4, "fusedoptimizermodul": [4, 12], "public": [4, 11], "manual": [4, 10, 12], "dist_input": 4, "compute_and_output_dist": 4, "In": [4, 5, 11, 12, 14], "sens": [4, 12], "initi": [4, 11, 12], "distibut": 4, "soon": 4, "complet": [4, 5], "create_context": 4, "fused_optim": [4, 12], "keyedoptim": [4, 12], "output_dist": [4, 5], "reset_paramet": [4, 11], "create_embedding_shard": 4, "sharding_info": [4, 6], "embeddingshardinginfo": [4, 6], "embeddingshard": [4, 6], "create_sharding_infos_by_shard": 4, "embeddingcollectioninterfac": [4, 11, 13], "create_sharding_infos_by_sharding_device_group": 4, "get_device_from_parameter_shard": 4, "ps": [4, 5], "get_ec_index_dedup": 4, "pad_vbe_kjt_length": 4, "set_ec_index_dedup": 4, "commopgradientsc": 4, "functionctx": 4, "scale_gradient_factor": 4, "groupedembeddingslookup": 4, "grouped_config": 4, "groupedembeddingconfig": [4, 6], "baseembeddinglookup": [4, 6], "i": [4, 5, 6, 7, 9, 10, 11], "flush": 4, "everi": [4, 5, 6, 8, 11], "although": [4, 6, 8, 11], "recip": [4, 6, 8, 11], "afterward": [4, 6, 8, 11], "sinc": [4, 5, 6, 8, 11], "former": [4, 6, 8, 11], "take": [4, 5, 6, 8, 11, 12], "care": [4, 6, 8, 11], "regist": [4, 6, 7, 8, 11], "hook": [4, 6, 8, 11], "latter": [4, 6, 8, 11], "silent": [4, 6, 8, 11], "load_state_dict": [4, 12], "state_dict": [4, 8, 9, 11, 12], "ordereddict": [4, 8, 9, 11], "union": [4, 5, 7, 8, 9, 11, 12], "shardedtensor": [4, 12], "strict": [4, 12], "_incompatiblekei": 4, "descend": [4, 5], "unless": [4, 12], "get_swap_module_params_on_convers": 4, "persist": [4, 8, 9, 11], "strictli": [4, 11], "preserv": [4, 11], "except": [4, 5, 9, 11, 14], "requires_grad": 4, "field": [4, 11, 12, 14], "missing_kei": 4, "miss": [4, 5], "unexpected_kei": 4, "present": [4, 12], "namedtupl": 4, "runtimeerror": 4, "named_buff": [4, 11], "prefix": [4, 8, 9, 11], "recurs": [4, 11], "remove_dupl": [4, 11], "yield": [4, 11], "both": [4, 8, 9, 10, 11, 12, 14], "itself": [4, 10, 11], "prepend": [4, 11], "submodul": [4, 11, 12], "otherwis": [4, 5, 8, 9, 11, 12, 14], "direct": [4, 11], "remov": [4, 7, 11], "duplic": [4, 11, 12], "xdoctest": [4, 8, 9, 11], "skip": [4, 8, 9, 11, 12], "undefin": [4, 8, 9, 11], "var": [4, 8, 9, 11], "buf": [4, 11], "running_var": [4, 11], "named_paramet": 4, "bia": [4, 8, 9, 11], "named_parameters_by_t": 4, "tablebatchedembeddingslic": 4, "table_nam": 4, "embedding_weight": 4, "cw": [4, 5], "compos": [4, 8, 9, 11], "prefetch": [4, 5], "forward_stream": 4, "purg": 4, "keep_var": [4, 8, 9, 11], "dictionari": [4, 8, 9, 11, 14], "refer": [4, 8, 9, 11, 14], "whole": [4, 8, 9, 11], "averag": [4, 5, 8, 9, 11], "shallow": [4, 8, 9, 11], "posit": [4, 5, 6, 8, 9, 11], "howev": [4, 8, 9, 11, 12], "deprec": [4, 8, 9, 11], "keyword": [4, 8, 9, 11], "futur": [4, 8, 9, 11], "releas": [4, 8, 9, 11], "end": [4, 5, 8, 9, 11], "user": [4, 5, 8, 9, 11, 12], "detach": [4, 8, 9, 11], "groupedpooledembeddingslookup": 4, "feature_processor": [4, 6, 13], "basegroupedfeatureprocessor": [4, 6, 11], "scale_weight_gradi": 4, "infercpugroupedembeddingslookup": 4, "grouped_configs_per_rank": 4, "infergroupedlookupmixin": 4, "inputdistoutput": [4, 6], "tbetoregistermixin": 4, "get_tbes_to_regist": 4, "intnbittablebatchedembeddingbagscodegen": 4, "infergroupedembeddingslookup": 4, "input_dist_output": 4, "infergroupedpooledembeddingslookup": 4, "metainfergroupedembeddingslookup": 4, "tbe": [4, 5, 13], "op": [4, 5, 6, 12, 13], "metainfergroupedpooledembeddingslookup": 4, "bag": [4, 6, 7, 10, 11], "dtype": [4, 5, 6, 7, 8, 11, 13, 14], "embeddings_cat_empty_rank_handl": 4, "dummy_embs_tensor": 4, "embeddings_cat_empty_rank_handle_infer": 4, "fx_wrap_tensor_view2d": 4, "dim0": 4, "dim1": 4, "baseembeddingdist": [4, 6], "embeddinglookup": 4, "abstract": [4, 5, 8, 9, 11, 12], "basesparsefeaturesdist": [4, 6], "featureshardingmixin": 4, "table_wis": [4, 11], "create_input_dist": [4, 6], "create_lookup": [4, 6], "create_output_dist": [4, 6], "embedding_nam": [4, 6, 11], "embedding_names_per_rank": [4, 6], "embedding_shard_metadata": [4, 6], "shardmetadata": [4, 6], "embedding_t": [4, 6], "shardedembeddingt": [4, 6], "uncombined_embedding_dim": [4, 6], "uncombined_embedding_nam": [4, 6], "embeddingshardingcontext": [4, 6], "variable_batch_per_featur": 4, "embedding_config": [4, 13], "embeddingtableconfig": [4, 11], "param_shard": 4, "nonetyp": [4, 9, 11], "fusedkjtlistsplitsawait": 4, "kjtlistsplitsawait": 4, "kjtlistawait": 4, "info": [4, 11], "metadata": [4, 8, 11], "kjtsplitsalltoallmeta": 4, "distributed_c10d": 4, "_input": 4, "splits_tensor": 4, "listofkjtlistawait": 4, "listofkjtlist": 4, "listofkjtlistsplitsawait": 4, "bucketize_kjt_before_all2al": 4, "block_siz": [4, 6], "output_permut": 4, "bucketize_po": 4, "block_bucketize_row_po": 4, "keep_original_indic": [4, 6], "readjust": 4, "unbucket": 4, "offset": [4, 5, 10, 11, 13, 14], "keep": [4, 5, 11], "origin": [4, 5, 8, 10], "bucketize_kjt_infer": 4, "is_sequ": [4, 6], "group_tabl": 4, "tables_per_rank": 4, "datatyp": [4, 5, 11, 13, 14], "poolingtyp": [4, 11], "embeddingcomputekernel": [4, 5], "weighted": 4, "interfac": [4, 8, 9, 11], "reli": [4, 8, 11, 13], "moduleshard": [4, 5, 8], "compute_kernel": [4, 5], "storage_usag": 4, "resourc": 4, "processor": [4, 6, 8, 11], "basequantembeddingshard": 4, "shardable_param": 4, "dtensormetadata": 4, "mesh": 4, "device_mesh": 4, "devicemesh": 4, "placement": [4, 5], "_tensor": 4, "placement_typ": 4, "embeddingattribut": 4, "enum": [4, 5, 11, 12], "enumer": [4, 12], "fuse": [4, 6, 9], "fused_uvm": 4, "fused_uvm_cach": 4, "key_valu": 4, "quant": 4, "quant_uvm": 4, "quant_uvm_cach": 4, "feature_nam": [4, 5, 6, 10, 11, 13], "feature_names_per_rank": [4, 6], "data_typ": [4, 11], "is_weight": [4, 5, 11, 13, 14], "has_feature_processor": [4, 6, 11], "dim_sum": 4, "feature_hash_s": [4, 6], "num_featur": [4, 6, 10, 11], "bucket_mapping_tensor": 4, "bucketized_length": 4, "moduleshardingmixin": 4, "access": [4, 5, 12, 14], "scheme": 4, "optimtyp": 4, "adagrad": [4, 12], "adam": [4, 12], "adamw": 4, "lamb": 4, "lars_sgd": 4, "lion": 4, "partial_rowwise_adam": 4, "partial_rowwise_lamb": 4, "rowwise_adagrad": 4, "sgd": 4, "shampoo": 4, "shampoo_mr": 4, "shampoo_v2": 4, "shampoo_v2_mr": 4, "shardedconfig": 4, "local_row": [4, 5], "local_col": [4, 5], "compin": 4, "distout": 4, "out": [4, 11, 14], "shrdctx": 4, "commop": 4, "extra_repr": 4, "pretti": 4, "represent": [4, 5, 7, 11, 14], "num_embed": [4, 5, 10, 11, 13], "fp32": [4, 5, 11], "weight_init_max": [4, 11], "weight_init_min": [4, 11], "num_embeddings_post_prun": [4, 11], "init_fn": [4, 11], "need_po": [4, 6, 11], "local_metadata": 4, "_shard": 4, "global_metadata": 4, "sharded_tensor": 4, "shardedtensormetadata": 4, "dtensor_metadata": 4, "shardedmetaconfig": 4, "compute_kernel_to_embedding_loc": 4, "embeddingloc": 4, "embeddingawait": 4, "embeddingbagcollectionawait": 4, "lazygetitemmixin": 4, "keyedtensor": [4, 10, 11, 13, 14], "embeddingbagcollectioncontext": 4, "inverse_indic": [4, 11, 14], "divisor": 4, "embeddingbagcollectionshard": 4, "embeddingbagshard": 4, "nullshardedmodulecontext": 4, "per_sample_weight": 4, "named_modul": 4, "memo": 4, "network": [4, 5, 11, 12], "alreadi": [4, 6, 8, 12, 14], "onc": [4, 11], "l": [4, 11, 13], "linear": [4, 5, 11, 12], "net": [4, 10, 11], "sequenti": [4, 5, 11], "in_featur": [4, 10, 11], "out_featur": [4, 11], "sharded_parameter_nam": 4, "embeddingbagcollectioninterfac": [4, 11, 13], "variablebatchembeddingbagcollectionawait": 4, "construct_output_kt": 4, "create_embedding_bag_shard": 4, "permute_embed": [4, 6], "suffix": 4, "replace_placement_with_meta_devic": 4, "could": [4, 5, 14], "unmatch": 4, "scenario": [4, 11, 13], "dmp": 4, "cuda": [4, 5, 8], "embeddingshardingplann": [4, 5], "planner": 4, "groupedpositionweightedmodul": 4, "max_feature_length": [4, 11], "dataparallelwrapp": 4, "defaultdataparallelwrapp": 4, "bucket_cap_mb": 4, "static_graph": 4, "find_unused_paramet": 4, "allreduce_comm_precis": 4, "params_to_ignor": 4, "ddp_kwarg": 4, "unshard": [4, 5, 11, 13], "shardingplan": [4, 5, 8], "init_data_parallel": 4, "init_paramet": 4, "data_parallel_wrapp": 4, "entri": 4, "point": [4, 5], "collective_plan": [4, 5], "lazi": [4, 11, 12, 14], "delai": 4, "until": 4, "still": [4, 14], "no_grad": [4, 11], "init_weight": [4, 11], "isinst": 4, "fill_": [4, 11], "elif": 4, "init": [4, 11], "kaiming_normal_": 4, "mymodel": 4, "bare_named_paramet": 4, "tor": 4, "safe": 4, "ddp": 4, "fsdp": 4, "sparse_grad_parameter_nam": [4, 12], "get_modul": 4, "unwrap": 4, "get_unwrapped_modul": 4, "quantembeddingbagcollectionshard": [4, 8], "shardedquantembeddingbagcollect": 4, "quantfeatureprocessedembeddingbagcollectionshard": [4, 8], "featureprocessedembeddingbagcollect": [4, 8, 13], "shardedquantebcinputdist": 4, "sharding_type_device_group_to_shard": 4, "nullshardingcontext": [4, 6], "sharding_type_to_shard": 4, "sqebc_input_dist": 4, "infertwsequenceembeddingshard": 4, "f1": [4, 10, 11, 13], "f2": [4, 10, 11, 13], "7": [4, 9, 10, 11, 13, 14], "8": [4, 5, 10, 11, 13, 14], "shardedquantembeddingmodulest": 4, "embedding_bag_config": [4, 11, 13], "embeddingbagconfig": [4, 10, 11, 13], "execut": [4, 5, 8, 11, 13], "sharding_type_device_group_to_sharding_info": 4, "tbes_config": 4, "shardedquantfeatureprocessedembeddingbagcollect": 4, "featureprocessorscollect": [4, 13], "apply_feature_processor": 4, "kjt_list": [4, 14], "embedding_bag": [4, 13], "moduledict": [4, 13], "modulelist": [4, 9, 11, 13], "create_infer_embedding_bag_shard": 4, "flatten_feature_length": 4, "get_device_from_sharding_info": 4, "emb_shard_info": 4, "cacheparam": [4, 5], "algorithm": 4, "cachealgorithm": 4, "load_factor": [4, 5], "reserved_memori": 4, "prefetch_pipelin": [4, 5], "stat": 4, "cachestatist": [4, 5], "multipass_prefetch_config": 4, "multipassprefetchconfig": 4, "relat": [4, 5, 9], "uvm": [4, 5], "lru": [4, 5], "lfu": 4, "factor": [4, 5, 11], "decid": 4, "crucial": 4, "reserv": [4, 5], "ideal": 4, "aka": 4, "statist": [4, 5], "better": [4, 5], "tune": [4, 12], "cacheabl": [4, 5, 14], "summar": [4, 5], "measur": [4, 5, 9], "difficulti": [4, 5], "dataset": [4, 5, 10], "independ": [4, 5], "score": [4, 5, 6, 11], "veri": [4, 5], "high": [4, 5, 9, 11], "difficult": [4, 5], "expected_lookup": [4, 5], "distinct": [4, 5], "expected_miss_r": [4, 5], "clf": [4, 5], "rate": [4, 5, 9, 12], "hit": [4, 5], "extrem": [4, 5], "estim": [4, 5, 9], "pooled_embeddings_all_to_al": 4, "pooled_embeddings_reduce_scatt": 4, "sequence_embeddings_all_to_al": 4, "computekernel": 4, "moduleshardingplan": 4, "describ": 4, "genericmeta": 4, "getitemlazyawait": 4, "parentw": 4, "kt": [4, 14], "__getitem__": 4, "parent": 4, "keyvalueparam": [4, 5], "ssd_storage_directori": 4, "ssd_rocksdb_write_buffer_s": 4, "ssd_rocksdb_shard": 4, "gather_ssd_cache_stat": 4, "stats_reporter_config": 4, "tbestatsreporterconfig": 4, "use_passed_in_path": 4, "l2_cache_s": 4, "ps_host": 4, "ps_client_thread_num": 4, "ps_max_key_per_request": 4, "ps_max_local_index_length": 4, "ssd": [4, 5], "ssdtablebatchedembeddingbag": 4, "data00_nvidia": 4, "local_rank": 4, "rocksdb": 4, "write": 4, "relav": 4, "compact": 4, "std": 4, "report": [4, 9], "od": 4, "report_interv": 4, "interv": [4, 9, 11], "ods_prefix": 4, "server": 4, "host": [4, 5, 6], "ip": 4, "port": 4, "2000": 4, "2001": 4, "2002": 4, "reason": [4, 12], "hashabl": 4, "thread": [4, 5], "client": 4, "maximum": [4, 5, 9, 11], "index": [4, 11, 14], "expos": [4, 12], "concret": 4, "achiev": 4, "late": 4, "possibl": [4, 5, 9], "__torch_function__": 4, "below": 4, "doesn": [4, 11, 12], "python": [4, 7, 10], "magic": 4, "__getattr__": 4, "caveat": 4, "mechan": [4, 11], "ensur": [4, 11, 14], "perfect": 4, "quickli": 4, "long": [4, 5, 11], "kwd": 4, "vt_co": 4, "augment": 4, "trigger": [4, 11], "keyedlazyawait": 4, "defer": 4, "mixin": 4, "inherit": [4, 9, 11], "mro": 4, "properli": [4, 11], "select": [4, 5, 6, 14], "lazynowait": 4, "classmethod": [4, 5, 8, 13], "noopquantizedcommcodec": 4, "quantizationcontext": 4, "No": [4, 6, 9], "calc_quantized_s": 4, "input_len": 4, "decod": 4, "input_grad": 4, "encod": 4, "padded_s": 4, "dim_per_rank": 4, "my_rank": [4, 9], "qcomm_ctx": 4, "quantized_dtyp": 4, "nowait": [4, 7], "obj": 4, "objectpoolshardingplan": 4, "objectpoolshardingtyp": 4, "replicated_row_wis": 4, "row_wis": [4, 11], "sharding_spec": 4, "shardingspec": 4, "cache_param": [4, 5], "enforce_hbm": [4, 5], "stochastic_round": [4, 5], "bounds_check_mod": [4, 5], "boundscheckmod": [4, 5], "output_dtyp": [4, 5, 8, 13], "key_value_param": [4, 5], "hbm": [4, 5], "stochast": [4, 5], "round": [4, 5], "bound": [4, 5], "place": [4, 5, 6, 12, 14], "column_wis": [4, 11], "seen": [4, 7], "individu": [4, 5, 10, 11], "table_row_wis": [4, 11], "data_parallel": [4, 5, 11], "parameterstorag": 4, "physic": 4, "constraint": [4, 5, 8], "shardingplann": [4, 5], "ddr": [4, 5], "pipelinetyp": [4, 5], "about": 4, "train_bas": 4, "train_prefetch_sparse_dist": 4, "train_sparse_dist": 4, "pooled_all_to_al": 4, "reduce_scatt": 4, "quantized_tensor": 4, "quantized_comm_codec": 4, "collective_cal": 4, "output_tensor": 4, "assert_clos": 4, "int8": [4, 8], "addit": [4, 5, 7, 8, 10, 11, 12, 14], "carri": 4, "session": 4, "padded_dim_sum": 4, "padding_s": 4, "respect": [4, 10, 11], "sequence_all_to_al": 4, "modulenocopymixin": [4, 13], "vise": [4, 12], "versa": [4, 12], "practic": 4, "from_loc": 4, "typic": [4, 5, 7, 11, 12, 14], "from_process_group": 4, "fqn": [4, 5], "larger": [4, 5], "get_plan_for_modul": 4, "module_path": 4, "re": [4, 12], "stabil": 4, "table_column_wis": [4, 11], "get_tensor_size_byt": 4, "rank_devic": 4, "device_typ": 4, "scope": 4, "copyablemixin": 4, "mymodul": 4, "forkedpdb": 4, "completekei": 4, "tab": 4, "stdin": 4, "stdout": 4, "nosigint": 4, "readrc": 4, "pdb": 4, "multiprocess": 4, "child": 4, "debug": [4, 5, 9], "multiprocessing_util": 4, "set_trac": 4, "barrier": 4, "add_params_from_parameter_shard": 4, "parameter_shard": 4, "extract": 4, "ones": 4, "add_prefix_to_state_dict": 4, "filter": [4, 11], "append_prefix": 4, "append": 4, "convert_to_fbgemm_typ": 4, "copy_to_devic": 4, "current_devic": [4, 8], "to_devic": 4, "filter_state_dict": 4, "strip": 4, "begin": [4, 5, 12], "get_unsharded_module_nam": 4, "level": [4, 6], "don": [4, 8, 11, 14], "merge_fused_param": 4, "param_fused_param": 4, "configur": 4, "cache_precis": 4, "preset": 4, "table_level_fused_param": 4, "precid": 4, "grouped_fused_param": 4, "null": 4, "none_throw": 4, "_t": 4, "messag": [4, 5], "unexpect": 4, "assertionerror": 4, "optimizer_type_to_emb_opt_typ": 4, "optimizer_class": 4, "emboptimtyp": 4, "sharded_model_copi": 4, "m_cpu": 4, "deepcopi": 4, "managedcollisioncollectionawait": 4, "managedcollisioncollectioncontext": 4, "managedcollisioncollectionshard": 4, "managedcollisioncollect": [4, 11], "shardedmanagedcollisioncollect": 4, "evict": [4, 11], "global_to_local_index": 4, "jt_dict": [4, 14], "open_slot": [4, 11], "create_mc_shard": 4, "managedcollisionembeddingbagcollectioncontext": 4, "evictions_per_t": 4, "remapped_kjt": 4, "managedcollisionembeddingbagcollectionshard": 4, "ebc_shard": 4, "mc_sharder": 4, "basemanagedcollisionembeddingcollectionshard": 4, "managedcollisionembeddingbagcollect": [4, 11], "shardedmanagedcollisionembeddingbagcollect": 4, "baseshardedmanagedcollisionembeddingcollect": 4, "managedcollisionembeddingcollectioncontext": 4, "managedcollisionembeddingcollectionshard": 4, "ec_shard": 4, "managedcollisionembeddingcollect": [4, 11], "shardedmanagedcollisionembeddingcollect": 4, "consid": [5, 11, 13, 14], "perf": 5, "storag": [5, 14], "peak": 5, "elimin": 5, "oom": [5, 9], "kernel_bw_lookup": 5, "compute_devic": [5, 8], "hbm_mem_bw": 5, "ddr_mem_bw": 5, "caching_ratio": 5, "calcul": [5, 9], "bandwidth": 5, "ratio": [5, 9], "embeddingenumer": 5, "parameterconstraint": [5, 8], "shardestim": 5, "use_exact_enumerate_ord": 5, "shardabl": 5, "exact": 5, "name_children": 5, "shardingopt": 5, "popul": [5, 11], "populate_estim": 5, "sharding_opt": 5, "descript": [5, 9], "get_partition_by_typ": 5, "partitionbytyp": 5, "greedyperfpartition": 5, "sort_bi": 5, "sortbi": 5, "balance_modul": 5, "greedi": 5, "sort": [5, 11], "smaller": 5, "effect": [5, 11], "storage_constraint": 5, "partition_bi": 5, "strategi": 5, "docstr": [5, 9, 14], "partition_by_devic": 5, "done": [5, 11, 12, 14], "clariti": 5, "memorybalancedpartition": 5, "max_search_count": 5, "toler": 5, "02": 5, "greedypartition": 5, "reject": 5, "200": 5, "wors": 5, "repeatedli": 5, "least": 5, "ordereddevicehardwar": 5, "devicehardwar": 5, "local_world_s": 5, "shardingoptiongroup": 5, "storage_sum": 5, "perf_sum": 5, "param_count": 5, "set_hbm_per_devic": 5, "hbm_per_devic": 5, "noopperfmodel": 5, "perfmodel": 5, "among": [5, 10], "without": [5, 9, 14], "noopstoragemodel": 5, "storagereserv": 5, "performance_model": 5, "heteroembeddingshardingplann": 5, "topology_group": 5, "dynamicprogrammingpropos": 5, "hbm_bins_per_devic": 5, "dynam": 5, "fashion": [5, 6, 14], "problem": 5, "frame": 5, "n": [5, 8, 10, 11, 12, 14], "minim": 5, "overal": [5, 10], "k": [5, 9, 10, 11], "mathemat": [5, 9], "formul": 5, "matrix": [5, 10, 11], "let": 5, "element": [5, 11], "denot": 5, "a_": 5, "j": 5, "b_": 5, "aim": 5, "j_0": 5, "j_1": 5, "ldot": 5, "j_": 5, "condit": [5, 11], "satisfi": 5, "sum_": 5, "j_i": 5, "leq": 5, "tackl": 5, "discret": 5, "k_i": 5, "transit": 5, "min_": 5, "left": [5, 14], "right": [5, 9, 11], "simpli": 5, "fit": 5, "card": 5, "therefor": 5, "maintain": 5, "last": [5, 10, 11, 14], "layer": [5, 10, 11, 12], "under": [5, 9], "vari": 5, "hdm": 5, "bin": 5, "perf_rat": 5, "search_spac": 5, "search": 5, "embeddingoffloadscaleuppropos": 5, "use_depth": 5, "allocate_budget": 5, "budget": 5, "allocation_prior": 5, "build_affine_storage_model": 5, "uvm_caching_sharding_opt": 5, "clf_to_byt": 5, "get_budget": 5, "get_cach": 5, "get_expected_lookup": 5, "next_plan": 5, "starting_propos": 5, "promote_high_prefetch_overheaad_table_to_hbm": 5, "overhead": 5, "io": 5, "offload": 5, "undo": 5, "promot": 5, "greedypropos": 5, "threshold": [5, 9, 11], "On": [5, 11], "tri": [5, 12], "earli": 5, "stop": 5, "consecut": 5, "best_perf_r": 5, "gridsearchpropos": 5, "max_propos": 5, "10000": 5, "uniformpropos": 5, "proposers_to_proposals_list": 5, "proposers_list": 5, "static_feedback": 5, "embeddingoffloadstat": 5, "mrc_hist_count": 5, "height": 5, "uvm_fused_cach": 5, "cachebl": 5, "area": [5, 9], "curv": [5, 9], "histogram": 5, "nth": 5, "wa": [5, 8], "estimate_cache_miss_r": 5, "cache_s": 5, "hist": 5, "mrc": 5, "embeddingperfestim": 5, "is_infer": 5, "wall": 5, "sharder_map": 5, "perf_func_emb_wall_tim": 5, "shard_siz": 5, "input_length": 5, "input_data_type_s": 5, "table_data_type_s": 5, "output_data_type_s": 5, "fwd_a2a_comm_data_type_s": 5, "bwd_a2a_comm_data_type_s": 5, "fwd_sr_comm_data_type_s": 5, "bwd_sr_comm_data_type_s": 5, "num_pool": 5, "intra_host_bw": 5, "inter_host_bw": 5, "bwd_compute_multipli": 5, "weighted_feature_bwd_compute_multipli": 5, "is_pool": 5, "expected_cache_fetch": 5, "uneven_sharding_perf_multipli": 5, "attempt": 5, "rel": [5, 11], "tw": 5, "queri": 5, "fwd_comm_data_type_s": 5, "bwd_comm_data_type_s": 5, "machin": [5, 11], "embeddingbag": [5, 7, 10, 11, 13], "unpool": 5, "ebc": [5, 8, 10, 11, 13], "signifi": 5, "fetch": 5, "embeddingstorageestim": 5, "pipeline_typ": 5, "run_embedding_at_peak_memori": 5, "Will": [5, 9, 14], "fwd": 5, "bwd": 5, "temporari": [5, 11], "toward": 5, "cost": [5, 11], "won": [5, 11], "ll": 5, "hidden": [5, 10, 11], "old": [5, 12], "agnost": 5, "forwrad": 5, "calculate_pipeline_io_cost": 5, "output_s": [5, 11], "prefetch_s": 5, "multipass_prefetch_max_pass": 5, "count_ephemeral_storage_cost": 5, "calculate_shard_storag": 5, "compris": 5, "synonym": 5, "byte": [5, 8, 9], "embeddingstat": 5, "sharding_plan": 5, "num_propos": 5, "num_plan": 5, "run_tim": 5, "best_plan": 5, "tabular": 5, "view": 5, "chosen": [5, 11], "evalu": [5, 11], "successfulli": 5, "noopembeddingstat": 5, "noop": 5, "round_to_one_sigfig": 5, "fixedpercentagestoragereserv": 5, "percentag": 5, "heuristicalstoragereserv": 5, "parameter_multipli": 5, "dense_tensor_estim": 5, "heurist": 5, "extra": 5, "percent": 5, "act": 5, "margin": 5, "error": [5, 9, 11, 14], "beyond": 5, "inferencestoragereserv": 5, "customtopologydata": 5, "get_data": 5, "has_data": 5, "supported_field": 5, "ddr_cap": 5, "hbm_cap": 5, "512": [5, 9], "min_partit": 5, "pooling_factor": 5, "fbgemm_gpu": 5, "split_table_batched_embeddings_ops_common": 5, "device_group": 5, "around": 5, "lower": [5, 7, 8, 12, 13], "divid": [5, 9], "divis": 5, "optionallist": 5, "momentum": 5, "accuraci": [5, 11], "term": [5, 11], "fp16": 5, "exce": 5, "todai": 5, "bldm": 5, "fwd_comput": 5, "fwd_comm": 5, "bwd_comput": 5, "bwd_comm": 5, "prefetch_comput": 5, "breakdown": 5, "plannererror": 5, "error_typ": 5, "plannererrortyp": 5, "classifi": 5, "insufficient_storag": 5, "strict_constraint": 5, "prospos": 5, "paritit": 5, "much": [5, 12], "depend": [5, 8, 11], "One": [5, 9, 11], "eval": 5, "job": 5, "tower": [5, 11], "cache_load_factor": 5, "module_pool": 5, "sharding_option_nam": 5, "num_input": 5, "num_shard": 5, "total_perf": 5, "total_storag": 5, "capac": 5, "hardwar": 5, "fits_in": 5, "963146416": 5, "128": [5, 9], "54760833": 5, "024": 5, "644245094": 5, "13421772": 5, "custom_topology_data": 5, "binarysearchpred": 5, "extern": [5, 10], "predic": 5, "discov": 5, "try": 5, "prior_result": 5, "probe": 5, "prior": 5, "explor": 5, "reach": [5, 9], "luusjaakolasearch": 5, "max_iter": 5, "42": 5, "left_cost": 5, "clamp": 5, "variant": 5, "luu": 5, "jaakola": 5, "en": 5, "wikipedia": 5, "wiki": 5, "far": 5, "associ": 5, "fy": 5, "y": [5, 10, 11], "previou": 5, "subsequ": 5, "been": [5, 11, 14], "shrink_right": 5, "shrink": 5, "boundari": 5, "infin": [5, 12], "random": [5, 10], "bytes_to_gb": 5, "num_byt": 5, "bytes_to_mb": 5, "gb_to_byt": 5, "gb": 5, "local_s": [5, 6], "prod": 5, "reset_shard_rank": 5, "sharder_nam": 5, "storage_repr_in_gb": 5, "basecwembeddingshard": 6, "basetwembeddingshard": 6, "cwpooledembeddingshard": 6, "infercwpooledembeddingdist": 6, "infercwpooledembeddingdistwithpermut": 6, "infercwpooledembeddingshard": 6, "basedpembeddingshard": 6, "dppooledembeddingdist": 6, "dppooledembeddingshard": 6, "dpsparsefeaturesdist": 6, "sparsefeatur": 6, "baserwembeddingshard": 6, "inferrwpooledembeddingdist": 6, "inferrwpooledembeddingshard": 6, "inferrwsparsefeaturesdist": 6, "rwpooledembeddingdist": 6, "share": [6, 11], "rwpooledembeddingshard": 6, "evenli": 6, "rwsparsefeaturesdist": 6, "intra_pg": 6, "get_block_sizes_runtime_devic": 6, "runtime_devic": 6, "tensor_cach": 6, "get_embedding_shard_metadata": 6, "grouped_embedding_configs_per_rank": 6, "infertwembeddingshard": 6, "infertwpooledembeddingdist": 6, "infertwsparsefeaturesdist": 6, "twpooledembeddingdist": 6, "twpooledembeddingshard": 6, "twsparsefeaturesdist": 6, "twcwpooledembeddingshard": 6, "basetwrwembeddingshard": 6, "twrwpooledembeddingdist": 6, "cross_pg": 6, "dim_sum_per_nod": 6, "emb_dim_per_node_per_featur": 6, "twrwpooledembeddingshard": 6, "twrwsparsefeaturesdist": 6, "id_list_features_per_rank": 6, "id_score_list_features_per_rank": 6, "id_list_feature_hash_s": 6, "id_score_list_feature_hash_s": 6, "look": [6, 7], "reorder": 6, "document": [7, 10], "leaf_modul": 7, "trace": [7, 8], "torchscript": 7, "create_arg": 7, "memory_format": 7, "opoverload": 7, "symint": 7, "symbool": 7, "symfloat": 7, "prepar": [7, 11], "graph": 7, "emit": 7, "appropri": 7, "is_leaf_modul": 7, "module_qualified_nam": 7, "path_of_modul": 7, "mod": 7, "abil": [7, 14], "made": [7, 12], "concrete_arg": 7, "is_fx_trac": 7, "symbolic_trac": 7, "graphmodul": 7, "symbol": 7, "record": [7, 11], "partial": 7, "structur": [7, 12], "predictfactorypackag": 8, "save_predict_factori": 8, "predict_factori": 8, "predictfactori": 8, "config": [8, 9, 10, 11], "pathlib": 8, "binaryio": 8, "extra_fil": 8, "loader_cod": 8, "nimport": 8, "packag": 8, "nmodule_factori": 8, "package_import": 8, "_sysimport": 8, "set_extern_modul": 8, "decor": 8, "abstractmethod": 8, "set_mocked_modul": 8, "load_config_text": 8, "load_pickle_config": 8, "clazz": 8, "batchingmetadata": 8, "pin": 8, "kept": [8, 11], "sync": [8, 9, 14], "batching_metadata": 8, "infom": 8, "batching_metadata_json": 8, "serial": 8, "json": 8, "eas": [8, 11], "pars": 8, "create_predict_modul": 8, "transformmodul": 8, "transform_state_dict": 8, "init_process_group": 8, "model_inputs_data": 8, "benchmark": 8, "qualname_metadata": 8, "qualnamemetadata": 8, "qualnam": 8, "inform": [8, 9, 14], "qualname_metadata_json": 8, "result_metadata": 8, "run_weights_dependent_transform": 8, "predict_modul": 8, "predict": [8, 9], "run_weights_independent_tranform": 8, "fx": 8, "predictmodul": 8, "predict_forward": 8, "primari": 8, "need_preproc": 8, "assign_weights_to_tb": 8, "table_to_weight": 8, "get_table_to_weights_from_tb": 8, "quantize_dens": 8, "additional_embedding_module_typ": 8, "quantize_embed": 8, "inplac": [8, 13], "additional_qconfig_spec_kei": 8, "additional_map": 8, "per_table_weight_dtyp": [8, 11], "quantize_featur": 8, "quantize_inference_model": 8, "quantization_map": 8, "fp_weight_dtyp": 8, "quantization_dtyp": 8, "swap": 8, "counterpart": 8, "quantembeddingbagcollect": [8, 13], "quantembeddingcollect": 8, "eb_config": 8, "dlrmpredictmodul": 8, "embedding_bag_collect": [8, 10, 11], "dense_in_featur": [8, 10], "model_config": 8, "dense_arch_layer_s": [8, 10], "over_arch_layer_s": [8, 10], "id_list_features_kei": 8, "dense_devic": [8, 10], "quant_model": 8, "set_pruning_data": 8, "tables_to_rows_post_prun": 8, "shard_quant_model": 8, "sharding_devic": 8, "device_memory_s": 8, "quantembeddingcollectionshard": 8, "tablewis": 8, "sharded_model": 8, "trim_torch_package_prefix_from_typenam": 8, "typenam": 8, "accuracymetr": 9, "task": 9, "rectaskinfo": 9, "compute_mod": 9, "reccomputemod": 9, "unfused_tasks_comput": 9, "window_s": 9, "fused_update_limit": 9, "compute_on_all_rank": 9, "should_validate_upd": 9, "process_group": 9, "recmetr": 9, "accuracymetriccomput": 9, "recmetriccomput": 9, "constructor": [9, 11], "cut": [9, 11], "off": [9, 11], "compute_accuraci": 9, "accuracy_sum": 9, "weighted_num_sampl": 9, "compute_accuracy_sum": 9, "get_accuracy_st": 9, "aucmetr": 9, "aucmetriccomput": 9, "grouped_auc": 9, "apply_bin": 9, "grouping_kei": 9, "reset": [9, 11, 12], "n_task": 9, "n_exampl": 9, "compute_auc": 9, "classif": 9, "compute_auc_per_group": 9, "auprcmetr": 9, "auprcmetriccomput": 9, "grouped_auprc": 9, "pr": 9, "compute_auprc": 9, "compute_auprc_per_group": 9, "calibrationmetr": 9, "calibrationmetriccomput": 9, "convers": 9, "compute_calibr": 9, "calibration_num": 9, "calibration_denom": 9, "get_calibration_st": 9, "ctrmetric": 9, "ctrmetriccomput": 9, "compute_ctr": 9, "ctr_num": 9, "ctr_denom": 9, "get_ctr_stat": 9, "maemetr": 9, "maemetriccomput": 9, "absolut": 9, "compute_error_sum": 9, "compute_ma": 9, "error_sum": 9, "get_mae_st": 9, "msemetr": 9, "msemetriccomput": 9, "squar": [9, 11], "compute_ms": 9, "compute_rms": 9, "get_mse_st": 9, "multiclassrecallmetr": 9, "multiclassrecallmetriccomput": 9, "compute_multiclass_recall_at_k": 9, "tp_at_k": 9, "total_weight": 9, "compute_true_positives_at_k": 9, "n_class": 9, "tp": 9, "1st": 9, "2nd": [9, 11], "n_sampl": 9, "ground": 9, "truth": 9, "true_positives_list": 9, "9": [9, 10], "15": [9, 10], "compute_multiclass_k_sum": 9, "5000": 9, "7500": 9, "0000": [9, 11], "get_multiclass_recall_st": 9, "ndcgcomput": 9, "exponential_gain": 9, "session_kei": 9, "session_id": 9, "report_ndcg_as_decreasing_curv": 9, "remove_single_length_sess": 9, "scale_by_weights_tensor": 9, "is_negative_task_mask": 9, "normal": [9, 11], "discount": 9, "gain": 9, "tensorboard": 9, "captur": 9, "decreas": 9, "loss": [9, 10, 12], "oppos": 9, "visual": [9, 14], "similarli": 9, "entropi": 9, "pointwis": 9, "noth": 9, "ndcgmetric": 9, "nemetr": 9, "nemetriccomput": 9, "include_logloss": 9, "allow_missing_label_with_zero_weight": 9, "vanilla": 9, "logloss": 9, "compute_cross_entropi": 9, "eta": 9, "compute_logloss": 9, "ce_sum": 9, "pos_label": 9, "neg_label": 9, "compute_n": 9, "get_ne_st": 9, "recallmetr": 9, "recallmetriccomput": 9, "compute_false_neg_sum": 9, "compute_recal": 9, "num_true_posit": 9, "num_false_negit": 9, "compute_true_pos_sum": 9, "get_recall_st": 9, "precisionmetr": 9, "precisionmetriccomput": 9, "compute_false_pos_sum": 9, "compute_precis": 9, "num_false_posit": 9, "get_precision_st": 9, "raucmetr": 9, "raucmetriccomput": 9, "grouped_rauc": 9, "regress": 9, "compute_rauc": 9, "compute_rauc_per_group": 9, "conquer_and_count": 9, "left_index": 9, "mid_index": 9, "right_index": 9, "count_reverse_pairs_divide_and_conqu": 9, "low": [9, 10, 11], "throughputmetr": 9, "window_second": 9, "warmup_step": 9, "batch_size_stag": 9, "batchsizestag": 9, "32": [9, 11], "time_to_train_one_step": 9, "trainer": 9, "window": 9, "window_throughput": 9, "warmup": 9, "Not": 9, "weightedavgmetr": 9, "weightedavgmetriccomput": 9, "get_mean": 9, "value_sum": 9, "num_sampl": 9, "xaucmetr": 9, "xaucmetriccomput": 9, "compute_weighted_num_pair": 9, "compute_xauc": 9, "weighted_num_pair": 9, "get_xauc_st": 9, "recmetricmodul": 9, "rec_task": 9, "recmetriclist": 9, "throughput_metr": 9, "state_metr": 9, "statemetr": 9, "compute_interval_step": 9, "min_compute_interv": 9, "max_compute_interv": 9, "inf": [9, 12], "memory_usage_limit_mb": 9, "standalon": 9, "characterist": 9, "componenet": 9, "intern": [9, 11, 14], "logic": [9, 11], "unit": [9, 11], "dataclass": [9, 11], "defaultmetricsconfig": 9, "statemetricenum": 9, "metricmodul": 9, "generate_metric_modul": 9, "metric_class": 9, "metrics_config": 9, "64": [9, 11], "state_metrics_map": 9, "mock_optim": 9, "check_memory_usag": 9, "compute_count": 9, "sink": 9, "get_memory_usag": 9, "get_required_input": 9, "last_compute_tim": 9, "local_comput": 9, "memory_usage_mb_avg": 9, "oom_count": 9, "should_comput": 9, "unsync": [9, 14], "model_out": 9, "model_output": 9, "due": 9, "slide": 9, "qat": 9, "get_metr": 9, "metricsconfig": 9, "metriccomputationreport": 9, "metrics_namespac": 9, "metricnamebas": 9, "metric_prefix": 9, "metricprefix": 9, "signal": 9, "own": 9, "__init__": 9, "_namespac": 9, "_metrics_comput": 9, "consum": 9, "invalid": 9, "defaulttaskinfo": 9, "rec": 9, "overwrit": 9, "synchron": [9, 14], "get_window_st": 9, "state_nam": 9, "get_window_state_nam": 9, "pre_comput": 9, "torchmetr": 9, "aggreg": 9, "recmetricexcept": 9, "encapul": 9, "required_input": 9, "windowbuff": 9, "max_siz": 9, "max_buffer_count": 9, "aggregate_st": 9, "window_st": 9, "curr_stat": 9, "dequ": 9, "architectur": [10, 11], "deep": [10, 11], "sparsearch": 10, "densearch": 10, "interactionarch": 10, "overarch": 10, "found": 10, "notat": 10, "embedding_dimens": 10, "hidden_layer_s": 10, "deepfmnn": 10, "dimension": 10, "dense_arch": 10, "dense_arch_input": 10, "dense_embed": 10, "fminteractionarch": 10, "fm_in_featur": 10, "sparse_feature_nam": 10, "deep_fm_dimens": 10, "paper": [10, 11], "arxiv": 10, "pdf": 10, "1703": 10, "04247": 10, "cat": [10, 11], "dense_modul": [10, 11], "di": 10, "arch": 10, "fm_inter_arch": 10, "length_per_kei": [10, 14], "cat_fm_output": 10, "mlp": 10, "over_arch": 10, "logit": 10, "simpledeepfmnn": 10, "num_dense_featur": 10, "relationship": 10, "deep_fm": 10, "eb1_config": [10, 13], "f3": 10, "eb2_config": [10, 13], "t2": [10, 11, 13], "sparse_nn": 10, "over_embedding_dim": 10, "from_offsets_sync": [10, 11, 13, 14], "sparse_arch": 10, "ab": 10, "1906": 10, "00091": 10, "pairwis": 10, "dlrmtrain": 10, "dlrm_modul": 10, "train_pipelin": 10, "dlrm_project": 10, "dlrm_dcn": 10, "ebc_config": 10, "dlrm_model": 10, "dcn_num_lay": 10, "dcn_low_rank_dim": 10, "dcn": [10, 11], "modifi": [10, 11, 12], "similar": 10, "deepcrossnet": 10, "2008": 10, "13535": 10, "approxim": 10, "interaction_branch1_layer_s": 10, "interaction_branch2_layer_s": 10, "branch": 10, "layer_s": [10, 11], "num_sparse_featur": 10, "dot": [10, 11], "pair": 10, "inter_arch": 10, "choos": 10, "concat_dens": 10, "interactiondcnarch": 10, "crossnet": 10, "lowrankcrossnet": [10, 11], "dnc_low_rank_dim": 10, "interactionprojectionarch": 10, "interaction_branch1": 10, "interaction_branch2": 10, "z": 10, "bx": 10, "f1xd": 10, "dxf2": 10, "i1": 10, "i2": 10, "sparse_embed": 10, "math": 10, "comb": 10, "extens": 11, "establish": 11, "pattern": 11, "swishlayernorm": 11, "positionweightedmodul": 11, "lazymoduleextensionmixin": 11, "embeddingtow": 11, "embeddingtowercollect": 11, "input_dim": 11, "swish": 11, "sigmoid": 11, "layernorm": 11, "d1": 11, "d2": 11, "d3": 11, "sln": 11, "num_lay": 11, "stack": 11, "learnabl": 11, "polynom": 11, "nxn": 11, "cover": 11, "bit": 11, "x_": 11, "x_0": 11, "w_l": 11, "cdot": 11, "x_l": 11, "b_l": 11, "low_rank": 11, "highli": 11, "matric": 11, "simplifi": 11, "v_l": 11, "vector": 11, "smartli": 11, "setup": 11, "lowrankmixturecrossnet": 11, "num_expert": 11, "relu": 11, "mixtur": 11, "expert": 11, "compar": [11, 14], "subspac": 11, "gate": 11, "moe": 11, "expert_i": 11, "k_": 11, "u_": 11, "li": 11, "c_": 11, "v_": 11, "vectorcrossnet": 11, "nx1": 11, "thu": [11, 12], "further": [11, 14], "implent": 11, "framework": 11, "factorizationmachin": 11, "fm": 11, "publish": 11, "learnt": 11, "To": 11, "90": 11, "30": 11, "40": 11, "fb": 11, "lazymlp": 11, "output_dim": 11, "192": 11, "deep_fm_output": 11, "common_spars": 11, "specialized_spars": 11, "embedding_featur": 11, "raw_embedding_featur": 11, "nativ": 11, "trained_embed": 11, "native_embed": 11, "ident": 11, "mention": 11, "baseembeddingconfig": 11, "min": 11, "prune": 11, "get_weight_init_max": 11, "get_weight_init_min": 11, "meant": [11, 12], "embeddingconfig": [11, 13], "quantconfig": 11, "placeholderobserv": [11, 13], "alia": 11, "data_type_to_dtyp": 11, "data_type_to_sparse_typ": 11, "sparsetyp": 11, "dtype_to_data_typ": 11, "pooling_type_to_pooling_mod": 11, "pooling_typ": 11, "poolingmod": 11, "pooling_type_to_str": 11, "sensit": [11, 13], "jag": [11, 13, 14], "table_0": [11, 13], "table_1": [11, 13], "pooled_embed": 11, "8899": 11, "1342": 11, "9060": 11, "0905": 11, "2814": 11, "9369": 11, "7783": 11, "1598": 11, "0695": 11, "3265": 11, "1011": 11, "4256": 11, "1846": 11, "1648": 11, "0893": 11, "3590": 11, "9784": 11, "7681": 11, "grad_fn": [11, 13], "catbackward0": 11, "offset_per_kei": [11, 14], "intiial": 11, "need_indic": [11, 13], "e1_config": [11, 13], "e2_config": [11, 13], "ec": [11, 13], "feature_embed": [11, 13], "2050": [11, 13], "5478": [11, 13], "6054": [11, 13], "7352": [11, 13], "3210": [11, 13], "0399": [11, 13], "1279": [11, 13], "1756": [11, 13], "4130": [11, 13], "7519": [11, 13], "4341": [11, 13], "0499": [11, 13], "9329": [11, 13], "0697": [11, 13], "8095": [11, 13], "embeddingbackward": [11, 13], "embedding_names_by_t": [11, 13], "get_embedding_names_by_t": 11, "process_pooled_embed": 11, "reorder_inverse_indic": 11, "basefeatureprocessor": 11, "max_length": 11, "truncat": 11, "positionweightedprocessor": 11, "feature_length": 11, "feature0": [11, 14], "feature1": [11, 14], "feature2": 11, "from_lengths_sync": [11, 14], "pw": 11, "featureprocessorcollect": 11, "feature_processor_modul": 11, "positionweightedfeatureprocessor": 11, "fp_featur": 11, "non_fp_featur": 11, "non_fp": 11, "feature_process": 11, "And": 11, "offsets_to_range_tracebl": 11, "position_weighted_module_update_featur": 11, "weighted_featur": 11, "lazymodulemixin": 11, "upstream": 11, "59923": 11, "testlazymoduleextensionmixin": 11, "_infer_paramet": 11, "pariti": 11, "_call_impl": 11, "fn": 11, "children": 11, "uniniti": 11, "dummi": [11, 12], "lazylinear": 11, "fail": [11, 14], "hasn": [11, 14], "yet": [11, 14], "now": [11, 14], "lazy_appli": 11, "attach": 11, "numer": 11, "immedi": 11, "seq": 11, "in_siz": 11, "perceptron": 11, "multi": 11, "out_siz": 11, "swish_layernorm": 11, "mlp_modul": 11, "assert": 11, "o": 11, "channel": 11, "unpadded_length": 11, "reindexed_length": 11, "reindexed_length_per_kei": 11, "reindexed_valu": 11, "check_module_output_dimens": 11, "verifi": 11, "construct_jagged_tensor": 11, "features_to_permute_indic": 11, "original_featur": 11, "construct_jagged_tensors_infer": 11, "construct_modulelist_from_single_modul": 11, "nest": 11, "reiniti": 11, "convert_list_of_modules_to_modulelist": 11, "deterministic_dedup": 11, "race": 11, "conflict": 11, "extract_module_or_tensor_cal": 11, "module_or_cal": 11, "get_module_output_dimens": 11, "init_mlp_weights_xavier_uniform": 11, "jagged_index_select_with_empti": 11, "output_offset": 11, "distancelfu_evictionpolici": 11, "decay_expon": 11, "threshold_filtering_func": 11, "mchevictionpolici": 11, "coalesce_history_metadata": 11, "current_it": 11, "history_metadata": 11, "unique_ids_count": 11, "unique_inverse_map": 11, "additional_id": 11, "threshold_mask": 11, "histori": 11, "invers": [11, 14], "history_accumul": 11, "coalesc": 11, "metadata_info": 11, "mchevictionpolicymetadatainfo": 11, "record_history_metadata": 11, "incoming_id": 11, "incom": 11, "polici": [11, 12], "update_metadata_and_generate_eviction_scor": 11, "mch_size": 11, "coalesced_history_argsort_map": 11, "coalesced_history_sorted_unique_ids_count": 11, "coalesced_history_mch_matching_elements_mask": 11, "coalesced_history_mch_matching_indic": 11, "mch_metadata": 11, "coalesced_history_metadata": 11, "evicted_indic": 11, "selected_new_indic": 11, "mch": 11, "lfu_evictionpolici": 11, "lru_evictionpolici": 11, "metadata_nam": 11, "is_mch_metadata": 11, "is_history_metadata": 11, "mchmanagedcollisionmodul": 11, "zch_size": 11, "eviction_polici": 11, "eviction_interv": 11, "input_hash_s": 11, "9223372036854775807": 11, "input_hash_func": 11, "mch_hash_func": 11, "output_global_offset": 11, "output_seg": 11, "managedcollisionmodul": 11, "zch": 11, "collis": 11, "output_size_offset": 11, "drive": 11, "greater": 11, "depreci": 11, "residu": 11, "legaci": 11, "shift": 11, "zch_output_rang": 11, "down": 11, "applic": 11, "slot": 11, "assumptionn": 11, "downstream": 11, "rtype": 11, "vs": 11, "profil": 11, "rebuild_with_output_id_rang": 11, "output_id_rang": 11, "mc": 11, "validate_st": 11, "checkpoint": [11, 12], "managed_collision_modul": 11, "need_preprocess": 11, "mcc": 11, "embedding_confg": 11, "collsion": 11, "skip_state_valid": 11, "max_output_id": 11, "remapping_range_start_index": 11, "mcm": 11, "mcm_jt": 11, "fp": 11, "apply_mc_method_to_jt_dict": 11, "features_dict": 11, "average_threshold_filt": 11, "id_count": 11, "dynamic_threshold_filt": 11, "threshold_skew_multipli": 11, "total_count": 11, "num_id": 11, "probabilistic_threshold_filt": 11, "per_id_prob": 11, "01": 11, "probabl": 11, "60": 11, "randomli": 11, "chanc": 11, "basemanagedcollisionembeddingcollect": 11, "managed_collision_collect": 11, "return_remapped_featur": 11, "embedding_collect": 11, "meaning": 12, "prohibit": 12, "empti": [12, 14], "sever": 12, "combinedoptim": 12, "optimizerwrapp": 12, "rowwis": 12, "gradientclip": 12, "norm": 12, "gradientclippingoptim": 12, "max_gradi": 12, "norm_typ": 12, "enable_global_grad_clip": 12, "param_to_pg": 12, "p": 12, "clip_grad_norm_": 12, "closur": 12, "reevalu": 12, "emptyfusedoptim": 12, "fusedoptim": 12, "zero_grad": 12, "set_to_non": 12, "zero": [12, 14], "footprint": 12, "modestli": 12, "certain": 12, "0s": 12, "behav": 12, "did": 12, "altogeth": 12, "param_group": 12, "post_load_state_dict": 12, "prepend_opt_kei": 12, "opt_kei": 12, "save_param_group": 12, "set_optimizer_step": 12, "stricter": 12, "switch": 12, "flag": 12, "identifi": 12, "littl": 12, "add_param_group": 12, "fine": 12, "frozen": 12, "trainabl": 12, "progress": 12, "what": 12, "init_st": 12, "introduc": 12, "usabl": 12, "sd": 12, "load_checkpoint": 12, "protocol": 12, "keyedoptimizerwrapp": 12, "optim_factori": 12, "conveni": 12, "warmupoptim": 12, "warmupstag": 12, "lr": 12, "lr_param": 12, "param_nam": 12, "__warmup": 12, "adjust": 12, "schedul": 12, "fake": 12, "warmuppolici": 12, "constant": 12, "cosine_annealing_warm_restart": 12, "invsqrt": 12, "inv_sqrt": 12, "poli": 12, "max_it": 12, "lr_scale": 12, "decay_it": 12, "sgdr_period": 12, "trec_quant": 13, "trec": 13, "qconfig": 13, "activ": 13, "with_arg": 13, "qint8": 13, "quantize_dynam": 13, "qconfig_spec": 13, "table_name_to_quantized_weight": 13, "register_tb": 13, "quant_state_dict_split_scale_bia": 13, "row_align": 13, "qebc": 13, "from_float": 13, "quantized_embed": 13, "use_precomputed_fake_qu": 13, "for_each_module_of_type_do": 13, "quant_prep_customize_row_align": 13, "quant_prep_enable_quant_state_dict_split_scale_bia": 13, "quant_prep_enable_quant_state_dict_split_scale_bias_for_typ": 13, "quant_prep_enable_register_tb": 13, "quantize_state_dict": 13, "table_name_to_data_typ": 13, "table_name_to_num_embeddings_post_prun": 13, "whose": 14, "dimes": 14, "computejtdicttokjt": 14, "dim_1": 14, "dim_0": 14, "computekjttojtdict": 14, "keyed_jagged_tensor": 14, "jit": 14, "abl": 14, "NOT": 14, "expens": 14, "values_dtyp": 14, "weights_dtyp": 14, "lengths_dtyp": 14, "from_dens": 14, "2d": 14, "11": 14, "12": 14, "j1": 14, "from_dense_length": 14, "lengths_or_non": 14, "offsets_or_non": 14, "move": 14, "to_dens": 14, "inttensor": 14, "values_list": 14, "to_dense_weight": 14, "weights_list": 14, "to_padded_dens": 14, "desired_length": 14, "padding_valu": 14, "longest": 14, "pad": 14, "dt": 14, "to_padded_dense_weight": 14, "d_wt": 14, "throw": 14, "weights_or_non": 14, "jaggedtensormeta": 14, "abcmeta": 14, "proxyableclassmeta": 14, "stride_per_key_per_rank": 14, "outer": 14, "inner": 14, "index_per_kei": 14, "to_dict": 14, "expand": 14, "dedupl": 14, "dim_2": 14, "w0": 14, "w1": 14, "w2": 14, "w3": 14, "w4": 14, "w5": 14, "w6": 14, "w7": 14, "dist_init": 14, "variable_stride_per_kei": 14, "dist_label": 14, "dist_split": 14, "key_split": 14, "dist_tensor": 14, "empty_lik": 14, "flatten_length": 14, "from_jt_dict": 14, "newli": 14, "implicit": 14, "variable_feature_dim": 14, "But": 14, "That": 14, "didn": 14, "correctli": 14, "technic": 14, "know": 14, "violat": 14, "precondit": 14, "inverse_indices_or_non": 14, "length_per_key_or_non": 14, "lengths_offset_per_kei": 14, "offset_per_key_or_non": 14, "indices_tensor": 14, "segment": 14, "stride_per_kei": 14, "block": 14, "_jt_dict": 14, "clear": 14, "key_dim": 14, "tensor_list": 14, "from_tensor_list": 14, "regroup": 14, "keyed_tensor": 14, "regroup_as_dict": 14, "flatten_kjt_list": 14, "kjt_arr": 14, "jt_is_equ": 14, "jt_1": 14, "jt_2": 14, "comparison": 14, "themselv": 14, "treat": 14, "kjt_is_equ": 14, "kjt_1": 14, "kjt_2": 14, "permute_multi_embed": 14, "regroup_kt": 14, "unflatten_kjt_list": 14}, "objects": {"torchrec": [[2, 0, 0, "-", "datasets"], [4, 0, 0, "-", "distributed"], [7, 0, 0, "module-0", "fx"], [8, 0, 0, "module-0", "inference"], [9, 0, 0, "-", "metrics"], [10, 0, 0, "module-0", "models"], [11, 0, 0, "-", "modules"], [12, 0, 0, "module-0", "optim"], [13, 0, 0, "module-0", "quant"], [14, 0, 0, "module-0", "sparse"]], "torchrec.datasets": [[2, 0, 0, "-", "criteo"], [2, 0, 0, "-", "movielens"], [2, 0, 0, "-", "random"], [3, 0, 0, "-", "scripts"], [2, 0, 0, "-", "utils"]], "torchrec.datasets.criteo": [[2, 1, 1, "", "BinaryCriteoUtils"], [2, 1, 1, "", "CriteoIterDataPipe"], [2, 1, 1, "", "InMemoryBinaryCriteoIterDataPipe"], [2, 3, 1, "", "criteo_kaggle"], [2, 3, 1, "", "criteo_terabyte"]], "torchrec.datasets.criteo.BinaryCriteoUtils": [[2, 2, 1, "", "get_file_row_ranges_and_remainder"], [2, 2, 1, "", "get_shape_from_npy"], [2, 2, 1, "", "load_npy_range"], [2, 2, 1, "", "shuffle"], [2, 2, 1, "", "sparse_to_contiguous"], [2, 2, 1, "", "tsv_to_npys"]], "torchrec.datasets.movielens": [[2, 3, 1, "", "movielens_20m"], [2, 3, 1, "", "movielens_25m"]], "torchrec.datasets.random": [[2, 1, 1, "", "RandomRecDataset"]], "torchrec.datasets.scripts": [[3, 0, 0, "-", "contiguous_preproc_criteo"], [3, 0, 0, "-", "npy_preproc_criteo"]], "torchrec.datasets.scripts.contiguous_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.scripts.npy_preproc_criteo": [[3, 3, 1, "", "main"], [3, 3, 1, "", "parse_args"]], "torchrec.datasets.utils": [[2, 1, 1, "", "Batch"], [2, 1, 1, "", "Limit"], [2, 1, 1, "", "LoadFiles"], [2, 1, 1, "", "ParallelReadConcat"], [2, 1, 1, "", "ReadLinesFromCSV"], [2, 3, 1, "", "idx_split_train_val"], [2, 3, 1, "", "rand_split_train_val"], [2, 3, 1, "", "safe_cast"], [2, 3, 1, "", "train_filter"], [2, 3, 1, "", "val_filter"]], "torchrec.datasets.utils.Batch": [[2, 4, 1, "", "dense_features"], [2, 4, 1, "", "labels"], [2, 2, 1, "", "pin_memory"], [2, 2, 1, "", "record_stream"], [2, 4, 1, "", "sparse_features"], [2, 2, 1, "", "to"]], "torchrec.distributed": [[4, 0, 0, "-", "collective_utils"], [4, 0, 0, "-", "comm"], [4, 0, 0, "-", "comm_ops"], [6, 0, 0, "-", "dist_data"], [4, 0, 0, "-", "embedding"], [4, 0, 0, "-", "embedding_lookup"], [4, 0, 0, "-", "embedding_sharding"], [4, 0, 0, "-", "embedding_types"], [4, 0, 0, "-", "embeddingbag"], [4, 0, 0, "-", "grouped_position_weighted"], [4, 0, 0, "-", "mc_embedding"], [4, 0, 0, "-", "mc_embeddingbag"], [4, 0, 0, "-", "mc_modules"], [4, 0, 0, "-", "model_parallel"], [5, 0, 0, "-", "planner"], [4, 0, 0, "-", "quant_embeddingbag"], [6, 0, 0, "-", "sharding"], [4, 0, 0, "-", "train_pipeline"], [4, 0, 0, "-", "types"], [4, 0, 0, "-", "utils"]], "torchrec.distributed.collective_utils": [[4, 3, 1, "", "invoke_on_rank_and_broadcast_result"], [4, 3, 1, "", "is_leader"], [4, 3, 1, "", "run_on_leader"]], "torchrec.distributed.comm": [[4, 3, 1, "", "get_group_rank"], [4, 3, 1, "", "get_local_rank"], [4, 3, 1, "", "get_local_size"], [4, 3, 1, "", "get_num_groups"], [4, 3, 1, "", "intra_and_cross_node_pg"]], "torchrec.distributed.comm_ops": [[4, 1, 1, "", "All2AllDenseInfo"], [4, 1, 1, "", "All2AllPooledInfo"], [4, 1, 1, "", "All2AllSequenceInfo"], [4, 1, 1, "", "All2AllVInfo"], [4, 1, 1, "", "All2All_Pooled_Req"], [4, 1, 1, "", "All2All_Pooled_Wait"], [4, 1, 1, "", "All2All_Seq_Req"], [4, 1, 1, "", "All2All_Seq_Req_Wait"], [4, 1, 1, "", "All2Allv_Req"], [4, 1, 1, "", "All2Allv_Wait"], [4, 1, 1, "", "AllGatherBaseInfo"], [4, 1, 1, "", "AllGatherBase_Req"], [4, 1, 1, "", "AllGatherBase_Wait"], [4, 1, 1, "", "ReduceScatterBaseInfo"], [4, 1, 1, "", "ReduceScatterBase_Req"], [4, 1, 1, "", "ReduceScatterBase_Wait"], [4, 1, 1, "", "ReduceScatterInfo"], [4, 1, 1, "", "ReduceScatterVInfo"], [4, 1, 1, "", "ReduceScatterV_Req"], [4, 1, 1, "", "ReduceScatterV_Wait"], [4, 1, 1, "", "ReduceScatter_Req"], [4, 1, 1, "", "ReduceScatter_Wait"], [4, 1, 1, "", "Request"], [4, 1, 1, "", "VariableBatchAll2AllPooledInfo"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Req"], [4, 1, 1, "", "Variable_Batch_All2All_Pooled_Wait"], [4, 3, 1, "", "all2all_pooled_sync"], [4, 3, 1, "", "all2all_sequence_sync"], [4, 3, 1, "", "all2allv_sync"], [4, 3, 1, "", "all_gather_base_pooled"], [4, 3, 1, "", "all_gather_base_sync"], [4, 3, 1, "", "all_gather_into_tensor_backward"], [4, 3, 1, "", "all_gather_into_tensor_fake"], [4, 3, 1, "", "all_gather_into_tensor_setup_context"], [4, 3, 1, "", "all_to_all_single_backward"], [4, 3, 1, "", "all_to_all_single_fake"], [4, 3, 1, "", "all_to_all_single_setup_context"], [4, 3, 1, "", "alltoall_pooled"], [4, 3, 1, "", "alltoall_sequence"], [4, 3, 1, "", "alltoallv"], [4, 3, 1, "", "get_gradient_division"], [4, 3, 1, "", "get_use_sync_collectives"], [4, 3, 1, "", "pg_name"], [4, 3, 1, "", "reduce_scatter_base_pooled"], [4, 3, 1, "", "reduce_scatter_base_sync"], [4, 3, 1, "", "reduce_scatter_pooled"], [4, 3, 1, "", "reduce_scatter_sync"], [4, 3, 1, "", "reduce_scatter_tensor_backward"], [4, 3, 1, "", "reduce_scatter_tensor_fake"], [4, 3, 1, "", "reduce_scatter_tensor_setup_context"], [4, 3, 1, "", "reduce_scatter_v_per_feature_pooled"], [4, 3, 1, "", "reduce_scatter_v_pooled"], [4, 3, 1, "", "reduce_scatter_v_sync"], [4, 3, 1, "", "set_gradient_division"], [4, 3, 1, "", "set_use_sync_collectives"], [4, 3, 1, "", "torchrec_use_sync_collectives"], [4, 3, 1, "", "variable_batch_all2all_pooled_sync"], [4, 3, 1, "", "variable_batch_alltoall_pooled"]], "torchrec.distributed.comm_ops.All2AllDenseInfo": [[4, 4, 1, "", "batch_size"], [4, 4, 1, "", "input_shape"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "output_splits"]], "torchrec.distributed.comm_ops.All2AllPooledInfo": [[4, 4, 1, "id0", "batch_size_per_rank"], [4, 4, 1, "id1", "codecs"], [4, 4, 1, "id2", "cumsum_dim_sum_per_rank_tensor"], [4, 4, 1, "id3", "dim_sum_per_rank"], [4, 4, 1, "id4", "dim_sum_per_rank_tensor"]], "torchrec.distributed.comm_ops.All2AllSequenceInfo": [[4, 4, 1, "id5", "backward_recat_tensor"], [4, 4, 1, "id6", "codecs"], [4, 4, 1, "id7", "embedding_dim"], [4, 4, 1, "id8", "forward_recat_tensor"], [4, 4, 1, "id9", "input_splits"], [4, 4, 1, "id10", "lengths_after_sparse_data_all2all"], [4, 4, 1, "id11", "output_splits"], [4, 4, 1, "id12", "permuted_lengths_after_sparse_data_all2all"], [4, 4, 1, "id13", "variable_batch_size"]], "torchrec.distributed.comm_ops.All2AllVInfo": [[4, 4, 1, "id14", "B_global"], [4, 4, 1, "id15", "B_local"], [4, 4, 1, "id16", "B_local_list"], [4, 4, 1, "id17", "D_local_list"], [4, 4, 1, "", "codecs"], [4, 4, 1, "", "dim_sum_per_rank"], [4, 4, 1, "", "dims_sum_per_rank"], [4, 4, 1, "id18", "input_split_sizes"], [4, 4, 1, "id19", "output_split_sizes"]], "torchrec.distributed.comm_ops.All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2All_Seq_Req_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.All2Allv_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id20", "input_size"]], "torchrec.distributed.comm_ops.AllGatherBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.AllGatherBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBaseInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id21", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterBase_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterInfo": [[4, 4, 1, "", "codecs"], [4, 4, 1, "id22", "input_sizes"]], "torchrec.distributed.comm_ops.ReduceScatterVInfo": [[4, 4, 1, "id23", "codecs"], [4, 4, 1, "id24", "equal_splits"], [4, 4, 1, "id25", "input_sizes"], [4, 4, 1, "id26", "input_splits"], [4, 4, 1, "id27", "total_input_size"]], "torchrec.distributed.comm_ops.ReduceScatterV_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatterV_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.ReduceScatter_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.VariableBatchAll2AllPooledInfo": [[4, 4, 1, "id28", "batch_size_per_feature_pre_a2a"], [4, 4, 1, "id29", "batch_size_per_rank_per_feature"], [4, 4, 1, "id30", "codecs"], [4, 4, 1, "id31", "emb_dim_per_rank_per_feature"], [4, 4, 1, "id32", "input_splits"], [4, 4, 1, "id33", "output_splits"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Req": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.comm_ops.Variable_Batch_All2All_Pooled_Wait": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.dist_data": [[6, 1, 1, "", "EmbeddingsAllToOne"], [6, 1, 1, "", "EmbeddingsAllToOneReduce"], [6, 1, 1, "", "JaggedTensorAllToAll"], [6, 1, 1, "", "KJTAllToAll"], [6, 1, 1, "", "KJTAllToAllSplitsAwaitable"], [6, 1, 1, "", "KJTAllToAllTensorsAwaitable"], [6, 1, 1, "", "KJTOneToAll"], [6, 1, 1, "", "MergePooledEmbeddingsModule"], [6, 1, 1, "", "PooledEmbeddingsAllGather"], [6, 1, 1, "", "PooledEmbeddingsAllToAll"], [6, 1, 1, "", "PooledEmbeddingsAwaitable"], [6, 1, 1, "", "PooledEmbeddingsReduceScatter"], [6, 1, 1, "", "SeqEmbeddingsAllToOne"], [6, 1, 1, "", "SequenceEmbeddingsAllToAll"], [6, 1, 1, "", "SequenceEmbeddingsAwaitable"], [6, 1, 1, "", "SplitsAllToAllAwaitable"], [6, 1, 1, "", "TensorAllToAll"], [6, 1, 1, "", "TensorAllToAllSplitsAwaitable"], [6, 1, 1, "", "TensorAllToAllValuesAwaitable"], [6, 1, 1, "", "TensorValuesAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsAllToAll"], [6, 1, 1, "", "VariableBatchPooledEmbeddingsReduceScatter"]], "torchrec.distributed.dist_data.EmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.EmbeddingsAllToOneReduce": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.KJTOneToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.MergePooledEmbeddingsModule": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllGather": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.PooledEmbeddingsAwaitable": [[6, 5, 1, "", "callbacks"]], "torchrec.distributed.dist_data.PooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SeqEmbeddingsAllToOne": [[6, 2, 1, "", "forward"], [6, 2, 1, "", "set_device"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.SequenceEmbeddingsAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.TensorValuesAllToAll": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsAllToAll": [[6, 5, 1, "", "callbacks"], [6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.dist_data.VariableBatchPooledEmbeddingsReduceScatter": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.embedding": [[4, 1, 1, "", "EmbeddingCollectionAwaitable"], [4, 1, 1, "", "EmbeddingCollectionContext"], [4, 1, 1, "", "EmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedEmbeddingCollection"], [4, 3, 1, "", "create_embedding_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_ec_index_dedup"], [4, 3, 1, "", "pad_vbe_kjt_lengths"], [4, 3, 1, "", "set_ec_index_dedup"]], "torchrec.distributed.embedding.EmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding.EmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.embedding.ShardedEmbeddingCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup": [[4, 1, 1, "", "CommOpGradientScaling"], [4, 1, 1, "", "GroupedEmbeddingsLookup"], [4, 1, 1, "", "GroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "InferCPUGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedEmbeddingsLookup"], [4, 1, 1, "", "InferGroupedLookupMixin"], [4, 1, 1, "", "InferGroupedPooledEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedEmbeddingsLookup"], [4, 1, 1, "", "MetaInferGroupedPooledEmbeddingsLookup"], [4, 3, 1, "", "dummy_tensor"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle"], [4, 3, 1, "", "embeddings_cat_empty_rank_handle_inference"], [4, 3, 1, "", "fx_wrap_tensor_view2d"]], "torchrec.distributed.embedding_lookup.CommOpGradientScaling": [[4, 2, 1, "", "backward"], [4, 2, 1, "", "forward"]], "torchrec.distributed.embedding_lookup.GroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.GroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "named_parameters_by_table"], [4, 2, 1, "", "prefetch"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferCPUGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedEmbeddingsLookup": [[4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.InferGroupedLookupMixin": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"]], "torchrec.distributed.embedding_lookup.InferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_lookup.MetaInferGroupedPooledEmbeddingsLookup": [[4, 2, 1, "", "flush"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "get_tbes_to_register"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "purge"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding": [[4, 1, 1, "", "BaseEmbeddingDist"], [4, 1, 1, "", "BaseSparseFeaturesDist"], [4, 1, 1, "", "EmbeddingSharding"], [4, 1, 1, "", "EmbeddingShardingContext"], [4, 1, 1, "", "EmbeddingShardingInfo"], [4, 1, 1, "", "FusedKJTListSplitsAwaitable"], [4, 1, 1, "", "KJTListAwaitable"], [4, 1, 1, "", "KJTListSplitsAwaitable"], [4, 1, 1, "", "KJTSplitsAllToAllMeta"], [4, 1, 1, "", "ListOfKJTListAwaitable"], [4, 1, 1, "", "ListOfKJTListSplitsAwaitable"], [4, 3, 1, "", "bucketize_kjt_before_all2all"], [4, 3, 1, "", "bucketize_kjt_inference"], [4, 3, 1, "", "group_tables"]], "torchrec.distributed.embedding_sharding.BaseEmbeddingDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.BaseSparseFeaturesDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_sharding.EmbeddingSharding": [[4, 2, 1, "", "create_input_dist"], [4, 2, 1, "", "create_lookup"], [4, 2, 1, "", "create_output_dist"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_names_per_rank"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 2, 1, "", "embedding_tables"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "uncombined_embedding_dims"], [4, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_sharding.EmbeddingShardingInfo": [[4, 4, 1, "", "embedding_config"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "param"], [4, 4, 1, "", "param_sharding"]], "torchrec.distributed.embedding_sharding.KJTSplitsAllToAllMeta": [[4, 4, 1, "", "device"], [4, 4, 1, "", "input_splits"], [4, 4, 1, "", "input_tensors"], [4, 4, 1, "", "keys"], [4, 4, 1, "", "labels"], [4, 4, 1, "", "pg"], [4, 4, 1, "", "splits"], [4, 4, 1, "", "splits_tensors"], [4, 4, 1, "", "stagger"]], "torchrec.distributed.embedding_types": [[4, 1, 1, "", "BaseEmbeddingLookup"], [4, 1, 1, "", "BaseEmbeddingSharder"], [4, 1, 1, "", "BaseGroupedFeatureProcessor"], [4, 1, 1, "", "BaseQuantEmbeddingSharder"], [4, 1, 1, "", "DTensorMetadata"], [4, 1, 1, "", "EmbeddingAttributes"], [4, 1, 1, "", "EmbeddingComputeKernel"], [4, 1, 1, "", "FeatureShardingMixIn"], [4, 1, 1, "", "GroupedEmbeddingConfig"], [4, 1, 1, "", "InputDistOutputs"], [4, 1, 1, "", "KJTList"], [4, 1, 1, "", "ListOfKJTList"], [4, 1, 1, "", "ModuleShardingMixIn"], [4, 1, 1, "", "OptimType"], [4, 1, 1, "", "ShardedConfig"], [4, 1, 1, "", "ShardedEmbeddingModule"], [4, 1, 1, "", "ShardedEmbeddingTable"], [4, 1, 1, "", "ShardedMetaConfig"], [4, 3, 1, "", "compute_kernel_to_embedding_location"]], "torchrec.distributed.embedding_types.BaseEmbeddingLookup": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.BaseGroupedFeatureProcessor": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.BaseQuantEmbeddingSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "fused_params"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.embedding_types.DTensorMetadata": [[4, 4, 1, "", "mesh"], [4, 4, 1, "", "placements"], [4, 4, 1, "", "size"], [4, 4, 1, "", "stride"]], "torchrec.distributed.embedding_types.EmbeddingAttributes": [[4, 4, 1, "", "compute_kernel"]], "torchrec.distributed.embedding_types.EmbeddingComputeKernel": [[4, 4, 1, "", "DENSE"], [4, 4, 1, "", "FUSED"], [4, 4, 1, "", "FUSED_UVM"], [4, 4, 1, "", "FUSED_UVM_CACHING"], [4, 4, 1, "", "KEY_VALUE"], [4, 4, 1, "", "QUANT"], [4, 4, 1, "", "QUANT_UVM"], [4, 4, 1, "", "QUANT_UVM_CACHING"]], "torchrec.distributed.embedding_types.FeatureShardingMixIn": [[4, 2, 1, "", "feature_names"], [4, 2, 1, "", "feature_names_per_rank"], [4, 2, 1, "", "features_per_rank"]], "torchrec.distributed.embedding_types.GroupedEmbeddingConfig": [[4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "data_type"], [4, 2, 1, "", "dim_sum"], [4, 2, 1, "", "embedding_dims"], [4, 2, 1, "", "embedding_names"], [4, 2, 1, "", "embedding_shard_metadata"], [4, 4, 1, "", "embedding_tables"], [4, 2, 1, "", "feature_hash_sizes"], [4, 2, 1, "", "feature_names"], [4, 4, 1, "", "fused_params"], [4, 4, 1, "", "has_feature_processor"], [4, 4, 1, "", "is_weighted"], [4, 2, 1, "", "num_features"], [4, 4, 1, "", "pooling"], [4, 2, 1, "", "table_names"]], "torchrec.distributed.embedding_types.InputDistOutputs": [[4, 4, 1, "", "bucket_mapping_tensor"], [4, 4, 1, "", "bucketized_length"], [4, 4, 1, "", "features"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "unbucketize_permute_tensor"]], "torchrec.distributed.embedding_types.KJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ListOfKJTList": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.embedding_types.ModuleShardingMixIn": [[4, 5, 1, "", "shardings"]], "torchrec.distributed.embedding_types.OptimType": [[4, 4, 1, "", "ADAGRAD"], [4, 4, 1, "", "ADAM"], [4, 4, 1, "", "ADAMW"], [4, 4, 1, "", "LAMB"], [4, 4, 1, "", "LARS_SGD"], [4, 4, 1, "", "LION"], [4, 4, 1, "", "PARTIAL_ROWWISE_ADAM"], [4, 4, 1, "", "PARTIAL_ROWWISE_LAMB"], [4, 4, 1, "", "ROWWISE_ADAGRAD"], [4, 4, 1, "", "SGD"], [4, 4, 1, "", "SHAMPOO"], [4, 4, 1, "", "SHAMPOO_MRS"], [4, 4, 1, "", "SHAMPOO_V2"], [4, 4, 1, "", "SHAMPOO_V2_MRS"]], "torchrec.distributed.embedding_types.ShardedConfig": [[4, 4, 1, "", "local_cols"], [4, 4, 1, "", "local_rows"]], "torchrec.distributed.embedding_types.ShardedEmbeddingModule": [[4, 2, 1, "", "extra_repr"], [4, 2, 1, "", "prefetch"], [4, 4, 1, "", "training"]], "torchrec.distributed.embedding_types.ShardedEmbeddingTable": [[4, 4, 1, "", "fused_params"]], "torchrec.distributed.embedding_types.ShardedMetaConfig": [[4, 4, 1, "", "dtensor_metadata"], [4, 4, 1, "", "global_metadata"], [4, 4, 1, "", "local_metadata"]], "torchrec.distributed.embeddingbag": [[4, 1, 1, "", "EmbeddingAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionAwaitable"], [4, 1, 1, "", "EmbeddingBagCollectionContext"], [4, 1, 1, "", "EmbeddingBagCollectionSharder"], [4, 1, 1, "", "EmbeddingBagSharder"], [4, 1, 1, "", "ShardedEmbeddingBag"], [4, 1, 1, "", "ShardedEmbeddingBagCollection"], [4, 1, 1, "", "VariableBatchEmbeddingBagCollectionAwaitable"], [4, 3, 1, "", "construct_output_kt"], [4, 3, 1, "", "create_embedding_bag_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding"], [4, 3, 1, "", "create_sharding_infos_by_sharding_device_group"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "replace_placement_with_meta_device"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionContext": [[4, 4, 1, "", "divisor"], [4, 4, 1, "", "inverse_indices"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "sharding_contexts"], [4, 4, 1, "", "variable_batch_per_feature"]], "torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.EmbeddingBagSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBag": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "load_state_dict"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_modules"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "reset_parameters"], [4, 4, 1, "", "training"]], "torchrec.distributed.grouped_position_weighted": [[4, 1, 1, "", "GroupedPositionWeightedModule"]], "torchrec.distributed.grouped_position_weighted.GroupedPositionWeightedModule": [[4, 2, 1, "", "forward"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embedding": [[4, 1, 1, "", "ManagedCollisionEmbeddingCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingCollection"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.mc_embedding.ManagedCollisionEmbeddingCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embedding.ShardedManagedCollisionEmbeddingCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_embeddingbag": [[4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionContext"], [4, 1, 1, "", "ManagedCollisionEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionEmbeddingBagCollection"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionContext": [[4, 4, 1, "", "evictions_per_table"], [4, 2, 1, "", "record_stream"], [4, 4, 1, "", "remapped_kjt"]], "torchrec.distributed.mc_embeddingbag.ManagedCollisionEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.mc_embeddingbag.ShardedManagedCollisionEmbeddingBagCollection": [[4, 2, 1, "", "create_context"], [4, 4, 1, "", "training"]], "torchrec.distributed.mc_modules": [[4, 1, 1, "", "ManagedCollisionCollectionAwaitable"], [4, 1, 1, "", "ManagedCollisionCollectionContext"], [4, 1, 1, "", "ManagedCollisionCollectionSharder"], [4, 1, 1, "", "ShardedManagedCollisionCollection"], [4, 3, 1, "", "create_mc_sharding"]], "torchrec.distributed.mc_modules.ManagedCollisionCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.mc_modules.ShardedManagedCollisionCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "evict"], [4, 2, 1, "", "global_to_local_index"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "open_slots"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.model_parallel": [[4, 1, 1, "", "DataParallelWrapper"], [4, 1, 1, "", "DefaultDataParallelWrapper"], [4, 1, 1, "", "DistributedModelParallel"], [4, 3, 1, "", "get_module"], [4, 3, 1, "", "get_unwrapped_module"]], "torchrec.distributed.model_parallel.DataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DefaultDataParallelWrapper": [[4, 2, 1, "", "wrap"]], "torchrec.distributed.model_parallel.DistributedModelParallel": [[4, 2, 1, "", "bare_named_parameters"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "forward"], [4, 5, 1, "", "fused_optimizer"], [4, 2, 1, "", "init_data_parallel"], [4, 2, 1, "", "load_state_dict"], [4, 5, 1, "", "module"], [4, 2, 1, "", "named_buffers"], [4, 2, 1, "", "named_parameters"], [4, 5, 1, "", "plan"], [4, 2, 1, "", "sparse_grad_parameter_names"], [4, 2, 1, "", "state_dict"], [4, 4, 1, "", "training"]], "torchrec.distributed.planner": [[5, 0, 0, "-", "constants"], [5, 0, 0, "-", "enumerators"], [5, 0, 0, "-", "partitioners"], [5, 0, 0, "-", "perf_models"], [5, 0, 0, "-", "planners"], [5, 0, 0, "-", "proposers"], [5, 0, 0, "-", "shard_estimators"], [5, 0, 0, "-", "stats"], [5, 0, 0, "-", "storage_reservations"], [5, 0, 0, "-", "types"], [5, 0, 0, "-", "utils"]], "torchrec.distributed.planner.constants": [[5, 3, 1, "", "kernel_bw_lookup"]], "torchrec.distributed.planner.enumerators": [[5, 1, 1, "", "EmbeddingEnumerator"], [5, 3, 1, "", "get_partition_by_type"]], "torchrec.distributed.planner.enumerators.EmbeddingEnumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.partitioners": [[5, 1, 1, "", "GreedyPerfPartitioner"], [5, 1, 1, "", "MemoryBalancedPartitioner"], [5, 1, 1, "", "OrderedDeviceHardware"], [5, 1, 1, "", "ShardingOptionGroup"], [5, 1, 1, "", "SortBy"], [5, 3, 1, "", "set_hbm_per_device"]], "torchrec.distributed.planner.partitioners.GreedyPerfPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.MemoryBalancedPartitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.partitioners.OrderedDeviceHardware": [[5, 4, 1, "", "device"], [5, 4, 1, "", "local_world_size"]], "torchrec.distributed.planner.partitioners.ShardingOptionGroup": [[5, 4, 1, "", "param_count"], [5, 4, 1, "", "perf_sum"], [5, 4, 1, "", "sharding_options"], [5, 4, 1, "", "storage_sum"]], "torchrec.distributed.planner.partitioners.SortBy": [[5, 4, 1, "", "PERF"], [5, 4, 1, "", "STORAGE"]], "torchrec.distributed.planner.perf_models": [[5, 1, 1, "", "NoopPerfModel"], [5, 1, 1, "", "NoopStorageModel"]], "torchrec.distributed.planner.perf_models.NoopPerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.perf_models.NoopStorageModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.planners": [[5, 1, 1, "", "EmbeddingShardingPlanner"], [5, 1, 1, "", "HeteroEmbeddingShardingPlanner"]], "torchrec.distributed.planner.planners.EmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.planners.HeteroEmbeddingShardingPlanner": [[5, 2, 1, "", "collective_plan"], [5, 2, 1, "", "plan"]], "torchrec.distributed.planner.proposers": [[5, 1, 1, "", "DynamicProgrammingProposer"], [5, 1, 1, "", "EmbeddingOffloadScaleupProposer"], [5, 1, 1, "", "GreedyProposer"], [5, 1, 1, "", "GridSearchProposer"], [5, 1, 1, "", "UniformProposer"], [5, 3, 1, "", "proposers_to_proposals_list"]], "torchrec.distributed.planner.proposers.DynamicProgrammingProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.EmbeddingOffloadScaleupProposer": [[5, 2, 1, "", "allocate_budget"], [5, 2, 1, "", "build_affine_storage_model"], [5, 2, 1, "", "clf_to_bytes"], [5, 2, 1, "", "feedback"], [5, 2, 1, "", "get_budget"], [5, 2, 1, "", "get_cacheability"], [5, 2, 1, "", "get_expected_lookups"], [5, 2, 1, "", "load"], [5, 2, 1, "", "next_plan"], [5, 2, 1, "", "promote_high_prefetch_overheaad_table_to_hbm"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GreedyProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.GridSearchProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.proposers.UniformProposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.shard_estimators": [[5, 1, 1, "", "EmbeddingOffloadStats"], [5, 1, 1, "", "EmbeddingPerfEstimator"], [5, 1, 1, "", "EmbeddingStorageEstimator"], [5, 3, 1, "", "calculate_pipeline_io_cost"], [5, 3, 1, "", "calculate_shard_storages"]], "torchrec.distributed.planner.shard_estimators.EmbeddingOffloadStats": [[5, 5, 1, "", "cacheability"], [5, 2, 1, "", "estimate_cache_miss_rate"], [5, 5, 1, "", "expected_lookups"], [5, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.planner.shard_estimators.EmbeddingPerfEstimator": [[5, 2, 1, "", "estimate"], [5, 2, 1, "", "perf_func_emb_wall_time"]], "torchrec.distributed.planner.shard_estimators.EmbeddingStorageEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.stats": [[5, 1, 1, "", "EmbeddingStats"], [5, 1, 1, "", "NoopEmbeddingStats"], [5, 3, 1, "", "round_to_one_sigfig"]], "torchrec.distributed.planner.stats.EmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.stats.NoopEmbeddingStats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.storage_reservations": [[5, 1, 1, "", "FixedPercentageStorageReservation"], [5, 1, 1, "", "HeuristicalStorageReservation"], [5, 1, 1, "", "InferenceStorageReservation"]], "torchrec.distributed.planner.storage_reservations.FixedPercentageStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.storage_reservations.InferenceStorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types": [[5, 1, 1, "", "CustomTopologyData"], [5, 1, 1, "", "DeviceHardware"], [5, 1, 1, "", "Enumerator"], [5, 1, 1, "", "ParameterConstraints"], [5, 1, 1, "", "PartitionByType"], [5, 1, 1, "", "Partitioner"], [5, 1, 1, "", "Perf"], [5, 1, 1, "", "PerfModel"], [5, 6, 1, "", "PlannerError"], [5, 1, 1, "", "PlannerErrorType"], [5, 1, 1, "", "Proposer"], [5, 1, 1, "", "Shard"], [5, 1, 1, "", "ShardEstimator"], [5, 1, 1, "", "ShardingOption"], [5, 1, 1, "", "Stats"], [5, 1, 1, "", "Storage"], [5, 1, 1, "", "StorageReservation"], [5, 1, 1, "", "Topology"]], "torchrec.distributed.planner.types.CustomTopologyData": [[5, 2, 1, "", "get_data"], [5, 2, 1, "", "has_data"], [5, 4, 1, "", "supported_fields"]], "torchrec.distributed.planner.types.DeviceHardware": [[5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.Enumerator": [[5, 2, 1, "", "enumerate"], [5, 2, 1, "", "populate_estimates"]], "torchrec.distributed.planner.types.ParameterConstraints": [[5, 4, 1, "id0", "batch_sizes"], [5, 4, 1, "id1", "bounds_check_mode"], [5, 4, 1, "id2", "cache_params"], [5, 4, 1, "id3", "compute_kernels"], [5, 4, 1, "id4", "device_group"], [5, 4, 1, "id5", "enforce_hbm"], [5, 4, 1, "id6", "feature_names"], [5, 4, 1, "id7", "is_weighted"], [5, 4, 1, "id8", "key_value_params"], [5, 4, 1, "id9", "min_partition"], [5, 4, 1, "id10", "num_poolings"], [5, 4, 1, "id11", "output_dtype"], [5, 4, 1, "id12", "pooling_factors"], [5, 4, 1, "id13", "sharding_types"], [5, 4, 1, "id14", "stochastic_rounding"]], "torchrec.distributed.planner.types.PartitionByType": [[5, 4, 1, "", "DEVICE"], [5, 4, 1, "", "HOST"], [5, 4, 1, "", "UNIFORM"]], "torchrec.distributed.planner.types.Partitioner": [[5, 2, 1, "", "partition"]], "torchrec.distributed.planner.types.Perf": [[5, 4, 1, "", "bwd_comms"], [5, 4, 1, "", "bwd_compute"], [5, 4, 1, "", "fwd_comms"], [5, 4, 1, "", "fwd_compute"], [5, 4, 1, "", "prefetch_compute"], [5, 5, 1, "", "total"]], "torchrec.distributed.planner.types.PerfModel": [[5, 2, 1, "", "rate"]], "torchrec.distributed.planner.types.PlannerErrorType": [[5, 4, 1, "", "INSUFFICIENT_STORAGE"], [5, 4, 1, "", "OTHER"], [5, 4, 1, "", "PARTITION"], [5, 4, 1, "", "STRICT_CONSTRAINTS"]], "torchrec.distributed.planner.types.Proposer": [[5, 2, 1, "", "feedback"], [5, 2, 1, "", "load"], [5, 2, 1, "", "propose"]], "torchrec.distributed.planner.types.Shard": [[5, 4, 1, "", "offset"], [5, 4, 1, "", "perf"], [5, 4, 1, "", "rank"], [5, 4, 1, "", "size"], [5, 4, 1, "", "storage"]], "torchrec.distributed.planner.types.ShardEstimator": [[5, 2, 1, "", "estimate"]], "torchrec.distributed.planner.types.ShardingOption": [[5, 4, 1, "", "batch_size"], [5, 4, 1, "", "bounds_check_mode"], [5, 5, 1, "", "cache_load_factor"], [5, 4, 1, "", "cache_params"], [5, 4, 1, "", "compute_kernel"], [5, 4, 1, "", "dependency"], [5, 4, 1, "", "enforce_hbm"], [5, 4, 1, "", "feature_names"], [5, 5, 1, "", "fqn"], [5, 4, 1, "", "input_lengths"], [5, 5, 1, "id15", "is_pooled"], [5, 4, 1, "", "key_value_params"], [5, 5, 1, "id16", "module"], [5, 2, 1, "", "module_pooled"], [5, 4, 1, "", "name"], [5, 5, 1, "", "num_inputs"], [5, 5, 1, "", "num_shards"], [5, 4, 1, "", "output_dtype"], [5, 5, 1, "", "path"], [5, 4, 1, "", "sharding_type"], [5, 4, 1, "", "shards"], [5, 4, 1, "", "stochastic_rounding"], [5, 5, 1, "id17", "tensor"], [5, 5, 1, "", "total_perf"], [5, 5, 1, "", "total_storage"]], "torchrec.distributed.planner.types.Stats": [[5, 2, 1, "", "log"]], "torchrec.distributed.planner.types.Storage": [[5, 4, 1, "", "ddr"], [5, 2, 1, "", "fits_in"], [5, 4, 1, "", "hbm"]], "torchrec.distributed.planner.types.StorageReservation": [[5, 2, 1, "", "reserve"]], "torchrec.distributed.planner.types.Topology": [[5, 5, 1, "", "bwd_compute_multiplier"], [5, 5, 1, "", "compute_device"], [5, 5, 1, "", "ddr_mem_bw"], [5, 5, 1, "", "devices"], [5, 5, 1, "", "hbm_mem_bw"], [5, 5, 1, "", "inter_host_bw"], [5, 5, 1, "", "intra_host_bw"], [5, 5, 1, "", "local_world_size"], [5, 5, 1, "", "uneven_sharding_perf_multiplier"], [5, 5, 1, "", "weighted_feature_bwd_compute_multiplier"], [5, 5, 1, "", "world_size"]], "torchrec.distributed.planner.utils": [[5, 1, 1, "", "BinarySearchPredicate"], [5, 1, 1, "", "LuusJaakolaSearch"], [5, 3, 1, "", "bytes_to_gb"], [5, 3, 1, "", "bytes_to_mb"], [5, 3, 1, "", "gb_to_bytes"], [5, 3, 1, "", "placement"], [5, 3, 1, "", "prod"], [5, 3, 1, "", "reset_shard_rank"], [5, 3, 1, "", "sharder_name"], [5, 3, 1, "", "storage_repr_in_gb"]], "torchrec.distributed.planner.utils.BinarySearchPredicate": [[5, 2, 1, "", "next"]], "torchrec.distributed.planner.utils.LuusJaakolaSearch": [[5, 2, 1, "", "best"], [5, 2, 1, "", "clamp"], [5, 2, 1, "", "next"], [5, 2, 1, "", "shrink_right"], [5, 2, 1, "", "uniform"]], "torchrec.distributed.quant_embeddingbag": [[4, 1, 1, "", "QuantEmbeddingBagCollectionSharder"], [4, 1, 1, "", "QuantFeatureProcessedEmbeddingBagCollectionSharder"], [4, 1, 1, "", "ShardedQuantEbcInputDist"], [4, 1, 1, "", "ShardedQuantEmbeddingBagCollection"], [4, 1, 1, "", "ShardedQuantFeatureProcessedEmbeddingBagCollection"], [4, 3, 1, "", "create_infer_embedding_bag_sharding"], [4, 3, 1, "", "flatten_feature_lengths"], [4, 3, 1, "", "get_device_from_parameter_sharding"], [4, 3, 1, "", "get_device_from_sharding_infos"]], "torchrec.distributed.quant_embeddingbag.QuantEmbeddingBagCollectionSharder": [[4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"]], "torchrec.distributed.quant_embeddingbag.QuantFeatureProcessedEmbeddingBagCollectionSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "sharding_types"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEbcInputDist": [[4, 2, 1, "", "forward"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantEmbeddingBagCollection": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "copy"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "embedding_bag_configs"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 2, 1, "", "sharding_type_device_group_to_sharding_infos"], [4, 5, 1, "", "shardings"], [4, 2, 1, "", "tbes_configs"], [4, 4, 1, "", "training"]], "torchrec.distributed.quant_embeddingbag.ShardedQuantFeatureProcessedEmbeddingBagCollection": [[4, 2, 1, "", "apply_feature_processor"], [4, 2, 1, "", "compute"], [4, 4, 1, "", "embedding_bags"], [4, 4, 1, "", "tbes"], [4, 4, 1, "", "training"]], "torchrec.distributed.sharding": [[6, 0, 0, "-", "cw_sharding"], [6, 0, 0, "-", "dp_sharding"], [6, 0, 0, "-", "rw_sharding"], [6, 0, 0, "-", "tw_sharding"], [6, 0, 0, "-", "twcw_sharding"], [6, 0, 0, "-", "twrw_sharding"]], "torchrec.distributed.sharding.cw_sharding": [[6, 1, 1, "", "BaseCwEmbeddingSharding"], [6, 1, 1, "", "CwPooledEmbeddingSharding"], [6, 1, 1, "", "InferCwPooledEmbeddingDist"], [6, 1, 1, "", "InferCwPooledEmbeddingDistWithPermute"], [6, 1, 1, "", "InferCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.cw_sharding.BaseCwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "uncombined_embedding_dims"], [6, 2, 1, "", "uncombined_embedding_names"]], "torchrec.distributed.sharding.cw_sharding.CwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingDistWithPermute": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.cw_sharding.InferCwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding": [[6, 1, 1, "", "BaseDpEmbeddingSharding"], [6, 1, 1, "", "DpPooledEmbeddingDist"], [6, 1, 1, "", "DpPooledEmbeddingSharding"], [6, 1, 1, "", "DpSparseFeaturesDist"]], "torchrec.distributed.sharding.dp_sharding.BaseDpEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.dp_sharding.DpPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.dp_sharding.DpSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding": [[6, 1, 1, "", "BaseRwEmbeddingSharding"], [6, 1, 1, "", "InferRwPooledEmbeddingDist"], [6, 1, 1, "", "InferRwPooledEmbeddingSharding"], [6, 1, 1, "", "InferRwSparseFeaturesDist"], [6, 1, 1, "", "RwPooledEmbeddingDist"], [6, 1, 1, "", "RwPooledEmbeddingSharding"], [6, 1, 1, "", "RwSparseFeaturesDist"], [6, 3, 1, "", "get_block_sizes_runtime_device"], [6, 3, 1, "", "get_embedding_shard_metadata"]], "torchrec.distributed.sharding.rw_sharding.BaseRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.InferRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.InferRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.rw_sharding.RwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.rw_sharding.RwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding": [[6, 1, 1, "", "BaseTwEmbeddingSharding"], [6, 1, 1, "", "InferTwEmbeddingSharding"], [6, 1, 1, "", "InferTwPooledEmbeddingDist"], [6, 1, 1, "", "InferTwSparseFeaturesDist"], [6, 1, 1, "", "TwPooledEmbeddingDist"], [6, 1, 1, "", "TwPooledEmbeddingSharding"], [6, 1, 1, "", "TwSparseFeaturesDist"]], "torchrec.distributed.sharding.tw_sharding.BaseTwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "embedding_tables"], [6, 2, 1, "", "feature_names"], [6, 2, 1, "", "feature_names_per_rank"], [6, 2, 1, "", "features_per_rank"]], "torchrec.distributed.sharding.tw_sharding.InferTwEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.InferTwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.InferTwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.tw_sharding.TwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twcw_sharding": [[6, 1, 1, "", "TwCwPooledEmbeddingSharding"]], "torchrec.distributed.sharding.twrw_sharding": [[6, 1, 1, "", "BaseTwRwEmbeddingSharding"], [6, 1, 1, "", "TwRwPooledEmbeddingDist"], [6, 1, 1, "", "TwRwPooledEmbeddingSharding"], [6, 1, 1, "", "TwRwSparseFeaturesDist"]], "torchrec.distributed.sharding.twrw_sharding.BaseTwRwEmbeddingSharding": [[6, 2, 1, "", "embedding_dims"], [6, 2, 1, "", "embedding_names"], [6, 2, 1, "", "embedding_names_per_rank"], [6, 2, 1, "", "embedding_shard_metadata"], [6, 2, 1, "", "feature_names"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.sharding.twrw_sharding.TwRwPooledEmbeddingSharding": [[6, 2, 1, "", "create_input_dist"], [6, 2, 1, "", "create_lookup"], [6, 2, 1, "", "create_output_dist"]], "torchrec.distributed.sharding.twrw_sharding.TwRwSparseFeaturesDist": [[6, 2, 1, "", "forward"], [6, 4, 1, "", "training"]], "torchrec.distributed.types": [[4, 1, 1, "", "Awaitable"], [4, 1, 1, "", "CacheParams"], [4, 1, 1, "", "CacheStatistics"], [4, 1, 1, "", "CommOp"], [4, 1, 1, "", "ComputeKernel"], [4, 1, 1, "", "EmbeddingModuleShardingPlan"], [4, 1, 1, "", "GenericMeta"], [4, 1, 1, "", "GetItemLazyAwaitable"], [4, 1, 1, "", "KeyValueParams"], [4, 1, 1, "", "LazyAwaitable"], [4, 1, 1, "", "LazyGetItemMixin"], [4, 1, 1, "", "LazyNoWait"], [4, 1, 1, "", "ModuleSharder"], [4, 1, 1, "", "ModuleShardingPlan"], [4, 1, 1, "", "NoOpQuantizedCommCodec"], [4, 1, 1, "", "NoWait"], [4, 1, 1, "", "NullShardedModuleContext"], [4, 1, 1, "", "NullShardingContext"], [4, 1, 1, "", "ObjectPoolShardingPlan"], [4, 1, 1, "", "ObjectPoolShardingType"], [4, 1, 1, "", "ParameterSharding"], [4, 1, 1, "", "ParameterStorage"], [4, 1, 1, "", "PipelineType"], [4, 1, 1, "", "QuantizedCommCodec"], [4, 1, 1, "", "QuantizedCommCodecs"], [4, 1, 1, "", "ShardedModule"], [4, 1, 1, "", "ShardingEnv"], [4, 1, 1, "", "ShardingPlan"], [4, 1, 1, "", "ShardingPlanner"], [4, 1, 1, "", "ShardingType"], [4, 3, 1, "", "get_tensor_size_bytes"], [4, 3, 1, "", "rank_device"], [4, 3, 1, "", "scope"]], "torchrec.distributed.types.Awaitable": [[4, 5, 1, "", "callbacks"], [4, 2, 1, "", "wait"]], "torchrec.distributed.types.CacheParams": [[4, 4, 1, "id34", "algorithm"], [4, 4, 1, "id35", "load_factor"], [4, 4, 1, "", "multipass_prefetch_config"], [4, 4, 1, "id36", "precision"], [4, 4, 1, "id37", "prefetch_pipeline"], [4, 4, 1, "id38", "reserved_memory"], [4, 4, 1, "id39", "stats"]], "torchrec.distributed.types.CacheStatistics": [[4, 5, 1, "", "cacheability"], [4, 5, 1, "", "expected_lookups"], [4, 2, 1, "", "expected_miss_rate"]], "torchrec.distributed.types.CommOp": [[4, 4, 1, "", "POOLED_EMBEDDINGS_ALL_TO_ALL"], [4, 4, 1, "", "POOLED_EMBEDDINGS_REDUCE_SCATTER"], [4, 4, 1, "", "SEQUENCE_EMBEDDINGS_ALL_TO_ALL"]], "torchrec.distributed.types.ComputeKernel": [[4, 4, 1, "", "DEFAULT"]], "torchrec.distributed.types.KeyValueParams": [[4, 4, 1, "id40", "gather_ssd_cache_stats"], [4, 4, 1, "", "l2_cache_size"], [4, 4, 1, "", "ods_prefix"], [4, 4, 1, "id41", "ps_client_thread_num"], [4, 4, 1, "id42", "ps_hosts"], [4, 4, 1, "id43", "ps_max_key_per_request"], [4, 4, 1, "id44", "ps_max_local_index_length"], [4, 4, 1, "", "report_interval"], [4, 4, 1, "id45", "ssd_rocksdb_shards"], [4, 4, 1, "id46", "ssd_rocksdb_write_buffer_size"], [4, 4, 1, "id47", "ssd_storage_directory"], [4, 4, 1, "", "stats_reporter_config"], [4, 4, 1, "", "use_passed_in_path"]], "torchrec.distributed.types.ModuleSharder": [[4, 2, 1, "", "compute_kernels"], [4, 5, 1, "", "module_type"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "shard"], [4, 2, 1, "", "shardable_parameters"], [4, 2, 1, "", "sharding_types"], [4, 2, 1, "", "storage_usage"]], "torchrec.distributed.types.NoOpQuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 2, 1, "", "quantized_dtype"]], "torchrec.distributed.types.NullShardedModuleContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.NullShardingContext": [[4, 2, 1, "", "record_stream"]], "torchrec.distributed.types.ObjectPoolShardingPlan": [[4, 4, 1, "", "inference"], [4, 4, 1, "", "sharding_type"]], "torchrec.distributed.types.ObjectPoolShardingType": [[4, 4, 1, "", "REPLICATED_ROW_WISE"], [4, 4, 1, "", "ROW_WISE"]], "torchrec.distributed.types.ParameterSharding": [[4, 4, 1, "", "bounds_check_mode"], [4, 4, 1, "", "cache_params"], [4, 4, 1, "", "compute_kernel"], [4, 4, 1, "", "enforce_hbm"], [4, 4, 1, "", "key_value_params"], [4, 4, 1, "", "output_dtype"], [4, 4, 1, "", "ranks"], [4, 4, 1, "", "sharding_spec"], [4, 4, 1, "", "sharding_type"], [4, 4, 1, "", "stochastic_rounding"]], "torchrec.distributed.types.ParameterStorage": [[4, 4, 1, "", "DDR"], [4, 4, 1, "", "HBM"]], "torchrec.distributed.types.PipelineType": [[4, 4, 1, "", "NONE"], [4, 4, 1, "", "TRAIN_BASE"], [4, 4, 1, "", "TRAIN_PREFETCH_SPARSE_DIST"], [4, 4, 1, "", "TRAIN_SPARSE_DIST"]], "torchrec.distributed.types.QuantizedCommCodec": [[4, 2, 1, "", "calc_quantized_size"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "decode"], [4, 2, 1, "", "encode"], [4, 2, 1, "", "padded_size"], [4, 5, 1, "", "quantized_dtype"]], "torchrec.distributed.types.QuantizedCommCodecs": [[4, 4, 1, "", "backward"], [4, 4, 1, "", "forward"]], "torchrec.distributed.types.ShardedModule": [[4, 2, 1, "", "compute"], [4, 2, 1, "", "compute_and_output_dist"], [4, 2, 1, "", "create_context"], [4, 2, 1, "", "forward"], [4, 2, 1, "", "input_dist"], [4, 2, 1, "", "output_dist"], [4, 5, 1, "", "qcomm_codecs_registry"], [4, 2, 1, "", "sharded_parameter_names"], [4, 4, 1, "", "training"]], "torchrec.distributed.types.ShardingEnv": [[4, 2, 1, "", "from_local"], [4, 2, 1, "", "from_process_group"]], "torchrec.distributed.types.ShardingPlan": [[4, 2, 1, "", "get_plan_for_module"], [4, 4, 1, "id48", "plan"]], "torchrec.distributed.types.ShardingPlanner": [[4, 2, 1, "", "collective_plan"], [4, 2, 1, "", "plan"]], "torchrec.distributed.types.ShardingType": [[4, 4, 1, "", "COLUMN_WISE"], [4, 4, 1, "", "DATA_PARALLEL"], [4, 4, 1, "", "ROW_WISE"], [4, 4, 1, "", "TABLE_COLUMN_WISE"], [4, 4, 1, "", "TABLE_ROW_WISE"], [4, 4, 1, "", "TABLE_WISE"]], "torchrec.distributed.utils": [[4, 1, 1, "", "CopyableMixin"], [4, 1, 1, "", "ForkedPdb"], [4, 3, 1, "", "add_params_from_parameter_sharding"], [4, 3, 1, "", "add_prefix_to_state_dict"], [4, 3, 1, "", "append_prefix"], [4, 3, 1, "", "convert_to_fbgemm_types"], [4, 3, 1, "", "copy_to_device"], [4, 3, 1, "", "filter_state_dict"], [4, 3, 1, "", "get_unsharded_module_names"], [4, 3, 1, "", "init_parameters"], [4, 3, 1, "", "merge_fused_params"], [4, 3, 1, "", "none_throws"], [4, 3, 1, "", "optimizer_type_to_emb_opt_type"], [4, 1, 1, "", "sharded_model_copy"]], "torchrec.distributed.utils.CopyableMixin": [[4, 2, 1, "", "copy"], [4, 4, 1, "", "training"]], "torchrec.distributed.utils.ForkedPdb": [[4, 2, 1, "", "interaction"]], "torchrec.fx": [[7, 0, 0, "-", "tracer"]], "torchrec.fx.tracer": [[7, 1, 1, "", "Tracer"], [7, 3, 1, "", "is_fx_tracing"], [7, 3, 1, "", "symbolic_trace"]], "torchrec.fx.tracer.Tracer": [[7, 2, 1, "", "create_arg"], [7, 2, 1, "", "is_leaf_module"], [7, 2, 1, "", "path_of_module"], [7, 2, 1, "", "trace"]], "torchrec.inference": [[8, 0, 0, "-", "model_packager"], [8, 0, 0, "-", "modules"]], "torchrec.inference.model_packager": [[8, 1, 1, "", "PredictFactoryPackager"], [8, 3, 1, "", "load_config_text"], [8, 3, 1, "", "load_pickle_config"]], "torchrec.inference.model_packager.PredictFactoryPackager": [[8, 2, 1, "", "save_predict_factory"], [8, 2, 1, "", "set_extern_modules"], [8, 2, 1, "", "set_mocked_modules"]], "torchrec.inference.modules": [[8, 1, 1, "", "BatchingMetadata"], [8, 1, 1, "", "PredictFactory"], [8, 1, 1, "", "PredictModule"], [8, 1, 1, "", "QualNameMetadata"], [8, 3, 1, "", "assign_weights_to_tbe"], [8, 3, 1, "", "get_table_to_weights_from_tbe"], [8, 3, 1, "", "quantize_dense"], [8, 3, 1, "", "quantize_embeddings"], [8, 3, 1, "", "quantize_feature"], [8, 3, 1, "", "quantize_inference_model"], [8, 3, 1, "", "set_pruning_data"], [8, 3, 1, "", "shard_quant_model"], [8, 3, 1, "", "trim_torch_package_prefix_from_typename"]], "torchrec.inference.modules.BatchingMetadata": [[8, 4, 1, "", "device"], [8, 4, 1, "", "pinned"], [8, 4, 1, "", "type"]], "torchrec.inference.modules.PredictFactory": [[8, 2, 1, "", "batching_metadata"], [8, 2, 1, "", "batching_metadata_json"], [8, 2, 1, "", "create_predict_module"], [8, 2, 1, "", "model_inputs_data"], [8, 2, 1, "", "qualname_metadata"], [8, 2, 1, "", "qualname_metadata_json"], [8, 2, 1, "", "result_metadata"], [8, 2, 1, "", "run_weights_dependent_transformations"], [8, 2, 1, "", "run_weights_independent_tranformations"]], "torchrec.inference.modules.PredictModule": [[8, 2, 1, "", "forward"], [8, 2, 1, "", "predict_forward"], [8, 5, 1, "", "predict_module"], [8, 2, 1, "", "state_dict"], [8, 4, 1, "", "training"]], "torchrec.inference.modules.QualNameMetadata": [[8, 4, 1, "", "need_preproc"]], "torchrec.metrics": [[9, 0, 0, "-", "accuracy"], [9, 0, 0, "-", "auc"], [9, 0, 0, "-", "auprc"], [9, 0, 0, "-", "calibration"], [9, 0, 0, "-", "ctr"], [9, 0, 0, "-", "mae"], [9, 0, 0, "-", "metric_module"], [9, 0, 0, "-", "mse"], [9, 0, 0, "-", "multiclass_recall"], [9, 0, 0, "-", "ndcg"], [9, 0, 0, "-", "ne"], [9, 0, 0, "-", "precision"], [9, 0, 0, "-", "rauc"], [9, 0, 0, "-", "rec_metric"], [9, 0, 0, "-", "recall"], [9, 0, 0, "-", "throughput"], [9, 0, 0, "-", "weighted_avg"], [9, 0, 0, "-", "xauc"]], "torchrec.metrics.accuracy": [[9, 1, 1, "", "AccuracyMetric"], [9, 1, 1, "", "AccuracyMetricComputation"], [9, 3, 1, "", "compute_accuracy"], [9, 3, 1, "", "compute_accuracy_sum"], [9, 3, 1, "", "get_accuracy_states"]], "torchrec.metrics.accuracy.AccuracyMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.auc": [[9, 1, 1, "", "AUCMetric"], [9, 1, 1, "", "AUCMetricComputation"], [9, 3, 1, "", "compute_auc"], [9, 3, 1, "", "compute_auc_per_group"]], "torchrec.metrics.auc.AUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.auprc": [[9, 1, 1, "", "AUPRCMetric"], [9, 1, 1, "", "AUPRCMetricComputation"], [9, 3, 1, "", "compute_auprc"], [9, 3, 1, "", "compute_auprc_per_group"]], "torchrec.metrics.auprc.AUPRCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.calibration": [[9, 1, 1, "", "CalibrationMetric"], [9, 1, 1, "", "CalibrationMetricComputation"], [9, 3, 1, "", "compute_calibration"], [9, 3, 1, "", "get_calibration_states"]], "torchrec.metrics.calibration.CalibrationMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ctr": [[9, 1, 1, "", "CTRMetric"], [9, 1, 1, "", "CTRMetricComputation"], [9, 3, 1, "", "compute_ctr"], [9, 3, 1, "", "get_ctr_states"]], "torchrec.metrics.ctr.CTRMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.mae": [[9, 1, 1, "", "MAEMetric"], [9, 1, 1, "", "MAEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mae"], [9, 3, 1, "", "get_mae_states"]], "torchrec.metrics.mae.MAEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.metric_module": [[9, 1, 1, "", "RecMetricModule"], [9, 1, 1, "", "StateMetric"], [9, 3, 1, "", "generate_metric_module"]], "torchrec.metrics.metric_module.RecMetricModule": [[9, 4, 1, "", "batch_size"], [9, 2, 1, "", "check_memory_usage"], [9, 2, 1, "", "compute"], [9, 4, 1, "", "compute_count"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 4, 1, "", "last_compute_time"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "memory_usage_limit_mb"], [9, 4, 1, "", "memory_usage_mb_avg"], [9, 4, 1, "", "oom_count"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "rec_tasks"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "should_compute"], [9, 4, 1, "", "state_metrics"], [9, 2, 1, "", "sync"], [9, 4, 1, "", "throughput_metric"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"], [9, 4, 1, "", "world_size"]], "torchrec.metrics.metric_module.StateMetric": [[9, 2, 1, "", "get_metrics"]], "torchrec.metrics.mse": [[9, 1, 1, "", "MSEMetric"], [9, 1, 1, "", "MSEMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_mse"], [9, 3, 1, "", "compute_rmse"], [9, 3, 1, "", "get_mse_states"]], "torchrec.metrics.mse.MSEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.multiclass_recall": [[9, 1, 1, "", "MulticlassRecallMetric"], [9, 1, 1, "", "MulticlassRecallMetricComputation"], [9, 3, 1, "", "compute_multiclass_recall_at_k"], [9, 3, 1, "", "compute_true_positives_at_k"], [9, 3, 1, "", "get_multiclass_recall_states"]], "torchrec.metrics.multiclass_recall.MulticlassRecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ndcg": [[9, 1, 1, "", "NDCGComputation"], [9, 1, 1, "", "NDCGMetric"]], "torchrec.metrics.ndcg.NDCGComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.ne": [[9, 1, 1, "", "NEMetric"], [9, 1, 1, "", "NEMetricComputation"], [9, 3, 1, "", "compute_cross_entropy"], [9, 3, 1, "", "compute_logloss"], [9, 3, 1, "", "compute_ne"], [9, 3, 1, "", "get_ne_states"]], "torchrec.metrics.ne.NEMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.precision": [[9, 1, 1, "", "PrecisionMetric"], [9, 1, 1, "", "PrecisionMetricComputation"], [9, 3, 1, "", "compute_false_pos_sum"], [9, 3, 1, "", "compute_precision"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_precision_states"]], "torchrec.metrics.precision.PrecisionMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.rauc": [[9, 1, 1, "", "RAUCMetric"], [9, 1, 1, "", "RAUCMetricComputation"], [9, 3, 1, "", "compute_rauc"], [9, 3, 1, "", "compute_rauc_per_group"], [9, 3, 1, "", "conquer_and_count"], [9, 3, 1, "", "count_reverse_pairs_divide_and_conquer"], [9, 3, 1, "", "divide"]], "torchrec.metrics.rauc.RAUCMetricComputation": [[9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric": [[9, 1, 1, "", "MetricComputationReport"], [9, 1, 1, "", "RecMetric"], [9, 1, 1, "", "RecMetricComputation"], [9, 6, 1, "", "RecMetricException"], [9, 1, 1, "", "RecMetricList"], [9, 1, 1, "", "WindowBuffer"]], "torchrec.metrics.rec_metric.MetricComputationReport": [[9, 4, 1, "", "description"], [9, 4, 1, "", "metric_prefix"], [9, 4, 1, "", "name"], [9, 4, 1, "", "value"]], "torchrec.metrics.rec_metric.RecMetric": [[9, 4, 1, "", "LABELS"], [9, 4, 1, "", "PREDICTIONS"], [9, 4, 1, "", "WEIGHTS"], [9, 2, 1, "", "compute"], [9, 2, 1, "", "get_memory_usage"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "state_dict"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricComputation": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_window_state"], [9, 2, 1, "", "get_window_state_name"], [9, 2, 1, "", "local_compute"], [9, 2, 1, "", "pre_compute"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.RecMetricList": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "get_required_inputs"], [9, 2, 1, "", "local_compute"], [9, 4, 1, "", "rec_metrics"], [9, 4, 1, "", "required_inputs"], [9, 2, 1, "", "reset"], [9, 2, 1, "", "sync"], [9, 2, 1, "", "unsync"], [9, 2, 1, "", "update"]], "torchrec.metrics.rec_metric.WindowBuffer": [[9, 2, 1, "", "aggregate_state"], [9, 5, 1, "", "buffers"]], "torchrec.metrics.recall": [[9, 1, 1, "", "RecallMetric"], [9, 1, 1, "", "RecallMetricComputation"], [9, 3, 1, "", "compute_false_neg_sum"], [9, 3, 1, "", "compute_recall"], [9, 3, 1, "", "compute_true_pos_sum"], [9, 3, 1, "", "get_recall_states"]], "torchrec.metrics.recall.RecallMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.throughput": [[9, 1, 1, "", "ThroughputMetric"]], "torchrec.metrics.throughput.ThroughputMetric": [[9, 2, 1, "", "compute"], [9, 2, 1, "", "update"]], "torchrec.metrics.weighted_avg": [[9, 1, 1, "", "WeightedAvgMetric"], [9, 1, 1, "", "WeightedAvgMetricComputation"], [9, 3, 1, "", "get_mean"]], "torchrec.metrics.weighted_avg.WeightedAvgMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.metrics.xauc": [[9, 1, 1, "", "XAUCMetric"], [9, 1, 1, "", "XAUCMetricComputation"], [9, 3, 1, "", "compute_error_sum"], [9, 3, 1, "", "compute_weighted_num_pairs"], [9, 3, 1, "", "compute_xauc"], [9, 3, 1, "", "get_xauc_states"]], "torchrec.metrics.xauc.XAUCMetricComputation": [[9, 2, 1, "", "update"]], "torchrec.models": [[10, 0, 0, "-", "deepfm"], [10, 0, 0, "-", "dlrm"]], "torchrec.models.deepfm": [[10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "FMInteractionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SimpleDeepFMNN"], [10, 1, 1, "", "SparseArch"]], "torchrec.models.deepfm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.FMInteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SimpleDeepFMNN": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.deepfm.SparseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm": [[10, 1, 1, "", "DLRM"], [10, 1, 1, "", "DLRMTrain"], [10, 1, 1, "", "DLRM_DCN"], [10, 1, 1, "", "DLRM_Projection"], [10, 1, 1, "", "DenseArch"], [10, 1, 1, "", "InteractionArch"], [10, 1, 1, "", "InteractionDCNArch"], [10, 1, 1, "", "InteractionProjectionArch"], [10, 1, 1, "", "OverArch"], [10, 1, 1, "", "SparseArch"], [10, 3, 1, "", "choose"]], "torchrec.models.dlrm.DLRM": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRMTrain": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_DCN": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DLRM_Projection": [[10, 4, 1, "", "sparse_arch"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.DenseArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionDCNArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.InteractionProjectionArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.OverArch": [[10, 2, 1, "", "forward"], [10, 4, 1, "", "training"]], "torchrec.models.dlrm.SparseArch": [[10, 2, 1, "", "forward"], [10, 5, 1, "", "sparse_feature_names"], [10, 4, 1, "", "training"]], "torchrec.modules": [[11, 0, 0, "-", "activation"], [11, 0, 0, "-", "crossnet"], [11, 0, 0, "-", "deepfm"], [11, 0, 0, "-", "embedding_configs"], [11, 0, 0, "-", "embedding_modules"], [11, 0, 0, "-", "feature_processor"], [11, 0, 0, "-", "lazy_extension"], [11, 0, 0, "-", "mc_embedding_modules"], [11, 0, 0, "-", "mc_modules"], [11, 0, 0, "-", "mlp"], [11, 0, 0, "-", "utils"]], "torchrec.modules.activation": [[11, 1, 1, "", "SwishLayerNorm"]], "torchrec.modules.activation.SwishLayerNorm": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet": [[11, 1, 1, "", "CrossNet"], [11, 1, 1, "", "LowRankCrossNet"], [11, 1, 1, "", "LowRankMixtureCrossNet"], [11, 1, 1, "", "VectorCrossNet"]], "torchrec.modules.crossnet.CrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.LowRankMixtureCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.crossnet.VectorCrossNet": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm": [[11, 1, 1, "", "DeepFM"], [11, 1, 1, "", "FactorizationMachine"]], "torchrec.modules.deepfm.DeepFM": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.deepfm.FactorizationMachine": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_configs": [[11, 1, 1, "", "BaseEmbeddingConfig"], [11, 1, 1, "", "EmbeddingBagConfig"], [11, 1, 1, "", "EmbeddingConfig"], [11, 1, 1, "", "EmbeddingTableConfig"], [11, 1, 1, "", "PoolingType"], [11, 1, 1, "", "QuantConfig"], [11, 1, 1, "", "ShardingType"], [11, 3, 1, "", "data_type_to_dtype"], [11, 3, 1, "", "data_type_to_sparse_type"], [11, 3, 1, "", "dtype_to_data_type"], [11, 3, 1, "", "pooling_type_to_pooling_mode"], [11, 3, 1, "", "pooling_type_to_str"]], "torchrec.modules.embedding_configs.BaseEmbeddingConfig": [[11, 4, 1, "", "data_type"], [11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 2, 1, "", "get_weight_init_max"], [11, 2, 1, "", "get_weight_init_min"], [11, 4, 1, "", "init_fn"], [11, 4, 1, "", "name"], [11, 4, 1, "", "need_pos"], [11, 4, 1, "", "num_embeddings"], [11, 4, 1, "", "num_embeddings_post_pruning"], [11, 2, 1, "", "num_features"], [11, 4, 1, "", "weight_init_max"], [11, 4, 1, "", "weight_init_min"]], "torchrec.modules.embedding_configs.EmbeddingBagConfig": [[11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.EmbeddingConfig": [[11, 4, 1, "", "embedding_dim"], [11, 4, 1, "", "feature_names"], [11, 4, 1, "", "num_embeddings"]], "torchrec.modules.embedding_configs.EmbeddingTableConfig": [[11, 4, 1, "", "embedding_names"], [11, 4, 1, "", "has_feature_processor"], [11, 4, 1, "", "is_weighted"], [11, 4, 1, "", "pooling"]], "torchrec.modules.embedding_configs.PoolingType": [[11, 4, 1, "", "MEAN"], [11, 4, 1, "", "NONE"], [11, 4, 1, "", "SUM"]], "torchrec.modules.embedding_configs.QuantConfig": [[11, 4, 1, "", "activation"], [11, 4, 1, "", "per_table_weight_dtype"], [11, 4, 1, "", "weight"]], "torchrec.modules.embedding_configs.ShardingType": [[11, 4, 1, "", "COLUMN_WISE"], [11, 4, 1, "", "DATA_PARALLEL"], [11, 4, 1, "", "ROW_WISE"], [11, 4, 1, "", "TABLE_COLUMN_WISE"], [11, 4, 1, "", "TABLE_ROW_WISE"], [11, 4, 1, "", "TABLE_WISE"]], "torchrec.modules.embedding_modules": [[11, 1, 1, "", "EmbeddingBagCollection"], [11, 1, 1, "", "EmbeddingBagCollectionInterface"], [11, 1, 1, "", "EmbeddingCollection"], [11, 1, 1, "", "EmbeddingCollectionInterface"], [11, 3, 1, "", "get_embedding_names_by_table"], [11, 3, 1, "", "process_pooled_embeddings"], [11, 3, 1, "", "reorder_inverse_indices"]], "torchrec.modules.embedding_modules.EmbeddingBagCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingBagCollectionInterface": [[11, 2, 1, "", "embedding_bag_configs"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "is_weighted"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollection": [[11, 5, 1, "", "device"], [11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.embedding_modules.EmbeddingCollectionInterface": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "embedding_dim"], [11, 2, 1, "", "embedding_names_by_table"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "need_indices"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor": [[11, 1, 1, "", "BaseFeatureProcessor"], [11, 1, 1, "", "BaseGroupedFeatureProcessor"], [11, 1, 1, "", "PositionWeightedModule"], [11, 1, 1, "", "PositionWeightedProcessor"], [11, 3, 1, "", "offsets_to_range_traceble"], [11, 3, 1, "", "position_weighted_module_update_features"]], "torchrec.modules.feature_processor.BaseFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.BaseGroupedFeatureProcessor": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedModule": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "reset_parameters"], [11, 4, 1, "", "training"]], "torchrec.modules.feature_processor.PositionWeightedProcessor": [[11, 2, 1, "", "forward"], [11, 2, 1, "", "named_buffers"], [11, 2, 1, "", "state_dict"], [11, 4, 1, "", "training"]], "torchrec.modules.lazy_extension": [[11, 1, 1, "", "LazyModuleExtensionMixin"], [11, 3, 1, "", "lazy_apply"]], "torchrec.modules.lazy_extension.LazyModuleExtensionMixin": [[11, 2, 1, "", "apply"]], "torchrec.modules.mc_embedding_modules": [[11, 1, 1, "", "BaseManagedCollisionEmbeddingCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingBagCollection"], [11, 1, 1, "", "ManagedCollisionEmbeddingCollection"], [11, 3, 1, "", "evict"]], "torchrec.modules.mc_embedding_modules.BaseManagedCollisionEmbeddingCollection": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingBagCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_embedding_modules.ManagedCollisionEmbeddingCollection": [[11, 4, 1, "", "training"]], "torchrec.modules.mc_modules": [[11, 1, 1, "", "DistanceLFU_EvictionPolicy"], [11, 1, 1, "", "LFU_EvictionPolicy"], [11, 1, 1, "", "LRU_EvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicy"], [11, 1, 1, "", "MCHEvictionPolicyMetadataInfo"], [11, 1, 1, "", "MCHManagedCollisionModule"], [11, 1, 1, "", "ManagedCollisionCollection"], [11, 1, 1, "", "ManagedCollisionModule"], [11, 3, 1, "", "apply_mc_method_to_jt_dict"], [11, 3, 1, "", "average_threshold_filter"], [11, 3, 1, "", "dynamic_threshold_filter"], [11, 3, 1, "", "probabilistic_threshold_filter"]], "torchrec.modules.mc_modules.DistanceLFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LFU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.LRU_EvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicy": [[11, 2, 1, "", "coalesce_history_metadata"], [11, 5, 1, "", "metadata_info"], [11, 2, 1, "", "record_history_metadata"], [11, 2, 1, "", "update_metadata_and_generate_eviction_scores"]], "torchrec.modules.mc_modules.MCHEvictionPolicyMetadataInfo": [[11, 4, 1, "", "is_history_metadata"], [11, 4, 1, "", "is_mch_metadata"], [11, 4, 1, "", "metadata_name"]], "torchrec.modules.mc_modules.MCHManagedCollisionModule": [[11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mc_modules.ManagedCollisionCollection": [[11, 2, 1, "", "embedding_configs"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "open_slots"]], "torchrec.modules.mc_modules.ManagedCollisionModule": [[11, 5, 1, "", "device"], [11, 2, 1, "", "evict"], [11, 2, 1, "", "forward"], [11, 2, 1, "", "input_size"], [11, 2, 1, "", "open_slots"], [11, 2, 1, "", "output_size"], [11, 2, 1, "", "preprocess"], [11, 2, 1, "", "profile"], [11, 2, 1, "", "rebuild_with_output_id_range"], [11, 2, 1, "", "remap"], [11, 4, 1, "", "training"], [11, 2, 1, "", "validate_state"]], "torchrec.modules.mlp": [[11, 1, 1, "", "MLP"], [11, 1, 1, "", "Perceptron"]], "torchrec.modules.mlp.MLP": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.mlp.Perceptron": [[11, 2, 1, "", "forward"], [11, 4, 1, "", "training"]], "torchrec.modules.utils": [[11, 1, 1, "", "SequenceVBEContext"], [11, 3, 1, "", "check_module_output_dimension"], [11, 3, 1, "", "construct_jagged_tensors"], [11, 3, 1, "", "construct_jagged_tensors_inference"], [11, 3, 1, "", "construct_modulelist_from_single_module"], [11, 3, 1, "", "convert_list_of_modules_to_modulelist"], [11, 3, 1, "", "deterministic_dedup"], [11, 3, 1, "", "extract_module_or_tensor_callable"], [11, 3, 1, "", "get_module_output_dimension"], [11, 3, 1, "", "init_mlp_weights_xavier_uniform"], [11, 3, 1, "", "jagged_index_select_with_empty"]], "torchrec.modules.utils.SequenceVBEContext": [[11, 4, 1, "", "recat"], [11, 2, 1, "", "record_stream"], [11, 4, 1, "", "reindexed_length_per_key"], [11, 4, 1, "", "reindexed_lengths"], [11, 4, 1, "", "reindexed_values"], [11, 4, 1, "", "unpadded_lengths"]], "torchrec.optim": [[12, 0, 0, "-", "clipping"], [12, 0, 0, "-", "fused"], [12, 0, 0, "-", "keyed"], [12, 0, 0, "-", "warmup"]], "torchrec.optim.clipping": [[12, 1, 1, "", "GradientClipping"], [12, 1, 1, "", "GradientClippingOptimizer"]], "torchrec.optim.clipping.GradientClipping": [[12, 4, 1, "", "NONE"], [12, 4, 1, "", "NORM"], [12, 4, 1, "", "VALUE"]], "torchrec.optim.clipping.GradientClippingOptimizer": [[12, 2, 1, "", "clip_grad_norm_"], [12, 2, 1, "", "step"]], "torchrec.optim.fused": [[12, 1, 1, "", "EmptyFusedOptimizer"], [12, 1, 1, "", "FusedOptimizer"], [12, 1, 1, "", "FusedOptimizerModule"]], "torchrec.optim.fused.EmptyFusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizer": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.fused.FusedOptimizerModule": [[12, 5, 1, "", "fused_optimizer"]], "torchrec.optim.keyed": [[12, 1, 1, "", "CombinedOptimizer"], [12, 1, 1, "", "KeyedOptimizer"], [12, 1, 1, "", "KeyedOptimizerWrapper"], [12, 1, 1, "", "OptimizerWrapper"]], "torchrec.optim.keyed.CombinedOptimizer": [[12, 5, 1, "", "optimizers"], [12, 5, 1, "", "param_groups"], [12, 5, 1, "", "params"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "prepend_opt_key"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "set_optimizer_step"], [12, 5, 1, "", "state"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.KeyedOptimizer": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "init_state"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"]], "torchrec.optim.keyed.KeyedOptimizerWrapper": [[12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.keyed.OptimizerWrapper": [[12, 2, 1, "", "add_param_group"], [12, 2, 1, "", "load_state_dict"], [12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "save_param_groups"], [12, 2, 1, "", "state_dict"], [12, 2, 1, "", "step"], [12, 2, 1, "", "zero_grad"]], "torchrec.optim.warmup": [[12, 1, 1, "", "WarmupOptimizer"], [12, 1, 1, "", "WarmupPolicy"], [12, 1, 1, "", "WarmupStage"]], "torchrec.optim.warmup.WarmupOptimizer": [[12, 2, 1, "", "post_load_state_dict"], [12, 2, 1, "", "step"]], "torchrec.optim.warmup.WarmupPolicy": [[12, 4, 1, "", "CONSTANT"], [12, 4, 1, "", "COSINE_ANNEALING_WARM_RESTARTS"], [12, 4, 1, "", "INVSQRT"], [12, 4, 1, "", "LINEAR"], [12, 4, 1, "", "NONE"], [12, 4, 1, "", "POLY"], [12, 4, 1, "", "STEP"]], "torchrec.optim.warmup.WarmupStage": [[12, 4, 1, "", "decay_iters"], [12, 4, 1, "", "lr_scale"], [12, 4, 1, "", "max_iters"], [12, 4, 1, "", "policy"], [12, 4, 1, "", "sgdr_period"], [12, 4, 1, "", "value"]], "torchrec.quant": [[13, 0, 0, "-", "embedding_modules"]], "torchrec.quant.embedding_modules": [[13, 1, 1, "", "EmbeddingBagCollection"], [13, 1, 1, "", "EmbeddingCollection"], [13, 1, 1, "", "FeatureProcessedEmbeddingBagCollection"], [13, 3, 1, "", "for_each_module_of_type_do"], [13, 3, 1, "", "quant_prep_customize_row_alignment"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias"], [13, 3, 1, "", "quant_prep_enable_quant_state_dict_split_scale_bias_for_types"], [13, 3, 1, "", "quant_prep_enable_register_tbes"], [13, 3, 1, "", "quantize_state_dict"]], "torchrec.quant.embedding_modules.EmbeddingBagCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_bag_configs"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "is_weighted"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.EmbeddingCollection": [[13, 5, 1, "", "device"], [13, 2, 1, "", "embedding_configs"], [13, 2, 1, "", "embedding_dim"], [13, 2, 1, "", "embedding_names_by_table"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 2, 1, "", "need_indices"], [13, 2, 1, "", "output_dtype"], [13, 4, 1, "", "training"]], "torchrec.quant.embedding_modules.FeatureProcessedEmbeddingBagCollection": [[13, 4, 1, "", "embedding_bags"], [13, 2, 1, "", "forward"], [13, 2, 1, "", "from_float"], [13, 4, 1, "", "tbes"], [13, 4, 1, "", "training"]], "torchrec.sparse": [[14, 0, 0, "-", "jagged_tensor"]], "torchrec.sparse.jagged_tensor": [[14, 1, 1, "", "ComputeJTDictToKJT"], [14, 1, 1, "", "ComputeKJTToJTDict"], [14, 1, 1, "", "JaggedTensor"], [14, 1, 1, "", "JaggedTensorMeta"], [14, 1, 1, "", "KeyedJaggedTensor"], [14, 1, 1, "", "KeyedTensor"], [14, 3, 1, "", "flatten_kjt_list"], [14, 3, 1, "", "jt_is_equal"], [14, 3, 1, "", "kjt_is_equal"], [14, 3, 1, "", "permute_multi_embedding"], [14, 3, 1, "", "regroup_kts"], [14, 3, 1, "", "unflatten_kjt_list"]], "torchrec.sparse.jagged_tensor.ComputeJTDictToKJT": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.ComputeKJTToJTDict": [[14, 2, 1, "", "forward"], [14, 4, 1, "", "training"]], "torchrec.sparse.jagged_tensor.JaggedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "from_dense"], [14, 2, 1, "", "from_dense_lengths"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dense"], [14, 2, 1, "", "to_dense_weights"], [14, 2, 1, "", "to_padded_dense"], [14, 2, 1, "", "to_padded_dense_weights"], [14, 2, 1, "", "values"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedJaggedTensor": [[14, 2, 1, "", "concat"], [14, 2, 1, "", "device"], [14, 2, 1, "", "dist_init"], [14, 2, 1, "", "dist_labels"], [14, 2, 1, "", "dist_splits"], [14, 2, 1, "", "dist_tensors"], [14, 2, 1, "", "empty"], [14, 2, 1, "", "empty_like"], [14, 2, 1, "", "flatten_lengths"], [14, 2, 1, "", "from_jt_dict"], [14, 2, 1, "", "from_lengths_sync"], [14, 2, 1, "", "from_offsets_sync"], [14, 2, 1, "", "index_per_key"], [14, 2, 1, "", "inverse_indices"], [14, 2, 1, "", "inverse_indices_or_none"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "length_per_key_or_none"], [14, 2, 1, "", "lengths"], [14, 2, 1, "", "lengths_offset_per_key"], [14, 2, 1, "", "lengths_or_none"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "offset_per_key_or_none"], [14, 2, 1, "", "offsets"], [14, 2, 1, "", "offsets_or_none"], [14, 2, 1, "", "permute"], [14, 2, 1, "", "pin_memory"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "split"], [14, 2, 1, "", "stride"], [14, 2, 1, "", "stride_per_key"], [14, 2, 1, "", "stride_per_key_per_rank"], [14, 2, 1, "", "sync"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "unsync"], [14, 2, 1, "", "values"], [14, 2, 1, "", "variable_stride_per_key"], [14, 2, 1, "", "weights"], [14, 2, 1, "", "weights_or_none"]], "torchrec.sparse.jagged_tensor.KeyedTensor": [[14, 2, 1, "", "device"], [14, 2, 1, "", "from_tensor_list"], [14, 2, 1, "", "key_dim"], [14, 2, 1, "", "keys"], [14, 2, 1, "", "length_per_key"], [14, 2, 1, "", "offset_per_key"], [14, 2, 1, "", "record_stream"], [14, 2, 1, "", "regroup"], [14, 2, 1, "", "regroup_as_dict"], [14, 2, 1, "", "to"], [14, 2, 1, "", "to_dict"], [14, 2, 1, "", "values"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:function", "4": "py:attribute", "5": "py:property", "6": "py:exception"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "function", "Python function"], "4": ["py", "attribute", "Python attribute"], "5": ["py", "property", "Python property"], "6": ["py", "exception", "Python exception"]}, "titleterms": {"welcom": 0, "torchrec": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], "document": 0, "get": 0, "start": 0, "how": 0, "contribut": 0, "overview": 1, "why": 1, "dataset": [2, 3], "criteo": 2, "movielen": 2, "random": 2, "util": [2, 4, 5, 11], "script": 3, "contiguous_preproc_criteo": 3, "npy_preproc_criteo": 3, "distribut": [4, 5, 6], "collective_util": 4, "comm": 4, "comm_op": 4, "dist_data": [4, 6], "embed": 4, "embedding_lookup": 4, "embedding_shard": 4, "embedding_typ": 4, "embeddingbag": 4, "grouped_position_weight": 4, "model_parallel": 4, "quant_embeddingbag": 4, "train_pipelin": 4, "type": [4, 5], "mc_modul": [4, 11], "mc_embeddingbag": 4, "mc_embed": 4, "planner": 5, "constant": 5, "enumer": 5, "partition": 5, "perf_model": 5, "propos": 5, "shard_estim": 5, "stat": 5, "storage_reserv": 5, "shard": 6, "cw_shard": 6, "dp_shard": 6, "rw_shard": 6, "tw_shard": 6, "twcw_shard": 6, "twrw_shard": 6, "fx": 7, "tracer": 7, "modul": [7, 8, 10, 11, 12, 13, 14], "content": [7, 8, 10, 12, 13, 14], "infer": 8, "model_packag": 8, "metric": 9, "accuraci": 9, "auc": 9, "auprc": 9, "calibr": 9, "ctr": 9, "mae": 9, "mse": 9, "multiclass_recal": 9, "ndcg": 9, "ne": 9, "recal": 9, "precis": 9, "rauc": 9, "throughput": 9, "weighted_avg": 9, "xauc": 9, "metric_modul": 9, "rec_metr": 9, "model": 10, "deepfm": [10, 11], "dlrm": 10, "activ": 11, "crossnet": 11, "embedding_config": 11, "embedding_modul": [11, 13], "feature_processor": 11, "lazy_extens": 11, "mlp": 11, "mc_embedding_modul": 11, "optim": 12, "clip": 12, "fuse": 12, "kei": 12, "warmup": 12, "quant": 13, "spars": 14, "jagged_tensor": 14}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 6, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx": 56}}) \ No newline at end of file diff --git a/torchrec.datasets.html b/torchrec.datasets.html index a97967b15..70444f5b2 100644 --- a/torchrec.datasets.html +++ b/torchrec.datasets.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.datasets.scripts.html b/torchrec.datasets.scripts.html index b4198da4b..ec1ec409d 100644 --- a/torchrec.datasets.scripts.html +++ b/torchrec.datasets.scripts.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.distributed.html b/torchrec.distributed.html index e7618f0d8..2439a4fe1 100644 --- a/torchrec.distributed.html +++ b/torchrec.distributed.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    @@ -5678,6 +5678,11 @@ SHAMPOO = 'SHAMPOO'
    +
    +
    +SHAMPOO_MRS = 'SHAMPOO_MRS'
    +
    +
    SHAMPOO_V2 = 'SHAMPOO_V2'
    diff --git a/torchrec.distributed.planner.html b/torchrec.distributed.planner.html index 7021d0cb7..f53b226de 100644 --- a/torchrec.distributed.planner.html +++ b/torchrec.distributed.planner.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.distributed.sharding.html b/torchrec.distributed.sharding.html index 64b729c67..37a0c682f 100644 --- a/torchrec.distributed.sharding.html +++ b/torchrec.distributed.sharding.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.fx.html b/torchrec.fx.html index fd3de17d7..13b94bc2e 100644 --- a/torchrec.fx.html +++ b/torchrec.fx.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.inference.html b/torchrec.inference.html index bcb53bb99..16bd10dff 100644 --- a/torchrec.inference.html +++ b/torchrec.inference.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.metrics.html b/torchrec.metrics.html index 58626bd3d..3453f2343 100644 --- a/torchrec.metrics.html +++ b/torchrec.metrics.html @@ -270,7 +270,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.models.html b/torchrec.models.html index a6417d9ca..28c167326 100644 --- a/torchrec.models.html +++ b/torchrec.models.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.modules.html b/torchrec.modules.html index 2cb69c008..c25f7d89d 100644 --- a/torchrec.modules.html +++ b/torchrec.modules.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    @@ -890,8 +890,26 @@

    torchrec.modules.embedding_configs

    -class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, num_embeddings_post_pruning: Union[int, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)
    +class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

    Bases: object

    +

    Base class for embedding configs.

    +
    +
    Parameters:
    +
      +
    • num_embeddings (int) – number of embeddings.

    • +
    • embedding_dim (int) – embedding dimension.

    • +
    • name (str) – name of the embedding table.

    • +
    • data_type (DataType) – data type of the embedding table.

    • +
    • feature_names (List[str]) – list of feature names.

    • +
    • weight_init_max (Optional[float]) – max value for weight initialization.

    • +
    • weight_init_min (Optional[float]) – min value for weight initialization.

    • +
    • num_embeddings_post_pruning (Optional[int]) – number of embeddings after pruning for inference. +If None, no pruning is applied.

    • +
    • init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – init function for embedding weights.

    • +
    • need_pos (bool) – whether table is position weighted.

    • +
    +
    +
    data_type: DataType = 'FP32'
    @@ -961,8 +979,15 @@
    -class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, num_embeddings_post_pruning: Union[int, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False, pooling: torchrec.modules.embedding_configs.PoolingType = <PoolingType.SUM: 'SUM'>)
    +class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)

    Bases: BaseEmbeddingConfig

    +

    EmbeddingBagConfig is a dataclass that represents a single embedding table, +where outputs are meant to be pooled.

    +
    +
    Parameters:
    +

    pooling (PoolingType) – pooling type.

    +
    +
    pooling: PoolingType = 'SUM'
    @@ -972,8 +997,9 @@
    -class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: torchrec.types.DataType = <DataType.FP32: 'FP32'>, feature_names: List[str] = <factory>, weight_init_max: Union[float, NoneType] = None, weight_init_min: Union[float, NoneType] = None, num_embeddings_post_pruning: Union[int, NoneType] = None, init_fn: Union[Callable[[torch.Tensor], Union[torch.Tensor, NoneType]], NoneType] = None, need_pos: bool = False)
    +class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)

    Bases: BaseEmbeddingConfig

    +

    EmbeddingConfig is a dataclass that represents a single embedding table.

    embedding_dim: int
    @@ -1021,7 +1047,16 @@
    class torchrec.modules.embedding_configs.PoolingType(value)

    Bases: Enum

    -

    An enumeration.

    +

    Pooling type for embedding table.

    +
    +
    Parameters:
    +
      +
    • SUM (str) – sum pooling.

    • +
    • MEAN (str) – mean pooling.

    • +
    • NONE (str) – no pooling.

    • +
    +
    +
    MEAN = 'MEAN'
    @@ -1197,19 +1232,31 @@
    property device: device
    -
    +

    Returns: +torch.device: The compute device.

    +
    embedding_bag_configs() List[EmbeddingBagConfig]
    -
    +
    +
    Returns:
    +

    The embedding bag configs.

    +
    +
    Return type:
    +

    List[EmbeddingBagConfig]

    +
    +
    +
    forward(features: KeyedJaggedTensor) KeyedTensor
    -
    +

    Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor +and returns a KeyedTensor, which is the result of pooling the embeddings for each feature.

    +
    Parameters:
    -

    features (KeyedJaggedTensor) – KJT of form [F X B X L].

    +

    features (KeyedJaggedTensor) – Input KJT

    Returns:

    KeyedTensor

    @@ -1220,12 +1267,22 @@
    is_weighted() bool
    -
    +
    +
    Returns:
    +

    Whether the EmbeddingBagCollection is weighted.

    +
    +
    Return type:
    +

    bool

    +
    +
    +
    reset_parameters() None
    -
    +

    Reset the parameters of the EmbeddingBagCollection. Parameter values +are intiialized based on the init_fn of each EmbeddingBagConfig if it exists.

    +
    @@ -1337,27 +1394,55 @@
    property device: device
    -
    +

    Returns: +torch.device: The compute device.

    +
    embedding_configs() List[EmbeddingConfig]
    -
    +
    +
    Returns:
    +

    The embedding configs.

    +
    +
    Return type:
    +

    List[EmbeddingConfig]

    +
    +
    +
    embedding_dim() int
    -
    +
    +
    Returns:
    +

    The embedding dimension.

    +
    +
    Return type:
    +

    int

    +
    +
    +
    embedding_names_by_table() List[List[str]]
    -
    +
    +
    Returns:
    +

    The embedding names by table.

    +
    +
    Return type:
    +

    List[List[str]]

    +
    +
    +
    forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]
    -
    +

    Run the EmbeddingBagCollection forward pass. This method takes in a KeyedJaggedTensor +and returns a Dict[str, JaggedTensor], which is the result of the individual embeddings for each feature.

    +
    Parameters:

    features (KeyedJaggedTensor) – KJT of form [F X B X L].

    @@ -1370,12 +1455,22 @@
    need_indices() bool
    -
    +
    +
    Returns:
    +

    Whether the EmbeddingCollection needs indices.

    +
    +
    Return type:
    +

    bool

    +
    +
    +
    reset_parameters() None
    -
    +

    Reset the parameters of the EmbeddingCollection. Parameter values +are intiialized based on the init_fn of each EmbeddingConfig if it exists.

    +
    diff --git a/torchrec.optim.html b/torchrec.optim.html index e57d38c3a..62d3e03ee 100644 --- a/torchrec.optim.html +++ b/torchrec.optim.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    @@ -428,7 +428,7 @@
    -class torchrec.optim.clipping.GradientClippingOptimizer(optimizer: KeyedOptimizer, clipping: GradientClipping = GradientClipping.NONE, max_gradient: float = 0.1, norm_type: Union[float, str] = 2.0)
    +class torchrec.optim.clipping.GradientClippingOptimizer(optimizer: KeyedOptimizer, clipping: GradientClipping = GradientClipping.NONE, max_gradient: float = 0.1, norm_type: Union[float, str] = 2.0, enable_global_grad_clip: bool = False, param_to_pgs: Optional[Dict[Parameter, List[ProcessGroup]]] = None)

    Bases: OptimizerWrapper

    Clips gradients before doing optimization step.

    @@ -438,9 +438,19 @@
  • clipping (GradientClipping) – how to clip gradients

  • max_gradient (float) – max value for clipping

  • norm_type (float or str) – type of the used p-norm. Can be 'inf' for infinity norm.

  • +
  • enable_global_grad_clip (bool) – whether to enable global gradient clipping.

  • +
  • param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]], optional) – Mapping of parameters +to process groups. Used for global gradient clipping in n-D model parallelism case. +Defaults to None, local gradient clipping is used.

  • +
    +
    +clip_grad_norm_() None
    +

    Clip the gradient norm of all parameters.

    +
    +
    step(closure: Optional[Any] = None) None
    diff --git a/torchrec.quant.html b/torchrec.quant.html index ba25fceca..95378e11e 100644 --- a/torchrec.quant.html +++ b/torchrec.quant.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    diff --git a/torchrec.sparse.html b/torchrec.sparse.html index deb19028a..c0a0e365f 100644 --- a/torchrec.sparse.html +++ b/torchrec.sparse.html @@ -271,7 +271,7 @@
    - 1.1.0.dev20240925+cpu + 1.1.0.dev20240926+cpu
    @@ -527,18 +527,46 @@
    device() device
    -
    +

    Get JaggedTensor device.

    +
    +
    Returns:
    +

    the device of the values tensor.

    +
    +
    Return type:
    +

    torch.device

    +
    +
    +
    static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) JaggedTensor
    -
    +

    Constructs an empty JaggedTensor.

    +
    +
    Parameters:
    +
      +
    • is_weighted (bool) – whether the JaggedTensor has weights.

    • +
    • device (Optional[torch.device]) – device for JaggedTensor.

    • +
    • values_dtype (Optional[torch.dtype]) – dtype for values.

    • +
    • weights_dtype (Optional[torch.dtype]) – dtype for weights.

    • +
    • lengths_dtype (torch.dtype) – dtype for lengths.

    • +
    +
    +
    Returns:
    +

    empty JaggedTensor.

    +
    +
    Return type:
    +

    JaggedTensor

    +
    +
    +
    static from_dense(values: List[Tensor], weights: Optional[List[Tensor]] = None) JaggedTensor
    -

    Constructs JaggedTensor from dense values/weights of shape (B, N,).

    -

    Note that lengths and offsets are still of shape (B,).

    +

    Constructs JaggedTensor from list of tensors as values, with optional weights. +lengths will be computed, of shape (B,), where B is len(values) which +represents the batch size.

    Parameters:
      @@ -572,7 +600,7 @@ weights=weights, ) -# j1 = [[1.0], [], [7.0], [8.0], [10.0, 11.0, 12.0]] +# j1 = [[1.0], [], [7.0, 8.0], [10.0, 11.0, 12.0]]
    @@ -580,29 +608,81 @@
    static from_dense_lengths(values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None) JaggedTensor
    -

    Constructs JaggedTensor from dense values/weights of shape (B, N,).

    -

    Note that lengths is still of shape (B,).

    +

    Constructs JaggedTensor from values and lengths tensors, with optional weights. +Note that lengths is still of shape (B,), where B is the batch size.

    +
    +
    Parameters:
    +
      +
    • values (torch.Tensor) – dense representation of values.

    • +
    • lengths (torch.Tensor) – jagged slices, represented as lengths.

    • +
    • weights (Optional[torch.Tensor]) – if values have weights, tensor with +the same shape as values.

    • +
    +
    +
    Returns:
    +

    JaggedTensor created from 2D dense tensor.

    +
    +
    Return type:
    +

    JaggedTensor

    +
    +
    lengths() Tensor
    -
    +

    Get JaggedTensor lengths. If not computed, compute it from offsets.

    +
    +
    Returns:
    +

    the lengths tensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    lengths_or_none() Optional[Tensor]
    -
    +

    Get JaggedTensor lengths. If not computed, return None.

    +
    +
    Returns:
    +

    the lengths tensor.

    +
    +
    Return type:
    +

    Optional[torch.Tensor]

    +
    +
    +
    offsets() Tensor
    -
    +

    Get JaggedTensor offsets. If not computed, compute it from lengths.

    +
    +
    Returns:
    +

    the offsets tensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    offsets_or_none() Optional[Tensor]
    -
    +

    Get JaggedTensor offsets. If not computed, return None.

    +
    +
    Returns:
    +

    the offsets tensor.

    +
    +
    Return type:
    +

    Optional[torch.Tensor]

    +
    +
    +
    @@ -613,9 +693,21 @@
    to(device: device, non_blocking: bool = False) JaggedTensor
    -

    Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, -to might return self or a copy of self. So please remember to use to with the assignment operator, -for example, in = in.to(new_device).

    +

    Move the JaggedTensor to the specified device.

    +
    +
    Parameters:
    +
      +
    • device (torch.device) – the device to move to.

    • +
    • non_blocking (bool) – whether to perform the copy asynchronously.

    • +
    +
    +
    Returns:
    +

    the moved JaggedTensor.

    +
    +
    Return type:
    +

    JaggedTensor

    +
    +
    @@ -685,8 +777,8 @@
    to_padded_dense(desired_length: Optional[int] = None, padding_value: float = 0.0) Tensor

    Constructs a 2D dense tensor from the JT’s values of shape (B, N,).

    -

    Note that B is the length of self.lengths() and N is the longest feature -length or desired_length.

    +

    Note that B is the length of self.lengths() and +N is the longest feature length or desired_length.

    If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

    @@ -729,10 +821,11 @@
    to_padded_dense_weights(desired_length: Optional[int] = None, padding_value: float = 0.0) Optional[Tensor]

    Constructs a 2D dense tensor from the JT’s weights of shape (B, N,).

    -

    Note that B is the length of self.lengths() and N is the longest feature -length or desired_length.

    +

    Note that B (batch size) is the length of self.lengths() and +N is the longest feature length or desired_length.

    If desired_length > length we will pad with padding_value, otherwise we will select the last value at desired_length.

    +

    Like to_padded_dense but for the JT’s weights instead of values.

    Parameters:
      @@ -773,17 +866,44 @@
      values() Tensor
      -
      +

      Get JaggedTensor values.

      +
      +
      Returns:
      +

      the values tensor.

      +
      +
      Return type:
      +

      torch.Tensor

      +
      +
      +
    weights() Tensor
    -
    +

    Get JaggedTensor weights. If None, throw an error.

    +
    +
    Returns:
    +

    the weights tensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    weights_or_none() Optional[Tensor]
    -
    +

    Get JaggedTensor weights. If None, return None.

    +
    +
    Returns:
    +

    the weights tensor.

    +
    +
    Return type:
    +

    Optional[torch.Tensor]

    +
    +
    +
    @@ -822,7 +942,8 @@
  • offset_per_key (Optional[List[int]]) – start offset for each key and final offset.

  • index_per_key (Optional[Dict[str, int]]) – index for each key.

  • -
  • jt_dict (Optional[Dict[str, JaggedTensor]]) –

  • +
  • jt_dict (Optional[Dict[str, JaggedTensor]]) – dictionary of keys to JaggedTensors. +Allow ability to make to_dict() lazy/cacheable.

  • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to expand deduplicated embedding output for variable stride per key.

  • @@ -853,12 +974,33 @@
    static concat(kjt_list: List[KeyedJaggedTensor]) KeyedJaggedTensor
    -
    +

    Concatenates a list of KeyedJaggedTensors into a single KeyedJaggedTensor.

    +
    +
    Parameters:
    +

    kjt_list (List[KeyedJaggedTensor]) – list of KeyedJaggedTensors to be concatenated.

    +
    +
    Returns:
    +

    concatenated KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    device() device
    -
    +

    Returns the device of the KeyedJaggedTensor.

    +
    +
    Returns:
    +

    device of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.device

    +
    +
    +
    @@ -883,12 +1025,42 @@
    static empty(is_weighted: bool = False, device: Optional[device] = None, values_dtype: Optional[dtype] = None, weights_dtype: Optional[dtype] = None, lengths_dtype: dtype = torch.int32) KeyedJaggedTensor
    -
    +

    Constructs an empty KeyedJaggedTensor.

    +
    +
    Parameters:
    +
      +
    • is_weighted (bool) – whether the KeyedJaggedTensor is weighted or not.

    • +
    • device (Optional[torch.device]) – device on which the KeyedJaggedTensor will be placed.

    • +
    • values_dtype (Optional[torch.dtype]) – dtype of the values tensor.

    • +
    • weights_dtype (Optional[torch.dtype]) – dtype of the weights tensor.

    • +
    • lengths_dtype (torch.dtype) – dtype of the lengths tensor.

    • +
    +
    +
    Returns:
    +

    empty KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    static empty_like(kjt: KeyedJaggedTensor) KeyedJaggedTensor
    -
    +

    Constructs an empty KeyedJaggedTensor with the same device and dtypes as the input KeyedJaggedTensor.

    +
    +
    Parameters:
    +

    kjt (KeyedJaggedTensor) – input KeyedJaggedTensor.

    +
    +
    Returns:
    +

    empty KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    @@ -898,12 +1070,16 @@
    static from_jt_dict(jt_dict: Dict[str, JaggedTensor]) KeyedJaggedTensor
    -

    Constructs a KeyedJaggedTensor from a Dict[str, JaggedTensor], -but this function will ONLY work if the JaggedTensors all +

    Constructs a KeyedJaggedTensor from a dictionary of JaggedTensors. +Automatically calls kjt.sync() on newly created KJT.

    +
    +

    Note

    +

    This function will ONLY work if the JaggedTensors all have the same “implicit” batch_size dimension.

    +

    Basically, we can visualize JaggedTensors as 2-D tensors of the format of [batch_size x variable_feature_dim]. -In case, we have some batch without a feature value, +In the case, we have some batch without a feature value, the input JaggedTensor could just not include any values.

    But KeyedJaggedTensor (by default) typically pad “None” so that all the JaggedTensors stored in the KeyedJaggedTensor @@ -917,14 +1093,6 @@ # ^ # dim_0

    -
    Notice that the inputs for this KeyedJaggedTensor would have looked like:

    values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7] # V == any tensor datatype -weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7] # W == any tensor datatype -lengths: torch.Tensor = [2, 0, 1, 1, 1, 3] # representing the jagged slice -offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8] # offsets from 0 for each jagged slice -keys: List[str] = [“Feature0”, “Feature1”] # correspond to each value of dim_0 -index_per_key: Dict[str, int] = {“Feature0”: 0, “Feature1”: 1} # index for each key -offset_per_key: List[int] = [0, 3, 8] # start offset for each key and final offset

    -
    Now if the input jt_dict = {

    # “Feature0” [V0,V1] [V2] # “Feature1” [V3] [V4] [V5,V6,V7]

    @@ -937,87 +1105,285 @@ would be [2, 1, 1, 1, 3] indicating variable batch_size dim_1 violates the existing assumption / precondition that KeyedJaggedTensor’s should have fixed batch_size dimension.

    +
    +
    Parameters:
    +

    jt_dict (Dict[str, JaggedTensor]) – dictionary of JaggedTensors.

    +
    +
    Returns:
    +

    constructed KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    static from_lengths_sync(keys: List[str], values: Tensor, lengths: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor
    -
    +

    Constructs a KeyedJaggedTensor from a list of keys, lengths, and offsets. +Same as from_offsets_sync except lengths are used instead of offsets.

    +
    +
    Parameters:
    +
      +
    • keys (List[str]) – list of keys.

    • +
    • values (torch.Tensor) – values tensor in dense representation.

    • +
    • lengths (torch.Tensor) – jagged slices, represented as lengths.

    • +
    • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the +same shape as values.

    • +
    • stride (Optional[int]) – number of examples per batch.

    • +
    • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size +(number of examples) per key per rank, with the outer list representing the +keys and the inner list representing the values.

    • +
    • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to +expand deduplicated embedding output for variable stride per key.

    • +
    +
    +
    Returns:
    +

    constructed KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    static from_offsets_sync(keys: List[str], values: Tensor, offsets: Tensor, weights: Optional[Tensor] = None, stride: Optional[int] = None, stride_per_key_per_rank: Optional[List[List[int]]] = None, inverse_indices: Optional[Tuple[List[str], Tensor]] = None) KeyedJaggedTensor
    -
    +

    Constructs a KeyedJaggedTensor from a list of keys, values, and offsets.

    +
    +
    Parameters:
    +
      +
    • keys (List[str]) – list of keys.

    • +
    • values (torch.Tensor) – values tensor in dense representation.

    • +
    • offsets (torch.Tensor) – jagged slices, represented as cumulative offsets.

    • +
    • weights (Optional[torch.Tensor]) – if the values have weights. Tensor with the +same shape as values.

    • +
    • stride (Optional[int]) – number of examples per batch.

    • +
    • stride_per_key_per_rank (Optional[List[List[int]]]) – batch size +(number of examples) per key per rank, with the outer list representing the +keys and the inner list representing the values.

    • +
    • inverse_indices (Optional[Tuple[List[str], torch.Tensor]]) – inverse indices to +expand deduplicated embedding output for variable stride per key.

    • +
    +
    +
    Returns:
    +

    constructed KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    index_per_key() Dict[str, int]
    -
    +

    Returns the index per key of the KeyedJaggedTensor.

    +
    +
    Returns:
    +

    index per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    Dict[str, int]

    +
    +
    +
    inverse_indices() Tuple[List[str], Tensor]
    -
    +

    Returns the inverse indices of the KeyedJaggedTensor. +If inverse indices are None, this will throw an error.

    +
    +
    Returns:
    +

    inverse indices of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    Tuple[List[str], torch.Tensor]

    +
    +
    +
    inverse_indices_or_none() Optional[Tuple[List[str], Tensor]]
    -
    +

    Returns the inverse indices of the KeyedJaggedTensor or None if they don’t exist.

    +
    +
    Returns:
    +

    inverse indices of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    Optional[Tuple[List[str], torch.Tensor]]

    +
    +
    +
    keys() List[str]
    -
    +

    Returns the keys of the KeyedJaggedTensor.

    +
    +
    Returns:
    +

    keys of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[str]

    +
    +
    +
    length_per_key() List[int]
    -
    +

    Returns the length per key of the KeyedJaggedTensor. +If length per key is None, this will compute it.

    +
    +
    Returns:
    +

    length per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    length_per_key_or_none() Optional[List[int]]
    -
    +

    Returns the length per key of the KeyedJaggedTensor or None if it hasn’t been computed.

    +
    +
    Returns:
    +

    length per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    lengths() Tensor
    -
    +

    Returns the lengths of the KeyedJaggedTensor. +If the lengths are not computed yet, it will compute them.

    +
    +
    Returns:
    +

    lengths of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    lengths_offset_per_key() List[int]
    -
    - +

    Returns the lengths offset per key of the KeyedJaggedTensor. +If lengths offset per key is None, this will compute it.

    +
    +
    Returns:
    +

    lengths offset per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    +
    lengths_or_none() Optional[Tensor]
    -
    +

    Returns the lengths of the KeyedJaggedTensor or None if they are not computed yet.

    +
    +
    Returns:
    +

    lengths of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    offset_per_key() List[int]
    -
    +

    Returns the offset per key of the KeyedJaggedTensor. +If offset per key is None, this will compute it.

    +
    +
    Returns:
    +

    offset per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    offset_per_key_or_none() Optional[List[int]]
    -
    +

    Returns the offset per key of the KeyedJaggedTensor or None if it hasn’t been computed.

    +
    +
    Returns:
    +

    offset per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    offsets() Tensor
    -
    +

    Returns the offsets of the KeyedJaggedTensor. +If the offsets are not computed yet, it will compute them.

    +
    +
    Returns:
    +

    offsets of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    offsets_or_none() Optional[Tensor]
    -
    +

    Returns the offsets of the KeyedJaggedTensor or None if they are not computed yet.

    +
    +
    Returns:
    +

    offsets of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    permute(indices: List[int], indices_tensor: Optional[Tensor] = None) KeyedJaggedTensor
    -
    +

    Permutes the KeyedJaggedTensor.

    +
    +
    Parameters:
    +
      +
    • indices (List[int]) – list of indices.

    • +
    • indices_tensor (Optional[torch.Tensor]) – tensor of indices.

    • +
    +
    +
    Returns:
    +

    permuted KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    @@ -1033,65 +1399,184 @@
    split(segments: List[int]) List[KeyedJaggedTensor]
    -
    +

    Splits the KeyedJaggedTensor into a list of KeyedJaggedTensor.

    +
    +
    Parameters:
    +

    segments (List[int]) – list of segments.

    +
    +
    Returns:
    +

    list of KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[KeyedJaggedTensor]

    +
    +
    +
    stride() int
    -
    +

    Returns the stride of the KeyedJaggedTensor. +If stride is None, this will compute it.

    +
    +
    Returns:
    +

    stride of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    int

    +
    +
    +
    stride_per_key() List[int]
    -
    +

    Returns the stride per key of the KeyedJaggedTensor. +If stride per key is None, this will compute it.

    +
    +
    Returns:
    +

    stride per key of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    stride_per_key_per_rank() List[List[int]]
    -
    +

    Returns the stride per key per rank of the KeyedJaggedTensor.

    +
    +
    Returns:
    +

    stride per key per rank of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    List[List[int]]

    +
    +
    +
    sync() KeyedJaggedTensor
    -
    +

    Synchronizes the KeyedJaggedTensor by computing the offset_per_key and length_per_key.

    +
    +
    Returns:
    +

    synced KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    to(device: device, non_blocking: bool = False, dtype: Optional[dtype] = None) KeyedJaggedTensor
    -

    Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, -to might return self or a copy of self. So please remember to use to with the assignment operator, -for example, in = in.to(new_device).

    +

    Returns a copy of KeyedJaggedTensor in the specified device and dtype.

    +
    +
    Parameters:
    +
      +
    • device (torch.device) – the desired device of the copy.

    • +
    • non_blocking (bool) – whether to copy the tensors in a non-blocking fashion.

    • +
    • dtype (Optional[torch.dtype]) – the desired data type of the copy.

    • +
    +
    +
    Returns:
    +

    the copied KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    to_dict() Dict[str, JaggedTensor]
    -
    +

    Returns a dictionary of JaggedTensor for each key. +Will cache result in self._jt_dict.

    +
    +
    Returns:
    +

    dictionary of JaggedTensor for each key.

    +
    +
    Return type:
    +

    Dict[str, JaggedTensor]

    +
    +
    +
    unsync() KeyedJaggedTensor
    -
    +

    Unsyncs the KeyedJaggedTensor by clearing the offset_per_key and length_per_key.

    +
    +
    Returns:
    +

    unsynced KeyedJaggedTensor.

    +
    +
    Return type:
    +

    KeyedJaggedTensor

    +
    +
    +
    values() Tensor
    -
    +

    Returns the values of the KeyedJaggedTensor.

    +
    +
    Returns:
    +

    values of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    variable_stride_per_key() bool
    -
    +

    Returns whether the KeyedJaggedTensor has variable stride per key.

    +
    +
    Returns:
    +

    whether the KeyedJaggedTensor has variable stride per key.

    +
    +
    Return type:
    +

    bool

    +
    +
    +
    weights() Tensor
    -
    +

    Returns the weights of the KeyedJaggedTensor. +If weights is None, this will throw an error.

    +
    +
    Returns:
    +

    weights of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    weights_or_none() Optional[Tensor]
    -
    +

    Returns the weights of the KeyedJaggedTensor or None if they don’t exist.

    +
    +
    Returns:
    +

    weights of the KeyedJaggedTensor.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +
    @@ -1134,47 +1619,108 @@ kt = KeyedTensor.from_tensor_list(keys, tensor_list) kt.values() - # tensor( - # [ - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # [1, 1, 2, 1, 2, 3, 1, 2, 3], - # ] - # ) +# torch.Tensor( +# [ +# [1, 1, 2, 1, 2, 3, 1, 2, 3], +# [1, 1, 2, 1, 2, 3, 1, 2, 3], +# [1, 1, 2, 1, 2, 3, 1, 2, 3], +# ] +# ) kt["Embedding B"] - # tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]]) +# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
    device() device
    -
    +
    +
    Returns:
    +

    device of the values tensor.

    +
    +
    Return type:
    +

    torch.device

    +
    +
    +
    static from_tensor_list(keys: List[str], tensors: List[Tensor], key_dim: int = 1, cat_dim: int = 1) KeyedTensor
    -
    +

    Create a KeyedTensor from a list of tensors. The tensors are concatenated +along the cat_dim. The keys are used to index the tensors.

    +
    +
    Parameters:
    +
      +
    • keys (List[str]) – list of keys.

    • +
    • tensors (List[torch.Tensor]) – list of tensors.

    • +
    • key_dim (int) – key dimension, zero indexed - defaults to 1 +(typically B is 0-dimension).

    • +
    • cat_dim (int) – dimension along which to concatenate the tensors - defaults

    • +
    +
    +
    Returns:
    +

    keyed tensor.

    +
    +
    Return type:
    +

    KeyedTensor

    +
    +
    +
    key_dim() int
    -
    +
    +
    Returns:
    +

    key dimension, zero indexed - typically B is 0-dimension.

    +
    +
    Return type:
    +

    int

    +
    +
    +
    keys() List[str]
    -
    +
    +
    Returns:
    +

    list of keys.

    +
    +
    Return type:
    +

    List[str]

    +
    +
    +
    length_per_key() List[int]
    -
    +
    +
    Returns:
    +

    length of each key along key dimension.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    offset_per_key() List[int]
    -
    +

    Get the offset of each key along key dimension. +Compute and cache if not already computed.

    +
    +
    Returns:
    +

    offset of each key along key dimension.

    +
    +
    Return type:
    +

    List[int]

    +
    +
    +
    @@ -1185,30 +1731,91 @@
    static regroup(keyed_tensors: List[KeyedTensor], groups: List[List[str]]) List[Tensor]
    -
    +

    Regroup a list of KeyedTensors into a list of tensors.

    +
    +
    Parameters:
    +
      +
    • keyed_tensors (List[KeyedTensor]) – list of KeyedTensors.

    • +
    • groups (List[List[str]]) – list of groups of keys.

    • +
    +
    +
    Returns:
    +

    list of tensors.

    +
    +
    Return type:
    +

    List[torch.Tensor]

    +
    +
    +
    static regroup_as_dict(keyed_tensors: List[KeyedTensor], groups: List[List[str]], keys: List[str]) Dict[str, Tensor]
    -
    +

    Regroup a list of KeyedTensors into a dictionary of tensors.

    +
    +
    Parameters:
    +
      +
    • keyed_tensors (List[KeyedTensor]) – list of KeyedTensors.

    • +
    • groups (List[List[str]]) – list of groups of keys.

    • +
    • keys (List[str]) – list of keys.

    • +
    +
    +
    Returns:
    +

    dictionary of tensors.

    +
    +
    Return type:
    +

    Dict[str, torch.Tensor]

    +
    +
    +
    to(device: device, non_blocking: bool = False) KeyedTensor
    -

    Please be aware that according to https://pytorch.org/docs/stable/generated/torch.Tensor.to.html, -to might return self or a copy of self. So please remember to use to with the assignment operator, -for example, in = in.to(new_device).

    +

    Moves the values tensor to the specified device.

    +
    +
    Parameters:
    +
      +
    • device (torch.device) – device to move the values tensor to.

    • +
    • non_blocking (bool) – whether to perform the operation asynchronously +(default: False).

    • +
    +
    +
    Returns:
    +

    keyed tensor with values tensor moved to the specified device.

    +
    +
    Return type:
    +

    KeyedTensor

    +
    +
    to_dict() Dict[str, Tensor]
    -
    +
    +
    Returns:
    +

    dictionary of tensors keyed by the keys.

    +
    +
    Return type:
    +

    Dict[str, torch.Tensor]

    +
    +
    +
    values() Tensor
    -
    +

    Get the values tensor.

    +
    +
    Returns:
    +

    dense tensor, concatenated typically along key dimension.

    +
    +
    Return type:
    +

    torch.Tensor

    +
    +
    +