Skip to content

Conversation

@rattus128
Copy link
Contributor

@rattus128 rattus128 commented Jan 13, 2026

To try it:

pip install requirements.txt
--fast dynamic_vram

NOTE: This work does not have any GGUF integration and GGUF will not see any benefits yet.

NOTE: I am aware of increase Windows RAM usage when not configuring a pagefile due to commit quota exhaustion. If anyone is testing please stay tuned for a major fix to windows RAM usage incoming. The VRAM stuff on windows is still testable. Linux is unaffected. (FIXED)

If you try it, please reply to the PR (if is hasn't been merged) with any issues or feel free to make an issue ticket for bigger test cases with logs and numbers.

Features

  • A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens.

  • This will eventually allow for development of ComfyUI without needing to estimate model VRAM usage at all.

  • Large RAM and Windows commit-charge savings. No need to load models fully to RAM. This also gives a much higher chance of having a model in disk cache and saving the user from a disk load delay on first run as there is no primary load to process memory displacing the disk-cache any more.

  • Windows GPU shared memory usage avoidance

  • A deep copy of the model is cut in the safetensors save process (incidental improvement)

  • Reduced VRAM usage in async offload stream which cuda malloc disabled (pre-requisite improvement)

Implementation Details

Aimdo readme here: https://pypi.org/project/comfy-aimdo/

The long story on RAM: Aimdos ability to just evict weights means its no longer possible to .to() a weight back and forth from the GPU. VRAM pressure can occur at any time during inference and there is no clean way to .to() weights or modules back to the CPU while pytorch is stacked in the middle on a pending VRAM allocation. So as we can never .to() a weight we instead take the opportunity to leave the model parameter as known to pytorch on the CPU permanently with assign=True state dict loading. Since its never write touched it lives in mmap permanently and never consumes any process allocated RAM. Several community developers have flagged this as a possible major enhancement to comfy already and the needed changes to model load and unload align with the VRAM problems.

(NEW) Windows has extra RAM complications with its pessimistic allocation and how it forbids overcommit other than with the pagefile. Two changes are made to drastically reduce commit charge. Linear nn.Modules are now constructed without the placeholder weight as this consumes commit charge. The other change is a lightweight safetensors load that loads files in READ mode (safetensors package uses CoW) which avoids getting commit-charged for the whole model on file load.

As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute).

Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods.

Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive).

Model saving is reworked to avoid the force_cast_weights flag which doesnt make sense in ModelPatcherDynamic. This rework was able to cut a RAM copy of the model by doing on-the-fly model patching during the save process which worked out to be a nice RAM saving while fixing my API problem.

Aimdo (under the hood) links with Windows APIs to adjust load levels based on the WDDM target VRAM usage rather than using the pytorch/Cuda stack reported numbers (which are WDDMs lies). This means as soon as shared memory spilling occurs on Windows, weights will be unloaded until you get out of the spill state and inference state will move back to VRAM.

Offload streams now have an accompanying single shared cast buffer that grows as needed. This is to avoid significant waste and fragmentation in the cast buffer streams when offloading multiple weight sizes as we don't have cuda_malloc and the pytorch allocator completely isolates memory by stream. So go a little hands on the low level to keep those allocation pools minimized. This is applied to non --dynamic_vram case when using non cuda_malloc as it doesn reduce VRAM esp on flux2 with those huge and varying weights.

Future Work

  • Pin RAM management could use more optimization and try and get pins behind the mmap of active models in priority under RAM pressure to allow much more aggressive pin retention. The heuristic (currently free model x2) can also just be straight improved.
  • First iterations are slower than id hoped. Some multi-threading of CPU/RAM bottlenecking might allow for further run ahead and bottleneck saturation.
    - The progress meter needs some work. Its jarring to have it stall of the first iteration when its doing a slow model load. (DONE)

Example Test case:

Flux2 + Lora text to image.
RTX5090 with 8GB of VRAM consumed by non comfy application (24GB effective)
PCIE5 NVME, 96GB RAM.
Disk caches warm with model

image

Before:

________________________________________________________________________
Starting server

To see the GUI go to: http://0.0.0.0:8188
To see the GUI go to: http://[::]:8188
got prompt
Using pytorch attention in VAE
Using pytorch attention in VAE
VAE load device: cuda:0, offload device: cpu, dtype: torch.bfloat16
Found quantization metadata version 1
Using MixedPrecisionOps for text encoder
CLIP/text encoder model load device: cuda:0, offload device: cpu, current: cpu, dtype: torch.float16
Requested to load Flux2TEModel_
loaded completely; 21458.36 MB usable, 17180.59 MB loaded, full load: True
Found quantization metadata version 1
Detected mixed precision quantization
Using mixed precision operations
model weight dtype torch.bfloat16, manual cast: torch.bfloat16
model_type FLUX
Requested to load Flux2
loaded partially; 19731.54 MB usable, 18720.00 MB loaded, 15093.00 MB offloaded, 1152.00 MB buffer reserved, lowvram patches: 72
Initializing ControlAltAI Nodes
100%|██████████| 20/20 [00:25<00:00,  1.27s/it]
Requested to load AutoencoderKL
Unloaded partially: 1152.00 MB freed, 17568.00 MB remains loaded, 1152.00 MB buffer reserved, lowvram patches: 80
loaded completely; 190.98 MB usable, 160.31 MB loaded, full load: True
Prompt executed in 39.58 seconds

General Memory Usage

image

Peak VRAM:

image

After (--fast dynamic_vram)

________________________________________________________________________
Starting server

To see the GUI go to: http://0.0.0.0:8188
To see the GUI go to: http://[::]:8188
got prompt
Using pytorch attention in VAE
Using pytorch attention in VAE
VAE load device: cuda:0, offload device: cpu, dtype: torch.bfloat16
Found quantization metadata version 1
Using MixedPrecisionOps for text encoder
CLIP/text encoder model load device: cuda:0, offload device: cpu, current: cpu, dtype: torch.float16
Requested to load Flux2TEModel_
Model Flux2TEModel_ prepared for dynamic VRAM loading. 17180MB Staged. 0 patches attached.
Found quantization metadata version 1
Detected mixed precision quantization
Using mixed precision operations
model weight dtype torch.bfloat16, manual cast: torch.bfloat16
model_type FLUX
Requested to load Flux2
Model Flux2 prepared for dynamic VRAM loading. 33813MB Staged. 138 patches attached.
Initializing ControlAltAI Nodes
100%|██████████| 20/20 [00:28<00:00,  1.41s/it]
Requested to load AutoencoderKL
0 models unloaded.
Model AutoencoderKL prepared for dynamic VRAM loading. 320MB Staged. 0 patches attached.
Prompt executed in 34.79 seconds

General Memory usage:

image

Peak VRAM:

image

More test data to come. Most workflows I have run are faster with this.

Im testing various things and updates bugfixes etc but enough works for a PR.

Get the model saving logic away from force_patch_weights and instead do
the patching JIT during safetensors saving.

Firstly switch off force_patch_weights in the load for save which avoids
creating CPU side tensors with loras calculated.

Then at save time, wrap the tensor to catch safetensors call to .to() and
patch it live.

This avoids having to ever have a lora-calculated copy of offloaded
weights on the CPU.

Also take advantage of the presence of the GPU when doing this Lora
calculation. The former force_patch_weights would just do eveyrthing on
the CPU. Its generally faster to go the GPU and back even if its just
a Lora application.
This needs to be visible by ops which may want to do stochastic rounding on
the fly.
Add a python for managing pinned memory of the weight/bias module level.
This allocates, pins and attached a tensor to a module for the pin for this
module. It does not set the weight, just allocates a singular ram buffer
for population and bulk DMA transfer.
Dynamic load needs to adjust these numbers based on future movements,
so wrap this in a MP API.
Add two api expansions, a flag for whether a model patcher is dynamic
a a very basic RAM freeing system.

Implement the semantics of the dynamic model patcher which never frees
VRAM ahead of time for the sake of another dynamic model patcher.

At the same time add an API for clearing out pins on a reservation of
model size x2 heuristic, as pins consume RAM in their own right in the
dynamic patcher.

This is actually less about OOMing RAM and more about performance, as
with assign=True load semantics there needs to be plenty headroom for
the OS to load models to dosk cache on demand so err on the side of
kicking old pins out.
non-comfy weights dont get async offload and a few other performance
limitations. Load them at top priority accordingly.
Implement a model patcher and caster for aimdo.

A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens.

As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute).

Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods.

Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive).
Use CoreModelPatcher for all internal ModelPatcher implementations. This drives
conditional use of the aimdo feature, while making sure custom node packs get
to keep ModelPatcher unchanged for the moment.
We need to general pytorch cache defragmentation on an appropriate level for
aimdo. Do in here on the per node basis, which has a reasonable chance of
purging stale shapes out of the pytorch caching allocator and saving VRAM
without costing too much garbage collector thrash.

This looks like a lot of GC but because aimdo never fails from pytorch and
saves the pytorch allocator from ever need to defrag out of demand, but it
needs a oil change every now and then so we gotta do it. Doing it here also
means the pytorch temps are cleared from task manager VRAM usage so user
anxiety can go down a little when they see their vram drop back at the end
of workflows inline with inference usage (rather than assuming full VRAM
leaks).
Add the optional command line switch --fast dynamic_vram.

This is mutually exclusing --high-vram and --gpu-only which contradict
aimdos underlying feature.

Add appropriate installation warning and a startup message, match the
comfy debug level inconfiguring aimdo.

Add comfy-aimdo pip requirement. This will safely stub to a nop for
unsupported platforms.
Sync before deleting anything.
@socket-security
Copy link

socket-security bot commented Jan 13, 2026

Review the following changes in direct dependencies. Learn more about Socket for GitHub.

Diff Package Supply Chain
Security
Vulnerability Quality Maintenance License
Addedcomfy-aimdo@​0.1.1100100100100100

View full report

@rattus128 rattus128 marked this pull request as draft January 13, 2026 10:29
This is needed for aimdo where the cache cant self recover from
fragmentation. It is however a good thing to do anyway after an OOM
so make it unconditional.
@MeiYi-dev
Copy link

MeiYi-dev commented Jan 13, 2026

As a 16GB VRAM user using the LTX 2 model, the main issue for me currently is before VAE decoing occurs, the whole model gets offloaded to RAM, which is already loaded with the TEs/VAE/Latent Upscalers, etc, so it gets overloaded and spools onto pagefile. When in reality the max VRAM use of the VAE decoding part is 4GB with the "VAE decoded tiled" node, unloading the whole model (probably because of VAE estimations within comfyui not accounting for tiled decoding) is the biggest issue I have found and is the only reason why us 16GB VRAM / 32GB RAM users are experiencing issues where the TE reloads from disk (because it got unloaded to make space for the model) after changing the prompt and model reloads again, both of these contributing to huge slowdowns.

@rattus128
Copy link
Contributor Author

rattus128 commented Jan 13, 2026

As a 16GB VRAM user using the LTX 2 model, the main issue for me currently is before VAE decoing occurs, the whole model gets offloaded to RAM, which is already loaded with the TEs/VAE/Latent Upscalers, etc, so it gets overloaded and spools onto pagefile. When in reality the max VRAM use of the VAE decoding part is 4GB with the "VAE decoded tiled" node, unloading the whole model (probably because of VAE estimations within comfyui not accounting for tiled decoding) is the biggest issue I have found and is the only reason why us 16GB VRAM / 32GB RAM users are experiencing issues where the TE reloads from disk (because it got unloaded to make space for the model) after changing the prompt and model reloads again, both of these contributing to huge slowdowns.

This loader doesn't unload back to RAM at all so it wont spill to pagefile. The idea is, if you dont have enough RAM, just dump it, because its faster to just read it from file on disk again than to write and read to pagefile. If you do have enough RAM, the OS will just leave the model in disk cache from the first load. So this should be faster for you.

You margins are very low, you might do well --disable-pinned-memory but if you try it, try it both ways.

kudos for LTX2 on 16 and these performance points is what im trying to really make work here.

@MeiYi-dev
Copy link

MeiYi-dev commented Jan 13, 2026

As a 16GB VRAM user using the LTX 2 model, the main issue for me currently is before VAE decoing occurs, the whole model gets offloaded to RAM, which is already loaded with the TEs/VAE/Latent Upscalers, etc, so it gets overloaded and spools onto pagefile. When in reality the max VRAM use of the VAE decoding part is 4GB with the "VAE decoded tiled" node, unloading the whole model (probably because of VAE estimations within comfyui not accounting for tiled decoding) is the biggest issue I have found and is the only reason why us 16GB VRAM / 32GB RAM users are experiencing issues where the TE reloads from disk (because it got unloaded to make space for the model) after changing the prompt and model reloads again, both of these contributing to huge slowdowns.

This loader doesn't unload back to RAM at all so it wont spill to pagefile. The idea is, if you dont have enough RAM, just dump it, because its faster to just read it from file on disk again than to write and read to pagefile. If you do have enough RAM, the OS will just leave the model in disk cache from the first load. So this should be faster for you.

You margins are very low, you might do well --disable-pinned-memory but if you try it, try it both ways.

kudos for LTX2 on 16 and these performance points is what im trying to really make work here.

Loading the model file from the disk again does seem to be a nice way to prevent useless writing to pagefile atleast. It will be very useful for 16GB RAM users. I use the model without any changes to startup args with GGUFs. It works perfectly, though the only issue currently is the offloading the whole model to RAM and pagefile to make space for 4GB haha

@zwukong
Copy link

zwukong commented Jan 13, 2026

The ComfyUI-ReservedVRAM node already let comfyui can run any model in any vram. LTX2 can run in 1G vram. Maybe you should take a look .

@FurkanGozukara
Copy link

awesome i hope this gets implemented

@MeiYi-dev
Copy link

MeiYi-dev commented Jan 13, 2026

The ComfyUI-ReservedVRAM node already let comfyui can run any model in any vram. LTX2 can run in 1G vram. Maybe you should take a look .

I am looking at the screenshots of this node, where does the model offload to? The ram use doesn't increase even with model only using 6GB VRAM, what's the caveat?

Edit: NVM it offloads to RAM

@MeiYi-dev
Copy link

MeiYi-dev commented Jan 13, 2026

The ComfyUI-ReservedVRAM node already let comfyui can run any model in any vram. LTX2 can run in 1G vram. Maybe you should take a look .

This PR doesn't do any offloading to RAM like the node you mentioned. This PR just drops the model if enough RAM space is not found, and loads the files for each run (I think) using a faster way.

TLDR, it prevents useless writing to pagefile

@RandomGitUser321
Copy link
Contributor

Would the Windows shared memory avoidance stuff have any effect when using WSL?

By default, Windows will only allow 1/2 of your system memory to be used by WSL(without modifying the .wslconfig with memory=24GB or whatever you want to set it to), so you're already going to run into issues quickly. But as far as I know, ComfyUI should pick up on that value and if it does, then it means any other memory management math should also pick up on it as well. Though at the GPU driver level, I'm not sure how they handle shared memory, when used with WSL.

@rattus128
Copy link
Contributor Author

Would the Windows shared memory avoidance stuff have any effect when using WSL?

If not, and with the changes now maximising VRAM usage (--reserve-vram seemingly no longer functions with dynamic_vram enabled), would that make spilling over into shared memory more likely with WSL?

I've noticed some slowdowns/stalls on subsequent runs along with the normal signs of shared memory sluggishness (low temperatures, low power draw, 100% GPU usage, along with reported shared memory usage in task manager) when testing out the PR, and wonder if the changes might not be suited for WSL as-is.

You are right that I ignore --reserve-vram for the moment. It can be implement with a bit of plumbing and ill take it as a feature request (along with --novram), but we might not do that one in V1 as you can just opt out for the interim.

Yeah so WSL is actually a big problem and very difficult (maybe impossible) to fix with regards to shared memory spilling. When you are under WSL you will present as linux to aimdo which wont have its anti-spill in play which is windows specific. Even if we could detect WSL we would not have access to the APIs needed to detect the spill as they are only visible on the host windows.

WSL has value from a linux familiarity point of view and solves some software packaging problems, but unfortunately the extra layer of indirection between comfy and the gpu creates multiple performance problems. If you optimizing comfy performance and like the linux env I VERY strongly recommend a dual boot setup as I have observed major performance differences in offloading setups where linux just beats windows with all over variables held the same (I dual boot my day-to-day test machine between Ubuntu and Win11).

Be more tolerant of unsupported platforms and fallback properly.
Fixes crash when cuda is not installed at all.
@rattus128
Copy link
Contributor Author

comfy_aimdo seems to break ROCM even when this is not in use. It tries to dynamically load libcuda.so.1 which does not exist.

Thanks for the test. This should be fixed. I trashed my libcuda.so and gracefully got past the aimdo init.

@rattus128
Copy link
Contributor Author

I fixed a major consumer of RAM today WRT to text encoders that are bundled in checkpoints (previously they would still assign=False) load.

I tested LTX2 FP16 1080P with --disable-pinned-memory which is a heavy offload flow even on the 5090.

Peak RAM is this:

image

With the models left in the disk cache.

If I pin the memory its a little bit more. This screenshot is the RAM levels increasing at it start the upsampler steps after it has to evict weights to make space for the much larger working set.

Screenshot from 2026-01-15 17-47-14

@blepping
Copy link
Contributor

blepping commented Jan 15, 2026

we are focusing on improving our own native quant system to make it better/faster than gguf.

This would be tough to do without essentially re-implementing GGUF. GGUF quants at the same BPW are significantly higher quality compared to quants like FP8 because it's a block/group based format. If ComfyUI's native quant system can support group/block based quants then the ideal case would be for GGUF support to be implemented with that so it can work seamlessly.

I also don't think you have to give up stuff like fused matmul (or even FP8 matmul specifically). Fused matmul kernels with GGUF are definitely possible, it is also probably possible to do a fast dequant to FP8 so FP8 matmul can be used.

By the way, I have a collection of GGUF Triton kernels that reduce the current GGUF dequant overhead pretty significantly (especially for complex quants like Q6_K). Unfortunately, I've had very little time to work on personal projects for some time now and haven't been able to complete it but these still are a major performance benefit on Nvidia (and by reports, ROCM as well): city96/ComfyUI-GGUF#336

@zwukong
Copy link

zwukong commented Jan 15, 2026

I agree, i used fp8 wan2.2 before, not as good as Q4, more easier to get bad results. GGUF is slower,but acceptable. Hope you can come up with better solutions. GGUF is the future, if speed can match fp8 or nunchaku int4 ,which is two times faster than gguf

@MeiYi-dev
Copy link

we are focusing on improving our own native quant system to make it better/faster than gguf.

This would be tough to do without essentially re-implementing GGUF. GGUF quants at the same BPW are significantly higher quality compared to quants like FP8 because it's a block/group based format. If ComfyUI's native quant system can support group/block based quants then the ideal case would be for GGUF support to be implemented with that so it can work seamlessly.

I also don't think you have to give up stuff like fused matmul (or even FP8 matmul specifically). Fused matmul kernels with GGUF are definitely possible, it is also probably possible to do a fast dequant to FP8 so FP8 matmul can be used.

By the way, I have a collection of GGUF Triton kernels that reduce the current GGUF dequant overhead pretty significantly (especially for complex quants like Q6_K). Unfortunately, I've had very little time to work on personal projects for some time now and haven't been able to complete it but these still are a major performance benefit on Nvidia (and by reports, ROCM as well): city96/ComfyUI-GGUF#336

"GGUFs with fp8 matmuls" this would be bonkers for us low VRAM/RAM users, hope someone implements it ❤️

@Haoming02
Copy link
Contributor

The idea is, if you dont have enough RAM, just dump it, because its faster to just read it from file on disk again than to write and read to pagefile.

Does this new appraoch cause more/less wear on the SSD? (if both pagefile and models were stored on said SSD)

@anr2me
Copy link

anr2me commented Jan 16, 2026

The idea is, if you dont have enough RAM, just dump it, because its faster to just read it from file on disk again than to write and read to pagefile.

Does this new appraoch cause more/less wear on the SSD? (if both pagefile and models were stored on said SSD)

It should be less wear, it's practically comparing between write + read (when falling to swapfile) vs just read (this PR)

@rattus128
Copy link
Contributor Author

we are focusing on improving our own native quant system to make it better/faster than gguf.

This would be tough to do without essentially re-implementing GGUF. GGUF quants at the same BPW are significantly higher quality compared to quants like FP8 because it's a block/group based format. If ComfyUI's native quant system can support group/block based quants then the ideal case would be for GGUF support to be implemented with that so it can work seamlessly.

I also don't think you have to give up stuff like fused matmul (or even FP8 matmul specifically). Fused matmul kernels with GGUF are definitely possible, it is also probably possible to do a fast dequant to FP8 so FP8 matmul can be used.

I've thought about this and it would be kinda huge for the smaller GPUs if it works.

By the way, I have a collection of GGUF Triton kernels that reduce the current GGUF dequant overhead pretty significantly (especially for complex quants like Q6_K). Unfortunately, I've had very little time to work on personal projects for some time now and haven't been able to complete it but these still are a major performance benefit on Nvidia (and by reports, ROCM as well): city96/ComfyUI-GGUF#336

The idea is, if you dont have enough RAM, just dump it, because its faster to just read it from file on disk again than to write and read to pagefile.

Does this new appraoch cause more/less wear on the SSD? (if both pagefile and models were stored on said SSD)

Less wear. This situation this is avoid is model offload back to ram hitting the pagefile, which means you end up writing the model you already have on disk back to the same disk. You win just reading it from file on disk again.

@rattus128
Copy link
Contributor Author

FYI I am aware of increase windows RAM usage when not configuring a pagefile due to commit quota exhaustion. If anyone is testing please stay tuned for a major fix to windows RAM usage incoming. The VRAM stuff on windows is still testable.

@zwukong
Copy link

zwukong commented Jan 16, 2026

I tried flux 2 klein nvfp4 in 40card, still can work like fp8. It will be super great to let 40card get the same speed as 50s like nunchaku int4.

I wonder if 40s can get same boost through this PR using nvfp4

@MeiYi-dev
Copy link

MeiYi-dev commented Jan 16, 2026

I tried flux 2 klein nvfp4 in 40card, still can work like fp8. It will be super great to let 40card get the same speed as 50s like nunchaku int4.

I wonder if 40s can get same boost through this PR using nvfp4

That would require the nunchaku guys https://github.com/nunchaku-ai/ComfyUI-nunchaku , as 40 series cards don't support fp4 natively.

Also, please keep the discussion about this PR's subject only 🙏

@asagi4
Copy link
Contributor

asagi4 commented Jan 16, 2026

@rattus128 is there any hope of getting a ROCM implementation of this at some point? I'm not sure what it does that requires interfacing with CUDA directly, and the model management code keeps getting further and further away from anything I am able to understand.

If running on Windows, defer creation of the layer parameters until the state
dict is loaded. This avoids a massive charge in windows commit charge spike
when a model is created and not loaded.

This problem doesnt exist on Linux as linux allows RAM overcommit,
however windows does not. Before dynamic memory work this was also a non issue
as every non-quant model would just immediate RAM load and need the memory
anyway.

Make the workaround windows specific, as there may be someone out there with
some training from scratch workflow (which this might break), and assume said
someone is on Linux.
The CoW MMAP as used by safetensors is hardcoded to CoW which forcibly
consumes windows commit charge on a zero copy. RIP. Implement safetensors
in pytorch itself with a READ mmap to not get commit charged for all our
open models.
@rattus128 rattus128 marked this pull request as ready for review January 19, 2026 14:00
@MeiYi-dev
Copy link

MeiYi-dev commented Jan 20, 2026

@rattus128 https://blog.fal.ai/introducing-flashpack-lightning-fast-model-loading-for-pytorch/ Wouldn't FlashPack help a ton with loading from disk? This would speedup loading from disk immensely and complement this PR.

Future Flashpack support would be awesome to have.

This isn't worth it and the likelyhood of inference leaving a complex
data-structure with cyclic reference behind is now. Remove it.

We would replace it with a condition on nodes that actually touch the
GPU which might be win.
This is needed for deepcopy construction. We shouldnt really have deep
copies of MP or MODynamic however this is a stay one in some controlnet
flows.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.