diff --git a/examples/qwen_image.json b/examples/qwen_image.json new file mode 100644 index 0000000..1de531a --- /dev/null +++ b/examples/qwen_image.json @@ -0,0 +1,180 @@ +{ + "last_node_id": 10, + "last_link_id": 10, + "nodes": [ + { + "id": 1, + "type": "Load Diffusion Model", + "pos": [50, 50], + "size": {"0": 315, "1": 58}, + "flags": {}, + "order": 0, + "mode": 0, + "outputs": [ + {"name": "MODEL", "type": "MODEL", "links": [1], "shape": 3} + ], + "properties": {"Node name for S&R": "Load Diffusion Model"}, + "widgets_values": ["qwen_image_fp8_e4m3fn.safetensors"] + }, + { + "id": 2, + "type": "TeaCache", + "pos": [400, 50], + "size": {"0": 315, "1": 150}, + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + {"name": "model", "type": "MODEL", "link": 1} + ], + "outputs": [ + {"name": "model", "type": "MODEL", "links": [2], "shape": 3} + ], + "properties": {"Node name for S&R": "TeaCache"}, + "widgets_values": ["qwen-image", 0.35, 0.1, 1.0, "cuda"] + }, + { + "id": 3, + "type": "Load CLIP", + "pos": [50, 150], + "size": {"0": 315, "1": 58}, + "flags": {}, + "order": 2, + "mode": 0, + "outputs": [ + {"name": "CLIP", "type": "CLIP", "links": [3], "shape": 3} + ], + "properties": {"Node name for S&R": "Load CLIP"}, + "widgets_values": ["qwen_2.5_vl_7b_fp8_scaled.safetensors"] + }, + { + "id": 4, + "type": "Load VAE", + "pos": [50, 250], + "size": {"0": 315, "1": 58}, + "flags": {}, + "order": 3, + "mode": 0, + "outputs": [ + {"name": "VAE", "type": "VAE", "links": [4], "shape": 3} + ], + "properties": {"Node name for S&R": "Load VAE"}, + "widgets_values": ["qwen_image_vae.safetensors"] + }, + { + "id": 5, + "type": "CLIP Text Encoder", + "pos": [400, 250], + "size": {"0": 315, "1": 100}, + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + {"name": "clip", "type": "CLIP", "link": 3} + ], + "outputs": [ + {"name": "CONDITIONING", "type": "CONDITIONING", "links": [5], "shape": 3} + ], + "properties": {"Node name for S&R": "CLIP Text Encoder"}, + "widgets_values": ["A beautiful landscape with mountains and a lake"] + }, + { + "id": 6, + "type": "EmptySD3LatentImage", + "pos": [750, 50], + "size": {"0": 315, "1": 100}, + "flags": {}, + "order": 5, + "mode": 0, + "outputs": [ + {"name": "LATENT", "type": "LATENT", "links": [6], "shape": 3} + ], + "properties": {"Node name for S&R": "EmptySD3LatentImage"}, + "widgets_values": [1024, 1024, 1] + }, + { + "id": 7, + "type": "KSampler", + "pos": [750, 200], + "size": {"0": 315, "1": 262}, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + {"name": "model", "type": "MODEL", "link": 2}, + {"name": "positive", "type": "CONDITIONING", "link": 5}, + {"name": "negative", "type": "CONDITIONING", "link": 7}, + {"name": "latent_image", "type": "LATENT", "link": 6} + ], + "outputs": [ + {"name": "LATENT", "type": "LATENT", "links": [8], "shape": 3} + ], + "properties": {"Node name for S&R": "KSampler"}, + "widgets_values": [12345, "fixed", 20, 7.0, "euler", "normal", 1.0] + }, + { + "id": 8, + "type": "CLIP Text Encoder", + "pos": [400, 400], + "size": {"0": 315, "1": 100}, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + {"name": "clip", "type": "CLIP", "link": 3} + ], + "outputs": [ + {"name": "CONDITIONING", "type": "CONDITIONING", "links": [7], "shape": 3} + ], + "properties": {"Node name for S&R": "CLIP Text Encoder"}, + "widgets_values": [""] + }, + { + "id": 9, + "type": "VAE Decode", + "pos": [1100, 200], + "size": {"0": 210, "1": 46}, + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + {"name": "samples", "type": "LATENT", "link": 8}, + {"name": "vae", "type": "VAE", "link": 4} + ], + "outputs": [ + {"name": "IMAGE", "type": "IMAGE", "links": [9], "shape": 3} + ], + "properties": {"Node name for S&R": "VAE Decode"} + }, + { + "id": 10, + "type": "Save Image", + "pos": [1350, 200], + "size": {"0": 315, "1": 58}, + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + {"name": "images", "type": "IMAGE", "link": 9} + ], + "properties": {"Node name for S&R": "Save Image"}, + "widgets_values": ["qwen_image_teacache"] + } + ], + "links": [ + [1, 1, 0, 2, 0, "MODEL"], + [2, 2, 0, 7, 0, "MODEL"], + [3, 3, 0, 5, 0, "CLIP"], + [3, 3, 0, 8, 0, "CLIP"], + [4, 4, 0, 9, 1, "VAE"], + [5, 5, 0, 7, 1, "CONDITIONING"], + [6, 6, 0, 7, 3, "LATENT"], + [7, 8, 0, 7, 2, "CONDITIONING"], + [8, 7, 0, 9, 0, "LATENT"], + [9, 9, 0, 10, 0, "IMAGE"] + ], + "groups": [], + "config": {}, + "extra": {}, + "version": 0.4 +} diff --git a/examples/qwen_image_edit.json b/examples/qwen_image_edit.json new file mode 100644 index 0000000..11b667a --- /dev/null +++ b/examples/qwen_image_edit.json @@ -0,0 +1,662 @@ +{ + "id": "00000000-0000-0000-0000-000000000000", + "revision": 0, + "last_node_id": 151, + "last_link_id": 73, + "nodes": [ + { + "id": 38, + "type": "CLIPLoader", + "pos": [ + -266.1712646484375, + 135.491943359375 + ], + "size": [ + 379.7156982421875, + 117.8611831665039 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "CLIP", + "type": "CLIP", + "links": [ + 41 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "CLIPLoader" + }, + "widgets_values": [ + "qwen_2.5_vl_7b_fp8_scaled.safetensors", + "qwen_image", + "default" + ] + }, + { + "id": 39, + "type": "VAELoader", + "pos": [ + -261.932861328125, + 335.3167419433594 + ], + "size": [ + 270, + 58 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "VAE", + "type": "VAE", + "links": [ + 19, + 42 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "VAELoader" + }, + "widgets_values": [ + "qwen_image_vae.safetensors" + ] + }, + { + "id": 137, + "type": "ConditioningZeroOut", + "pos": [ + 978.658935546875, + 379.45233154296875 + ], + "size": [ + 140, + 26 + ], + "flags": { + "collapsed": true + }, + "order": 10, + "mode": 0, + "inputs": [ + { + "name": "conditioning", + "type": "CONDITIONING", + "link": 46 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 47 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.50", + "Node name for S&R": "ConditioningZeroOut" + }, + "widgets_values": [] + }, + { + "id": 100, + "type": "UNETLoader", + "pos": [ + -122.38922882080078, + 649.0233764648438 + ], + "size": [ + 382.58837890625, + 82 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 51 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "UNETLoader" + }, + "widgets_values": [ + "qwen_image_edit_fp8_e4m3fn.safetensors", + "default" + ] + }, + { + "id": 8, + "type": "VAEDecode", + "pos": [ + 1002.855224609375, + 448.6126403808594 + ], + "size": [ + 140, + 46 + ], + "flags": {}, + "order": 12, + "mode": 0, + "inputs": [ + { + "name": "samples", + "type": "LATENT", + "link": 18 + }, + { + "name": "vae", + "type": "VAE", + "link": 19 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 22 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "VAEDecode" + }, + "widgets_values": [] + }, + { + "id": 147, + "type": "Note", + "pos": [ + -421.9145812988281, + 654.582763671875 + ], + "size": [ + 225.5645294189453, + 411.87225341796875 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [], + "properties": {}, + "widgets_values": [ + "Make some changes to the actions and scenes while maintaining the style and the appearance of the characters: A beautiful lady, her slender waist subtly highlighted by the oversized light blue sweater, sits languidly on a plush sofa. Her toned thighs, subtly parted beneath faded jeans, catch the warm glow from the television. A nervous smile plays on her slightly parted, glossy lips, revealing a hint of dewy collarbones beneath the soft fabric. Her hooded gaze, focused intently, conveys a gentle yet alluring expression.\\n\\nIn the foreground, a single rumpled silk cushion lies beside her, catching the flickering light from the screen. The midground reveals a cozy living room, with the soft glow of a movie playing on the television screen, reflecting in the polished surface of an antique coffee table. In the background, the blurred city skyline is visible through rain-streaked windows, casting liquid shadows that dance across the walls, enhancing the romantic tension of the evening" + ], + "color": "#432", + "bgcolor": "#653" + }, + { + "id": 3, + "type": "KSampler", + "pos": [ + 687.1221313476562, + 1006.6215209960938 + ], + "size": [ + 270, + 262 + ], + "flags": {}, + "order": 11, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 12 + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": 45 + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": 47 + }, + { + "name": "latent_image", + "type": "LATENT", + "link": 43 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 18 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 135908340840904, + "randomize", + 25, + 3, + "euler", + "simple", + 1 + ] + }, + { + "id": 58, + "type": "EmptySD3LatentImage", + "pos": [ + 689.9451293945312, + 814.2872314453125 + ], + "size": [ + 270, + 106 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": [ + 43 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "EmptySD3LatentImage" + }, + "widgets_values": [ + 512, + 1024, + 1 + ] + }, + { + "id": 66, + "type": "ModelSamplingAuraFlow", + "pos": [ + 348.5574951171875, + 229.88905334472656 + ], + "size": [ + 210, + 58.961429595947266 + ], + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 51 + } + ], + "outputs": [ + { + "name": "MODEL", + "type": "MODEL", + "links": [ + 62 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.49", + "Node name for S&R": "ModelSamplingAuraFlow" + }, + "widgets_values": [ + 3.1000000000000005 + ] + }, + { + "id": 150, + "type": "Note", + "pos": [ + -904.476318359375, + 632.1348876953125 + ], + "size": [ + 225.5645294189453, + 411.87225341796875 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [], + "outputs": [], + "properties": {}, + "widgets_values": [ + "把这张图片转成UV贴图的3D布线图,只使用黑白两色。白色只展示布线不需要贴图。" + ], + "color": "#432", + "bgcolor": "#653" + }, + { + "id": 136, + "type": "TextEncodeQwenImageEdit", + "pos": [ + 307.0711975097656, + 538.617431640625 + ], + "size": [ + 443.0798034667969, + 204.4992218017578 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "clip", + "type": "CLIP", + "link": 41 + }, + { + "name": "vae", + "shape": 7, + "type": "VAE", + "link": 42 + }, + { + "name": "image", + "shape": 7, + "type": "IMAGE", + "link": 72 + } + ], + "outputs": [ + { + "name": "CONDITIONING", + "type": "CONDITIONING", + "links": [ + 45, + 46 + ] + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.50", + "Node name for S&R": "TextEncodeQwenImageEdit" + }, + "widgets_values": [ + "把图像转换成线稿" + ] + }, + { + "id": 127, + "type": "Save_as_webp", + "pos": [ + 314.478515625, + 833.6119384765625 + ], + "size": [ + 255.1419677734375, + 597.0552368164062 + ], + "flags": {}, + "order": 13, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 22 + } + ], + "outputs": [], + "properties": { + "aux_id": "Kaharos94/ComfyUI-Saveaswebp", + "ver": "183863f809b97b4422268ac0647d40dd7648cfff", + "Node name for S&R": "Save_as_webp" + }, + "widgets_values": [ + "ComfyUI", + "lossy", + 90 + ] + }, + { + "id": 102, + "type": "TeaCache", + "pos": [ + 991.4552612304688, + 744.995361328125 + ], + "size": [ + 270, + 154 + ], + "flags": {}, + "order": 9, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": 62 + } + ], + "outputs": [ + { + "name": "model", + "type": "MODEL", + "links": [ + 12 + ] + } + ], + "properties": { + "cnr_id": "teacache", + "ver": "91dff8e31684ca70a5fda309611484402d8fa192", + "Node name for S&R": "TeaCache" + }, + "widgets_values": [ + "qwen-image-edit", + 0.4, + 0.2, + 0.8, + "cuda" + ] + }, + { + "id": 139, + "type": "LoadImage", + "pos": [ + -49.56144332885742, + 819.1526489257812 + ], + "size": [ + 305.02783203125, + 595.702880859375 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 72 + ] + }, + { + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { + "cnr_id": "comfy-core", + "ver": "0.3.50", + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "pasted/image (52).png", + "image" + ] + } + ], + "links": [ + [ + 12, + 102, + 0, + 3, + 0, + "MODEL" + ], + [ + 18, + 3, + 0, + 8, + 0, + "LATENT" + ], + [ + 19, + 39, + 0, + 8, + 1, + "VAE" + ], + [ + 22, + 8, + 0, + 127, + 0, + "IMAGE" + ], + [ + 41, + 38, + 0, + 136, + 0, + "CLIP" + ], + [ + 42, + 39, + 0, + 136, + 1, + "VAE" + ], + [ + 43, + 58, + 0, + 3, + 3, + "LATENT" + ], + [ + 45, + 136, + 0, + 3, + 1, + "CONDITIONING" + ], + [ + 46, + 136, + 0, + 137, + 0, + "CONDITIONING" + ], + [ + 47, + 137, + 0, + 3, + 2, + "CONDITIONING" + ], + [ + 51, + 100, + 0, + 66, + 0, + "MODEL" + ], + [ + 62, + 66, + 0, + 102, + 0, + "MODEL" + ], + [ + 72, + 139, + 0, + 136, + 2, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.0152559799480796, + "offset": [ + 559.060009645298, + -426.47256949801374 + ] + }, + "frontendVersion": "1.25.9", + "VHS_latentpreview": false, + "VHS_latentpreviewrate": 0, + "VHS_MetadataImage": true, + "VHS_KeepIntermediate": true + }, + "version": 0.4 +} \ No newline at end of file diff --git a/nodes.py b/nodes.py index 04bff4a..deaf784 100644 --- a/nodes.py +++ b/nodes.py @@ -1,1068 +1,1395 @@ -import math -import torch -import comfy.ldm.common_dit -import comfy.model_management as mm - -from torch import Tensor -from einops import repeat -from typing import Optional -from unittest.mock import patch - -from comfy.ldm.flux.layers import timestep_embedding, apply_mod -from comfy.ldm.lightricks.model import precompute_freqs_cis -from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords -from comfy.ldm.wan.model import sinusoidal_embedding_1d - - -SUPPORTED_MODELS_COEFFICIENTS = { - "flux": [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], - "flux-kontext": [-1.04655119e+03, 3.12563399e+02, -1.69500694e+01, 4.10995971e-01, 3.74537863e-02], - "ltxv": [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03], - "lumina_2": [-8.74643948e+02, 4.66059906e+02, -7.51559762e+01, 5.32836175e+00, -3.27258296e-02], - "hunyuan_video": [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02], - "hidream_i1_full": [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01], - "hidream_i1_dev": [1.39997273, -4.30130469, 5.01534416, -2.20504164, 0.93942874], - "hidream_i1_fast": [2.26509623, -6.88864563, 7.61123826, -3.10849353, 0.99927602], - "wan2.1_t2v_1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], - "wan2.1_t2v_14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], - "wan2.1_i2v_480p_14B": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], - "wan2.1_i2v_720p_14B": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], - "wan2.1_t2v_1.3B_ret_mode": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], - "wan2.1_t2v_14B_ret_mode": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], - "wan2.1_i2v_480p_14B_ret_mode": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], - "wan2.1_i2v_720p_14B_ret_mode": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], -} - -def poly1d(coefficients, x): - result = torch.zeros_like(x) - for i, coeff in enumerate(coefficients): - result += coeff * (x ** (len(coefficients) - 1 - i)) - return result - -def teacache_flux_forward( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor = None, - control = None, - transformer_options={}, - attn_mask: Tensor = None, - ) -> Tensor: - patches_replace = transformer_options.get("patches_replace", {}) - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) - if self.params.guidance_embed: - if guidance is not None: - vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) - txt = self.txt_in(txt) - - if img_ids is not None: - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - else: - pe = None - - blocks_replace = patches_replace.get("dit", {}) - - # enable teacache - img_mod1, _ = self.double_blocks[0].img_mod(vec) - modulated_inp = self.double_blocks[0].img_norm1(img) - modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift).to(cache_device) - ca_idx = 0 - - if not hasattr(self, 'accumulated_rel_l1_distance'): - should_calc = True - self.accumulated_rel_l1_distance = 0 - else: - try: - self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())).abs() - if self.accumulated_rel_l1_distance < rel_l1_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - except: - should_calc = True - self.accumulated_rel_l1_distance = 0 - - self.previous_modulated_input = modulated_inp - - if not enable_teacache: - should_calc = True - - if not should_calc: - img += self.previous_residual.to(img.device) - else: - ori_img = img.to(cache_device) - for i, block in enumerate(self.double_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"], out["txt"] = block(img=args["img"], - txt=args["txt"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) - return out - - out = blocks_replace[("double_block", i)]({"img": img, - "txt": txt, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) - txt = out["txt"] - img = out["img"] - else: - img, txt = block(img=img, - txt=txt, - vec=vec, - pe=pe, - attn_mask=attn_mask) - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - img += add - - # PuLID attention - if getattr(self, "pulid_data", {}): - if i % self.pulid_double_interval == 0: - # Will calculate influence of all pulid nodes at once - for _, node_data in self.pulid_data.items(): - if torch.any((node_data['sigma_start'] >= timesteps) - & (timesteps >= node_data['sigma_end'])): - img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) - ca_idx += 1 - - if img.dtype == torch.float16: - img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) - - img = torch.cat((txt, img), 1) - - for i, block in enumerate(self.single_blocks): - if ("single_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], - vec=args["vec"], - pe=args["pe"], - attn_mask=args.get("attn_mask")) - return out - - out = blocks_replace[("single_block", i)]({"img": img, - "vec": vec, - "pe": pe, - "attn_mask": attn_mask}, - {"original_block": block_wrap}) - img = out["img"] - else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) - - if control is not None: # Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] - if add is not None: - img[:, txt.shape[1] :, ...] += add - - # PuLID attention - if getattr(self, "pulid_data", {}): - real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] - if i % self.pulid_single_interval == 0: - # Will calculate influence of all nodes at once - for _, node_data in self.pulid_data.items(): - if torch.any((node_data['sigma_start'] >= timesteps) - & (timesteps >= node_data['sigma_end'])): - real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) - ca_idx += 1 - img = torch.cat((txt, real_img), 1) - - img = img[:, txt.shape[1] :, ...] - self.previous_residual = img.to(cache_device) - ori_img - - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - - return img - -def teacache_hidream_forward( - self, - x: torch.Tensor, - t: torch.Tensor, - y: Optional[torch.Tensor] = None, - context: Optional[torch.Tensor] = None, - encoder_hidden_states_llama3=None, - image_cond=None, - control = None, - transformer_options = {}, - ) -> torch.Tensor: - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - cond_or_uncond = transformer_options.get("cond_or_uncond") - model_type = transformer_options.get("model_type") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - bs, c, h, w = x.shape - if image_cond is not None: - x = torch.cat([x, image_cond], dim=-1) - hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - timesteps = t - pooled_embeds = y - T5_encoder_hidden_states = context - - img_sizes = None - - # spatial forward - batch_size = hidden_states.shape[0] - hidden_states_type = hidden_states.dtype - - # 0. time - timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) - timesteps = self.t_embedder(timesteps, hidden_states_type) - p_embedder = self.p_embedder(pooled_embeds) - adaln_input = timesteps + p_embedder - - hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) - if image_tokens_masks is None: - pH, pW = img_sizes[0] - img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) - hidden_states = self.x_embedder(hidden_states) - - # T5_encoder_hidden_states = encoder_hidden_states[0] - encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) - encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] - - if self.caption_projection is not None: - new_encoder_hidden_states = [] - for i, enc_hidden_state in enumerate(encoder_hidden_states): - enc_hidden_state = self.caption_projection[i](enc_hidden_state) - enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) - new_encoder_hidden_states.append(enc_hidden_state) - encoder_hidden_states = new_encoder_hidden_states - T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) - T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) - encoder_hidden_states.append(T5_encoder_hidden_states) - - txt_ids = torch.zeros( - batch_size, - encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], - 3, - device=img_ids.device, dtype=img_ids.dtype - ) - ids = torch.cat((img_ids, txt_ids), dim=1) - rope = self.pe_embedder(ids) - - # enable teacache - modulated_inp = timesteps.to(cache_device) if "full" in model_type else hidden_states.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } - - def update_cache_state(cache, modulated_inp): - if cache['previous_modulated_input'] is not None: - try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) - if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: - cache['should_calc'] = False - else: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - except: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - cache['previous_modulated_input'] = modulated_inp - - b = int(len(hidden_states) / len(cond_or_uncond)) - - for i, k in enumerate(cond_or_uncond): - update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) - - if enable_teacache: - should_calc = False - for k in cond_or_uncond: - should_calc = (should_calc or self.teacache_state[k]['should_calc']) - else: - should_calc = True - - if not should_calc: - for i, k in enumerate(cond_or_uncond): - hidden_states[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(hidden_states.device) - else: - # 2. Blocks - ori_hidden_states = hidden_states.to(cache_device) - block_id = 0 - initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) - initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] - for bid, block in enumerate(self.double_stream_blocks): - cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] - cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) - hidden_states, initial_encoder_hidden_states = block( - image_tokens = hidden_states, - image_tokens_masks = image_tokens_masks, - text_tokens = cur_encoder_hidden_states, - adaln_input = adaln_input, - rope = rope, - ) - initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] - block_id += 1 - - image_tokens_seq_len = hidden_states.shape[1] - hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) - hidden_states_seq_len = hidden_states.shape[1] - if image_tokens_masks is not None: - encoder_attention_mask_ones = torch.ones( - (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), - device=image_tokens_masks.device, dtype=image_tokens_masks.dtype - ) - image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) - - for bid, block in enumerate(self.single_stream_blocks): - cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] - hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) - hidden_states = block( - image_tokens=hidden_states, - image_tokens_masks=image_tokens_masks, - text_tokens=None, - adaln_input=adaln_input, - rope=rope, - ) - hidden_states = hidden_states[:, :hidden_states_seq_len] - block_id += 1 - - hidden_states = hidden_states[:, :image_tokens_seq_len, ...] - for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (hidden_states.to(cache_device) - ori_hidden_states)[i*b:(i+1)*b] - - output = self.final_layer(hidden_states, adaln_input) - output = self.unpatchify(output, img_sizes) - return -output[:, :, :h, :w] - -def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - cond_or_uncond = transformer_options.get("cond_or_uncond") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - t = 1.0 - timesteps - cap_feats = context - cap_mask = attention_mask - bs, c, h, w = x.shape - x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - - t = self.t_embedder(t, dtype=x.dtype) # (N, D) - adaln_input = t - - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - - x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) - freqs_cis = freqs_cis.to(x.device) - - # enable teacache - modulated_inp = t.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } - - def update_cache_state(cache, modulated_inp): - if cache['previous_modulated_input'] is not None: - try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) - if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: - cache['should_calc'] = False - else: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - except: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - cache['previous_modulated_input'] = modulated_inp - - b = int(len(x) / len(cond_or_uncond)) - - for i, k in enumerate(cond_or_uncond): - update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) - - if enable_teacache: - should_calc = False - for k in cond_or_uncond: - should_calc = (should_calc or self.teacache_state[k]['should_calc']) - else: - should_calc = True - - if not should_calc: - for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: - ori_x = x.to(cache_device) - # 2. Blocks - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input) - for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] - - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] - - return -x - -def teacache_hunyuanvideo_forward( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - txt_mask: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor = None, - guiding_frame_index=None, - ref_latent=None, - control=None, - transformer_options={}, - ) -> Tensor: - patches_replace = transformer_options.get("patches_replace", {}) - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - initial_shape = list(img.shape) - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - - if ref_latent is not None: - ref_latent_ids = self.img_ids(ref_latent) - ref_latent = self.img_in(ref_latent) - img = torch.cat([ref_latent, img], dim=-2) - ref_latent_ids[..., 0] = -1 - ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) - img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) - - if guiding_frame_index is not None: - token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) - vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) - vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) - frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) - modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] - modulation_dims_txt = [(0, None, 1)] - else: - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) - modulation_dims = None - modulation_dims_txt = None - - if self.params.guidance_embed: - if guidance is not None: - vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - - if txt_mask is not None and not torch.is_floating_point(txt_mask): - txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max - - txt = self.txt_in(txt, timesteps, txt_mask) - - ids = torch.cat((img_ids, txt_ids), dim=1) - pe = self.pe_embedder(ids) - - img_len = img.shape[1] - if txt_mask is not None: - attn_mask_len = img_len + txt.shape[1] - attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) - attn_mask[:, 0, img_len:] = txt_mask - else: - attn_mask = None - - blocks_replace = patches_replace.get("dit", {}) - - # enable teacache - img_mod1, _ = self.double_blocks[0].img_mod(vec) - modulated_inp = self.double_blocks[0].img_norm1(img) - modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift, modulation_dims).to(cache_device) - - if not hasattr(self, 'accumulated_rel_l1_distance'): - should_calc = True - self.accumulated_rel_l1_distance = 0 - else: - try: - self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) - if self.accumulated_rel_l1_distance < rel_l1_thresh: - should_calc = False - else: - should_calc = True - self.accumulated_rel_l1_distance = 0 - except: - should_calc = True - self.accumulated_rel_l1_distance = 0 - - self.previous_modulated_input = modulated_inp - - if not enable_teacache: - should_calc = True - - if not should_calc: - img += self.previous_residual.to(img.device) - else: - ori_img = img.to(cache_device) - for i, block in enumerate(self.double_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) - return out - - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) - txt = out["txt"] - img = out["img"] - else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) - - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - img += add - - img = torch.cat((img, txt), 1) - - for i, block in enumerate(self.single_blocks): - if ("single_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) - return out - - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) - img = out["img"] - else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) - - if control is not None: # Controlnet - control_o = control.get("output") - if i < len(control_o): - add = control_o[i] - if add is not None: - img[:, : img_len] += add - - img = img[:, : img_len] - self.previous_residual = (img.to(cache_device) - ori_img) - - if ref_latent is not None: - img = img[:, ref_latent.shape[1]:] - - img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) - - shape = initial_shape[-3:] - for i in range(len(shape)): - shape[i] = shape[i] // self.patch_size[i] - img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) - img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) - return img - -def teacache_ltxvmodel_forward( - self, - x, - timestep, - context, - attention_mask, - frame_rate=25, - transformer_options={}, - keyframe_idxs=None, - **kwargs - ): - patches_replace = transformer_options.get("patches_replace", {}) - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - cond_or_uncond = transformer_options.get("cond_or_uncond") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - orig_shape = list(x.shape) - - x, latent_coords = self.patchifier.patchify(x) - pixel_coords = latent_to_pixel_coords( - latent_coords=latent_coords, - scale_factors=self.vae_scale_factors, - causal_fix=self.causal_temporal_positioning, - ) - - if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs - - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) - - x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) - - blocks_replace = patches_replace.get("dit", {}) - - # enable teacache - inp = x.to(cache_device) - timestep_ = timestep.to(cache_device) - num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] - ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1) - shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2) - modulated_inp = comfy.ldm.common_dit.rms_norm(inp) - modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa - - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } - - def update_cache_state(cache, modulated_inp): - if cache['previous_modulated_input'] is not None: - try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) - if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: - cache['should_calc'] = False - else: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - except: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - cache['previous_modulated_input'] = modulated_inp - - b = int(len(x) / len(cond_or_uncond)) - - for i, k in enumerate(cond_or_uncond): - update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) - - if enable_teacache: - should_calc = False - for k in cond_or_uncond: - should_calc = (should_calc or self.teacache_state[k]['should_calc']) - else: - should_calc = True - - if not should_calc: - for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: - ori_x = x.to(cache_device) - for i, block in enumerate(self.transformer_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) - return out - - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) - x = out["img"] - else: - x = block( - x, - context=context, - attention_mask=attention_mask, - timestep=timestep, - pe=pe - ) - - # 3. Output - scale_shift_values = ( - self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] - ) - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - x = self.norm_out(x) - # Modulation - x = x * (1 + scale) + shift - for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] - - x = self.proj_out(x) - - x = self.patchifier.unpatchify( - latents=x, - output_height=orig_shape[3], - output_width=orig_shape[4], - output_num_frames=orig_shape[2], - out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), - ) - - return x - -def teacache_wanmodel_forward( - self, - x, - t, - context, - clip_fea=None, - freqs=None, - transformer_options={}, - **kwargs, - ): - patches_replace = transformer_options.get("patches_replace", {}) - rel_l1_thresh = transformer_options.get("rel_l1_thresh") - coefficients = transformer_options.get("coefficients") - cond_or_uncond = transformer_options.get("cond_or_uncond") - model_type = transformer_options.get("model_type") - enable_teacache = transformer_options.get("enable_teacache", True) - cache_device = transformer_options.get("cache_device") - - # embeddings - x = self.patch_embedding(x.float()).to(x.dtype) - grid_sizes = x.shape[2:] - x = x.flatten(2).transpose(1, 2) - - # time embeddings - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) - - # context - context = self.text_embedding(context) - - context_img_len = None - if clip_fea is not None: - if self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) - context_img_len = clip_fea.shape[-2] - - blocks_replace = patches_replace.get("dit", {}) - - # enable teacache - modulated_inp = e0.to(cache_device) if "ret_mode" in model_type else e.to(cache_device) - if not hasattr(self, 'teacache_state'): - self.teacache_state = { - 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, - 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} - } - - def update_cache_state(cache, modulated_inp): - if cache['previous_modulated_input'] is not None: - try: - cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) - if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: - cache['should_calc'] = False - else: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - except: - cache['should_calc'] = True - cache['accumulated_rel_l1_distance'] = 0 - cache['previous_modulated_input'] = modulated_inp - - b = int(len(x) / len(cond_or_uncond)) - - for i, k in enumerate(cond_or_uncond): - update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) - - if enable_teacache: - should_calc = False - for k in cond_or_uncond: - should_calc = (should_calc or self.teacache_state[k]['should_calc']) - else: - should_calc = True - - if not should_calc: - for i, k in enumerate(cond_or_uncond): - x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) - else: - ori_x = x.to(cache_device) - for i, block in enumerate(self.blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) - return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) - x = out["img"] - else: - x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) - for i, k in enumerate(cond_or_uncond): - self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] - - # head - x = self.head(x, e) - - # unpatchify - x = self.unpatchify(x, grid_sizes) - return x - -class TeaCache: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the TeaCache will be applied to."}), - "model_type": (["flux", "flux-kontext", "ltxv", "lumina_2", "hunyuan_video", "hidream_i1_full", "hidream_i1_dev", "hidream_i1_fast", "wan2.1_t2v_1.3B", "wan2.1_t2v_14B", "wan2.1_i2v_480p_14B", "wan2.1_i2v_720p_14B", "wan2.1_t2v_1.3B_ret_mode", "wan2.1_t2v_14B_ret_mode", "wan2.1_i2v_480p_14B_ret_mode", "wan2.1_i2v_720p_14B_ret_mode"], {"default": "flux", "tooltip": "Supported diffusion model."}), - "rel_l1_thresh": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps that will apply TeaCache."}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps that will apply TeaCache."}), - "cache_device": (["cuda", "cpu"], {"default": "cuda", "tooltip": "Device where the cache will reside."}), - } - } - - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("model",) - FUNCTION = "apply_teacache" - CATEGORY = "TeaCache" - TITLE = "TeaCache" - - def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_percent: float, end_percent: float, cache_device: str): - if rel_l1_thresh == 0: - return (model,) - - new_model = model.clone() - if 'transformer_options' not in new_model.model_options: - new_model.model_options['transformer_options'] = {} - new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh - new_model.model_options["transformer_options"]["coefficients"] = SUPPORTED_MODELS_COEFFICIENTS[model_type] - new_model.model_options["transformer_options"]["model_type"] = model_type - new_model.model_options["transformer_options"]["cache_device"] = mm.get_torch_device() if cache_device == "cuda" else torch.device("cpu") - - diffusion_model = new_model.get_model_object("diffusion_model") - - if "flux" in model_type: - is_cfg = False - context = patch.multiple( - diffusion_model, - forward_orig=teacache_flux_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - elif "lumina_2" in model_type: - is_cfg = True - context = patch.multiple( - diffusion_model, - forward=teacache_lumina_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - elif "hidream_i1" in model_type: - is_cfg = True if "full" in model_type else False - context = patch.multiple( - diffusion_model, - forward=teacache_hidream_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - elif "ltxv" in model_type: - is_cfg = True - context = patch.multiple( - diffusion_model, - forward=teacache_ltxvmodel_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - elif "hunyuan_video" in model_type: - is_cfg = False - context = patch.multiple( - diffusion_model, - forward_orig=teacache_hunyuanvideo_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - elif "wan2.1" in model_type: - is_cfg = True - context = patch.multiple( - diffusion_model, - forward_orig=teacache_wanmodel_forward.__get__(diffusion_model, diffusion_model.__class__) - ) - else: - raise ValueError(f"Unknown type {model_type}") - - def unet_wrapper_function(model_function, kwargs): - input = kwargs["input"] - timestep = kwargs["timestep"] - c = kwargs["c"] - # referenced from https://github.com/kijai/ComfyUI-KJNodes/blob/d126b62cebee81ea14ec06ea7cd7526999cb0554/nodes/model_optimization_nodes.py#L868 - sigmas = c["transformer_options"]["sample_sigmas"] - matched_step_index = (sigmas == timestep[0]).nonzero() - if len(matched_step_index) > 0: - current_step_index = matched_step_index.item() - else: - current_step_index = 0 - for i in range(len(sigmas) - 1): - # walk from beginning of steps until crossing the timestep - if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: - current_step_index = i - break - - if current_step_index == 0: - if is_cfg: - # uncond -> 1, cond -> 0 - if hasattr(diffusion_model, 'teacache_state') and \ - diffusion_model.teacache_state[0]['previous_modulated_input'] is not None and \ - diffusion_model.teacache_state[1]['previous_modulated_input'] is not None: - delattr(diffusion_model, 'teacache_state') - else: - if hasattr(diffusion_model, 'teacache_state'): - delattr(diffusion_model, 'teacache_state') - if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): - delattr(diffusion_model, 'accumulated_rel_l1_distance') - - current_percent = current_step_index / (len(sigmas) - 1) - c["transformer_options"]["current_percent"] = current_percent - if start_percent <= current_percent <= end_percent: - c["transformer_options"]["enable_teacache"] = True - else: - c["transformer_options"]["enable_teacache"] = False - - with context: - return model_function(input, timestep, **c) - - new_model.set_model_unet_function_wrapper(unet_wrapper_function) - - return (new_model,) - -def patch_optimized_module(): - try: - from torch._dynamo.eval_frame import OptimizedModule - except ImportError: - return - - if getattr(OptimizedModule, "_patched", False): - return - - def __getattribute__(self, name): - if name == "_orig_mod": - return object.__getattribute__(self, "_modules")[name] - if name in ( - "__class__", - "_modules", - "state_dict", - "load_state_dict", - "parameters", - "named_parameters", - "buffers", - "named_buffers", - "children", - "named_children", - "modules", - "named_modules", - ): - return getattr(object.__getattribute__(self, "_orig_mod"), name) - return object.__getattribute__(self, name) - - def __delattr__(self, name): - return delattr(self._orig_mod, name) - - @classmethod - def __instancecheck__(cls, instance): - return isinstance(instance, OptimizedModule) or issubclass( - object.__getattribute__(instance, "__class__"), cls - ) - - OptimizedModule.__getattribute__ = __getattribute__ - OptimizedModule.__delattr__ = __delattr__ - OptimizedModule.__instancecheck__ = __instancecheck__ - OptimizedModule._patched = True - -def patch_same_meta(): - try: - from torch._inductor.fx_passes import post_grad - except ImportError: - return - - same_meta = getattr(post_grad, "same_meta", None) - if same_meta is None: - return - - if getattr(same_meta, "_patched", False): - return - - def new_same_meta(a, b): - try: - return same_meta(a, b) - except Exception: - return False - - post_grad.same_meta = new_same_meta - new_same_meta._patched = True - -class CompileModel: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}), - "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), - "backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), - "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), - "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), - } - } - - RETURN_TYPES = ("MODEL",) - RETURN_NAMES = ("model",) - FUNCTION = "apply_compile" - CATEGORY = "TeaCache" - TITLE = "Compile Model" - - def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool): - patch_optimized_module() - patch_same_meta() - torch._dynamo.config.suppress_errors = True - - new_model = model.clone() - new_model.add_object_patch( - "diffusion_model", - torch.compile( - new_model.get_model_object("diffusion_model"), - mode=mode, - backend=backend, - fullgraph=fullgraph, - dynamic=dynamic - ) - ) - - return (new_model,) - - -NODE_CLASS_MAPPINGS = { - "TeaCache": TeaCache, - "CompileModel": CompileModel -} - -NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()} +import math +import torch +import comfy.ldm.common_dit +import comfy.model_management as mm + +from torch import Tensor +from einops import repeat +from typing import Optional +from unittest.mock import patch + +from comfy.ldm.flux.layers import timestep_embedding, apply_mod +from comfy.ldm.lightricks.model import precompute_freqs_cis +from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords +from comfy.ldm.wan.model import sinusoidal_embedding_1d + + +def default(x, y): + if x is not None: + return x + return y + +SUPPORTED_MODELS_COEFFICIENTS = { + "flux": [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], + "flux-kontext": [-1.04655119e+03, 3.12563399e+02, -1.69500694e+01, 4.10995971e-01, 3.74537863e-02], + "ltxv": [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03], + "lumina_2": [-8.74643948e+02, 4.66059906e+02, -7.51559762e+01, 5.32836175e+00, -3.27258296e-02], + "hunyuan_video": [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02], + "hidream_i1_full": [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01], + "hidream_i1_dev": [1.39997273, -4.30130469, 5.01534416, -2.20504164, 0.93942874], + "hidream_i1_fast": [2.26509623, -6.88864563, 7.61123826, -3.10849353, 0.99927602], + "qwen-image": [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02], + "qwen-image-edit": [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02], + "wan2.1_t2v_1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], + "wan2.1_t2v_14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], + "wan2.1_i2v_480p_14B": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], + "wan2.1_i2v_720p_14B": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], + "wan2.1_t2v_1.3B_ret_mode": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "wan2.1_t2v_14B_ret_mode": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "wan2.1_i2v_480p_14B_ret_mode": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "wan2.1_i2v_720p_14B_ret_mode": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], +} + +def poly1d(coefficients, x): + result = torch.zeros_like(x) + for i, coeff in enumerate(coefficients): + result += coeff * (x ** (len(coefficients) - 1 - i)) + return result + +def teacache_flux_forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + control = None, + transformer_options={}, + attn_mask: Tensor = None, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) + if self.params.guidance_embed: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + + vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + txt = self.txt_in(txt) + + if img_ids is not None: + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + else: + pe = None + + blocks_replace = patches_replace.get("dit", {}) + + # enable teacache + img_mod1, _ = self.double_blocks[0].img_mod(vec) + modulated_inp = self.double_blocks[0].img_norm1(img) + modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift).to(cache_device) + ca_idx = 0 + + if not hasattr(self, 'accumulated_rel_l1_distance'): + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + try: + self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())).abs() + if self.accumulated_rel_l1_distance < rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + except: + should_calc = True + self.accumulated_rel_l1_distance = 0 + + self.previous_modulated_input = modulated_inp + + if not enable_teacache: + should_calc = True + + if not should_calc: + img += self.previous_residual.to(img.device) + else: + ori_img = img.to(cache_device) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + # PuLID attention + if getattr(self, "pulid_data", {}): + if i % self.pulid_double_interval == 0: + # Will calculate influence of all pulid nodes at once + for _, node_data in self.pulid_data.items(): + if torch.any((node_data['sigma_start'] >= timesteps) + & (timesteps >= node_data['sigma_end'])): + img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) + ca_idx += 1 + + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, txt.shape[1] :, ...] += add + + # PuLID attention + if getattr(self, "pulid_data", {}): + real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] + if i % self.pulid_single_interval == 0: + # Will calculate influence of all nodes at once + for _, node_data in self.pulid_data.items(): + if torch.any((node_data['sigma_start'] >= timesteps) + & (timesteps >= node_data['sigma_end'])): + real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) + ca_idx += 1 + img = torch.cat((txt, real_img), 1) + + img = img[:, txt.shape[1] :, ...] + self.previous_residual = img.to(cache_device) - ori_img + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img + +def teacache_hidream_forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + image_cond=None, + control = None, + transformer_options = {}, + ) -> torch.Tensor: + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + cond_or_uncond = transformer_options.get("cond_or_uncond") + model_type = transformer_options.get("model_type") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + bs, c, h, w = x.shape + if image_cond is not None: + x = torch.cat([x, image_cond], dim=-1) + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + timesteps = t + pooled_embeds = y + T5_encoder_hidden_states = context + + img_sizes = None + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + # T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # enable teacache + modulated_inp = timesteps.to(cache_device) if "full" in model_type else hidden_states.to(cache_device) + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, + 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} + } + + def update_cache_state(cache, modulated_inp): + if cache['previous_modulated_input'] is not None: + try: + cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: + cache['should_calc'] = False + else: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + except: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + cache['previous_modulated_input'] = modulated_inp + + b = int(len(hidden_states) / len(cond_or_uncond)) + + for i, k in enumerate(cond_or_uncond): + update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) + + if enable_teacache: + should_calc = False + for k in cond_or_uncond: + should_calc = (should_calc or self.teacache_state[k]['should_calc']) + else: + should_calc = True + + if not should_calc: + for i, k in enumerate(cond_or_uncond): + hidden_states[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(hidden_states.device) + else: + # 2. Blocks + ori_hidden_states = hidden_states.to(cache_device) + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states = block( + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=None, + adaln_input=adaln_input, + rope=rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + for i, k in enumerate(cond_or_uncond): + self.teacache_state[k]['previous_residual'] = (hidden_states.to(cache_device) - ori_hidden_states)[i*b:(i+1)*b] + + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes) + return -output[:, :, :h, :w] + +def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + cond_or_uncond = transformer_options.get("cond_or_uncond") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + adaln_input = t + + cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + + x_is_tensor = isinstance(x, torch.Tensor) + x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) + freqs_cis = freqs_cis.to(x.device) + + # enable teacache + modulated_inp = t.to(cache_device) + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, + 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} + } + + def update_cache_state(cache, modulated_inp): + if cache['previous_modulated_input'] is not None: + try: + cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: + cache['should_calc'] = False + else: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + except: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + cache['previous_modulated_input'] = modulated_inp + + b = int(len(x) / len(cond_or_uncond)) + + for i, k in enumerate(cond_or_uncond): + update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) + + if enable_teacache: + should_calc = False + for k in cond_or_uncond: + should_calc = (should_calc or self.teacache_state[k]['should_calc']) + else: + should_calc = True + + if not should_calc: + for i, k in enumerate(cond_or_uncond): + x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) + else: + ori_x = x.to(cache_device) + # 2. Blocks + for layer in self.layers: + x = layer(x, mask, freqs_cis, adaln_input) + for i, k in enumerate(cond_or_uncond): + self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] + + x = self.final_layer(x, adaln_input) + x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + + return -x + +def teacache_hunyuanvideo_forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + guiding_frame_index=None, + ref_latent=None, + control=None, + transformer_options={}, + ) -> Tensor: + patches_replace = transformer_options.get("patches_replace", {}) + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + initial_shape = list(img.shape) + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + + if ref_latent is not None: + ref_latent_ids = self.img_ids(ref_latent) + ref_latent = self.img_in(ref_latent) + img = torch.cat([ref_latent, img], dim=-2) + ref_latent_ids[..., 0] = -1 + ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) + img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) + + if guiding_frame_index is not None: + token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) + modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] + modulation_dims_txt = [(0, None, 1)] + else: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + modulation_dims = None + modulation_dims_txt = None + + if self.params.guidance_embed: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) + + if txt_mask is not None and not torch.is_floating_point(txt_mask): + txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max + + txt = self.txt_in(txt, timesteps, txt_mask) + + ids = torch.cat((img_ids, txt_ids), dim=1) + pe = self.pe_embedder(ids) + + img_len = img.shape[1] + if txt_mask is not None: + attn_mask_len = img_len + txt.shape[1] + attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) + attn_mask[:, 0, img_len:] = txt_mask + else: + attn_mask = None + + blocks_replace = patches_replace.get("dit", {}) + + # enable teacache + img_mod1, _ = self.double_blocks[0].img_mod(vec) + modulated_inp = self.double_blocks[0].img_norm1(img) + modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift, modulation_dims).to(cache_device) + + if not hasattr(self, 'accumulated_rel_l1_distance'): + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + try: + self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) + if self.accumulated_rel_l1_distance < rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + except: + should_calc = True + self.accumulated_rel_l1_distance = 0 + + self.previous_modulated_input = modulated_inp + + if not enable_teacache: + should_calc = True + + if not should_calc: + img += self.previous_residual.to(img.device) + else: + ori_img = img.to(cache_device) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) + return out + + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) + + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + img += add + + img = torch.cat((img, txt), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) + return out + + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) + + if control is not None: # Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img[:, : img_len] += add + + img = img[:, : img_len] + self.previous_residual = (img.to(cache_device) - ori_img) + + if ref_latent is not None: + img = img[:, ref_latent.shape[1]:] + + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) + + shape = initial_shape[-3:] + for i in range(len(shape)): + shape[i] = shape[i] // self.patch_size[i] + img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) + img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + return img + +def teacache_ltxvmodel_forward( + self, + x, + timestep, + context, + attention_mask, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs + ): + patches_replace = transformer_options.get("patches_replace", {}) + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + cond_or_uncond = transformer_options.get("cond_or_uncond") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + orig_shape = list(x.shape) + + x, latent_coords = self.patchifier.patchify(x) + pixel_coords = latent_to_pixel_coords( + latent_coords=latent_coords, + scale_factors=self.vae_scale_factors, + causal_fix=self.causal_temporal_positioning, + ) + + if keyframe_idxs is not None: + pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + + x = self.patchify_proj(x) + timestep = timestep * 1000.0 + + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max + + pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) + + batch_size = x.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=x.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = x.shape[0] + context = self.caption_projection(context) + context = context.view( + batch_size, -1, x.shape[-1] + ) + + blocks_replace = patches_replace.get("dit", {}) + + # enable teacache + inp = x.to(cache_device) + timestep_ = timestep.to(cache_device) + num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] + ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1) + shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2) + modulated_inp = comfy.ldm.common_dit.rms_norm(inp) + modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa + + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, + 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} + } + + def update_cache_state(cache, modulated_inp): + if cache['previous_modulated_input'] is not None: + try: + cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: + cache['should_calc'] = False + else: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + except: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + cache['previous_modulated_input'] = modulated_inp + + b = int(len(x) / len(cond_or_uncond)) + + for i, k in enumerate(cond_or_uncond): + update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) + + if enable_teacache: + should_calc = False + for k in cond_or_uncond: + should_calc = (should_calc or self.teacache_state[k]['should_calc']) + else: + should_calc = True + + if not should_calc: + for i, k in enumerate(cond_or_uncond): + x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) + else: + ori_x = x.to(cache_device) + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) + return out + + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block( + x, + context=context, + attention_mask=attention_mask, + timestep=timestep, + pe=pe + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) + # Modulation + x = x * (1 + scale) + shift + for i, k in enumerate(cond_or_uncond): + self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] + + x = self.proj_out(x) + + x = self.patchifier.unpatchify( + latents=x, + output_height=orig_shape[3], + output_width=orig_shape[4], + output_num_frames=orig_shape[2], + out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), + ) + + return x + +def teacache_wanmodel_forward( + self, + x, + t, + context, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + patches_replace = transformer_options.get("patches_replace", {}) + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + cond_or_uncond = transformer_options.get("cond_or_uncond") + model_type = transformer_options.get("model_type") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + blocks_replace = patches_replace.get("dit", {}) + + # enable teacache + modulated_inp = e0.to(cache_device) if "ret_mode" in model_type else e.to(cache_device) + if not hasattr(self, 'teacache_state'): + self.teacache_state = { + 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, + 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} + } + + def update_cache_state(cache, modulated_inp): + if cache['previous_modulated_input'] is not None: + try: + cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) + if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: + cache['should_calc'] = False + else: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + except: + cache['should_calc'] = True + cache['accumulated_rel_l1_distance'] = 0 + cache['previous_modulated_input'] = modulated_inp + + b = int(len(x) / len(cond_or_uncond)) + + for i, k in enumerate(cond_or_uncond): + update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) + + if enable_teacache: + should_calc = False + for k in cond_or_uncond: + should_calc = (should_calc or self.teacache_state[k]['should_calc']) + else: + should_calc = True + + if not should_calc: + for i, k in enumerate(cond_or_uncond): + x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) + else: + ori_x = x.to(cache_device) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + for i, k in enumerate(cond_or_uncond): + self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x + +def teacache_qwen_image_forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + transformer_options={}, + **kwargs + ): + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + # Align with upstream Qwen-Image API: use process_img + pe_embedder + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + encoder_hidden_states.shape[1], steps=encoder_hidden_states.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + # TeaCache logic - use first transformer block's input as modulated input for change detection + # This is similar to how FLUX uses the first double_block's modulated input + if len(self.transformer_blocks) > 0: + # Get the first block to calculate modulated input + first_block = self.transformer_blocks[0] + # Use the processed hidden_states as modulated input (after img_in transformation) + modulated_inp = hidden_states.to(cache_device) + else: + modulated_inp = hidden_states.to(cache_device) + + # CFG-aware caching - maintain separate states for positive and negative prompts + # Detect CFG mode by checking encoder_hidden_states sequence length + is_positive_prompt = encoder_hidden_states.shape[1] > 50 # Long sequence = positive, short = negative + cache_key = 'positive' if is_positive_prompt else 'negative' + + if not hasattr(self, 'teacache_states'): + self.teacache_states = { + 'positive': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_encoder_residual': None, 'previous_hidden_residual': None}, + 'negative': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_encoder_residual': None, 'previous_hidden_residual': None} + } + + cache_state = self.teacache_states[cache_key] + + if cache_state['previous_modulated_input'] is None: + should_calc = True + else: + try: + cache_state['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache_state['previous_modulated_input']).abs().mean() / cache_state['previous_modulated_input'].abs().mean())).abs() + if cache_state['accumulated_rel_l1_distance'] < rel_l1_thresh: + should_calc = False + else: + should_calc = True + cache_state['accumulated_rel_l1_distance'] = 0 + except: + should_calc = True + cache_state['accumulated_rel_l1_distance'] = 0 + + cache_state['previous_modulated_input'] = modulated_inp + + if not enable_teacache: + should_calc = True + + if not should_calc: + # Use CFG-aware cached residuals + if (cache_state['previous_encoder_residual'] is not None and + cache_state['previous_hidden_residual'] is not None): + # Check if cached residuals have compatible shapes + if (cache_state['previous_encoder_residual'].shape == encoder_hidden_states.shape and + cache_state['previous_hidden_residual'].shape == hidden_states.shape): + pass # Using cached computation + encoder_hidden_states += cache_state['previous_encoder_residual'].to(encoder_hidden_states.device) + hidden_states += cache_state['previous_hidden_residual'].to(hidden_states.device) + else: + # Shape mismatch, force recalculation + pass # Shape mismatch, forcing recalculation + should_calc = True + else: + # No cached residuals available, force recalculation + pass # No cached residuals available + should_calc = True + + # Process through transformer_blocks if calculation is needed + if should_calc: + # Store original states for residual calculation + ori_encoder_hidden_states = encoder_hidden_states.to(cache_device) + ori_hidden_states = hidden_states.to(cache_device) + + # Process through transformer_blocks (Qwen-Image architecture) + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Store residuals for future use in CFG-aware cache state + cache_state['previous_encoder_residual'] = (encoder_hidden_states.to(cache_device) - ori_encoder_hidden_states) + cache_state['previous_hidden_residual'] = (hidden_states.to(cache_device) - ori_hidden_states) + pass # Residuals calculated and stored + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # Optional debug + if transformer_options.get("debug_teacache", False): + try: + print("[TeaCache][qwen-image-edit] shapes:", + "x=", tuple(x.shape), + "orig_shape=", tuple(orig_shape), + "after_proj=", tuple(hidden_states.shape), + "main_num_embeds=", int(main_num_embeds)) + except Exception: + pass + + # Use only main image tokens for reconstruction + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] + +def teacache_qwen_image_edit_forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + transformer_options={}, + **kwargs + ): + rel_l1_thresh = transformer_options.get("rel_l1_thresh") + coefficients = transformer_options.get("coefficients") + enable_teacache = transformer_options.get("enable_teacache", True) + cache_device = transformer_options.get("cache_device") + + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + # Process main image tokens + hidden_states, img_ids, orig_shape = self.process_img(x) + main_num_embeds = hidden_states.shape[1] + + # Integrate reference latents like upstream edit forward + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + + num_embeds = hidden_states.shape[1] + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + encoder_hidden_states.shape[1], steps=encoder_hidden_states.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + + # For qwen-image-edit: use ALL tokens for cache detection when reference images exist + # This ensures reference image changes are properly detected + if hidden_states.shape[1] > main_num_embeds: + # Reference images exist - use all tokens for cache detection + modulated_inp = hidden_states.to(cache_device) + else: + # No reference images - use main tokens only + modulated_inp = hidden_states[:, :main_num_embeds].to(cache_device) + + hidden_states = self.img_in(hidden_states) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + # CFG-aware dual cache states (positive/negative) + is_positive_prompt = encoder_hidden_states.shape[1] > 50 + cache_key = 'positive' if is_positive_prompt else 'negative' + + if not hasattr(self, 'teacache_states'): + self.teacache_states = { + 'positive': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_encoder_residual': None, 'previous_hidden_residual': None}, + 'negative': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_encoder_residual': None, 'previous_hidden_residual': None} + } + + cache_state = self.teacache_states[cache_key] + + # Check if we have reference images - use stricter caching for reference scenarios + has_reference = hidden_states.shape[1] > main_num_embeds + + if has_reference: + # For reference images, use stricter threshold to ensure proper fusion + reference_rel_l1_thresh = rel_l1_thresh * 0.1 # Much stricter threshold + effective_thresh = reference_rel_l1_thresh + else: + effective_thresh = rel_l1_thresh + + if cache_state['previous_modulated_input'] is None: + should_calc = True + else: + try: + cache_state['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache_state['previous_modulated_input']).abs().mean() / cache_state['previous_modulated_input'].abs().mean())).abs() + if cache_state['accumulated_rel_l1_distance'] < effective_thresh: + should_calc = False + else: + should_calc = True + cache_state['accumulated_rel_l1_distance'] = 0 + except: + should_calc = True + cache_state['accumulated_rel_l1_distance'] = 0 + + cache_state['previous_modulated_input'] = modulated_inp + + if not enable_teacache: + should_calc = True + + if not should_calc: + if (cache_state['previous_encoder_residual'] is not None and + cache_state['previous_hidden_residual'] is not None): + if (cache_state['previous_encoder_residual'].shape == encoder_hidden_states.shape and + cache_state['previous_hidden_residual'].shape == hidden_states.shape): + encoder_hidden_states += cache_state['previous_encoder_residual'].to(encoder_hidden_states.device) + hidden_states += cache_state['previous_hidden_residual'].to(hidden_states.device) + else: + should_calc = True + else: + should_calc = True + + if should_calc: + ori_encoder_hidden_states = encoder_hidden_states.to(cache_device) + ori_hidden_states = hidden_states.to(cache_device) + + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + cache_state['previous_encoder_residual'] = (encoder_hidden_states.to(cache_device) - ori_encoder_hidden_states) + cache_state['previous_hidden_residual'] = (hidden_states.to(cache_device) - ori_hidden_states) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # Follow original Qwen-Image model: use main_num_embeds for reconstruction + # Reference info should be fused into main tokens during transformer processing + + # Use main_num_embeds like the original model (line 441 in qwen_image/model.py) + hidden_states = hidden_states[:, :main_num_embeds].view( + orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2 + ) + hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + output = hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] + + return output + +class TeaCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the TeaCache will be applied to."}), + "model_type": (["flux", "flux-kontext", "ltxv", "lumina_2", "hunyuan_video", "hidream_i1_full", "hidream_i1_dev", "hidream_i1_fast", "qwen-image", "qwen-image-edit", "wan2.1_t2v_1.3B", "wan2.1_t2v_14B", "wan2.1_i2v_480p_14B", "wan2.1_i2v_720p_14B", "wan2.1_t2v_1.3B_ret_mode", "wan2.1_t2v_14B_ret_mode", "wan2.1_i2v_480p_14B_ret_mode", "wan2.1_i2v_720p_14B_ret_mode"], {"default": "flux", "tooltip": "Supported diffusion model."}), + "rel_l1_thresh": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps that will apply TeaCache."}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps that will apply TeaCache."}), + "cache_device": (["cuda", "cpu"], {"default": "cuda", "tooltip": "Device where the cache will reside."}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "apply_teacache" + CATEGORY = "TeaCache" + TITLE = "TeaCache" + + def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_percent: float, end_percent: float, cache_device: str): + if rel_l1_thresh == 0: + return (model,) + + new_model = model.clone() + if 'transformer_options' not in new_model.model_options: + new_model.model_options['transformer_options'] = {} + new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh + new_model.model_options["transformer_options"]["coefficients"] = SUPPORTED_MODELS_COEFFICIENTS[model_type] + new_model.model_options["transformer_options"]["model_type"] = model_type + new_model.model_options["transformer_options"]["cache_device"] = mm.get_torch_device() if cache_device == "cuda" else torch.device("cpu") + + diffusion_model = new_model.get_model_object("diffusion_model") + + if "flux" in model_type: + is_cfg = False + context = patch.multiple( + diffusion_model, + forward_orig=teacache_flux_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "lumina_2" in model_type: + is_cfg = True + context = patch.multiple( + diffusion_model, + forward=teacache_lumina_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "hidream_i1" in model_type: + is_cfg = True if "full" in model_type else False + context = patch.multiple( + diffusion_model, + forward=teacache_hidream_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "ltxv" in model_type: + is_cfg = True + context = patch.multiple( + diffusion_model, + forward=teacache_ltxvmodel_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "hunyuan_video" in model_type: + is_cfg = False + context = patch.multiple( + diffusion_model, + forward_orig=teacache_hunyuanvideo_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "qwen-image-edit" in model_type: + is_cfg = False + context = patch.multiple( + diffusion_model, + forward=teacache_qwen_image_edit_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "qwen-image" in model_type: + is_cfg = False + context = patch.multiple( + diffusion_model, + forward=teacache_qwen_image_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + elif "wan2.1" in model_type: + is_cfg = True + context = patch.multiple( + diffusion_model, + forward_orig=teacache_wanmodel_forward.__get__(diffusion_model, diffusion_model.__class__) + ) + else: + raise ValueError(f"Unknown type {model_type}") + + def unet_wrapper_function(model_function, kwargs): + input = kwargs["input"] + timestep = kwargs["timestep"] + c = kwargs["c"] + # referenced from https://github.com/kijai/ComfyUI-KJNodes/blob/d126b62cebee81ea14ec06ea7cd7526999cb0554/nodes/model_optimization_nodes.py#L868 + sigmas = c["transformer_options"]["sample_sigmas"] + matched_step_index = (sigmas == timestep[0]).nonzero() + if len(matched_step_index) > 0: + current_step_index = matched_step_index.item() + else: + current_step_index = 0 + for i in range(len(sigmas) - 1): + # walk from beginning of steps until crossing the timestep + if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: + current_step_index = i + break + + if current_step_index == 0: + if is_cfg: + # uncond -> 1, cond -> 0 + if hasattr(diffusion_model, 'teacache_state') and \ + diffusion_model.teacache_state[0]['previous_modulated_input'] is not None and \ + diffusion_model.teacache_state[1]['previous_modulated_input'] is not None: + delattr(diffusion_model, 'teacache_state') + else: + if hasattr(diffusion_model, 'teacache_state'): + delattr(diffusion_model, 'teacache_state') + if hasattr(diffusion_model, 'teacache_states'): + delattr(diffusion_model, 'teacache_states') + if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): + delattr(diffusion_model, 'accumulated_rel_l1_distance') + + current_percent = current_step_index / (len(sigmas) - 1) + c["transformer_options"]["current_percent"] = current_percent + if start_percent <= current_percent <= end_percent: + c["transformer_options"]["enable_teacache"] = True + else: + c["transformer_options"]["enable_teacache"] = False + + with context: + return model_function(input, timestep, **c) + + new_model.set_model_unet_function_wrapper(unet_wrapper_function) + + return (new_model,) + +def patch_optimized_module(): + try: + from torch._dynamo.eval_frame import OptimizedModule + except ImportError: + return + + if getattr(OptimizedModule, "_patched", False): + return + + def __getattribute__(self, name): + if name == "_orig_mod": + return object.__getattribute__(self, "_modules")[name] + if name in ( + "__class__", + "_modules", + "state_dict", + "load_state_dict", + "parameters", + "named_parameters", + "buffers", + "named_buffers", + "children", + "named_children", + "modules", + "named_modules", + ): + return getattr(object.__getattribute__(self, "_orig_mod"), name) + return object.__getattribute__(self, name) + + def __delattr__(self, name): + return delattr(self._orig_mod, name) + + @classmethod + def __instancecheck__(cls, instance): + return isinstance(instance, OptimizedModule) or issubclass( + object.__getattribute__(instance, "__class__"), cls + ) + + OptimizedModule.__getattribute__ = __getattribute__ + OptimizedModule.__delattr__ = __delattr__ + OptimizedModule.__instancecheck__ = __instancecheck__ + OptimizedModule._patched = True + +def patch_same_meta(): + try: + from torch._inductor.fx_passes import post_grad + except ImportError: + return + + same_meta = getattr(post_grad, "same_meta", None) + if same_meta is None: + return + + if getattr(same_meta, "_patched", False): + return + + def new_same_meta(a, b): + try: + return same_meta(a, b) + except Exception: + return False + + post_grad.same_meta = new_same_meta + new_same_meta._patched = True + +class CompileModel: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}), + "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), + "backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), + "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), + "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "apply_compile" + CATEGORY = "TeaCache" + TITLE = "Compile Model" + + def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool): + patch_optimized_module() + patch_same_meta() + torch._dynamo.config.suppress_errors = True + + new_model = model.clone() + new_model.add_object_patch( + "diffusion_model", + torch.compile( + new_model.get_model_object("diffusion_model"), + mode=mode, + backend=backend, + fullgraph=fullgraph, + dynamic=dynamic + ) + ) + + return (new_model,) + + +NODE_CLASS_MAPPINGS = { + "TeaCache": TeaCache, + "CompileModel": CompileModel +} + +NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()} +