Skip to content

[core] support device type device_maps to work with offloading.#12811

Merged
sayakpaul merged 18 commits intomainfrom
device-map-direct
Feb 16, 2026
Merged

[core] support device type device_maps to work with offloading.#12811
sayakpaul merged 18 commits intomainfrom
device-map-direct

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

This PR allows users to pass a device_map="cpu" while initializing a pipeline and then enable model CPU offloading.

This is beneficial when users want to initialize the models on CPU (think of low VRAM environments) and then call enable_model_cpu_offload(). Quantized models initialize directly on a supported accelerator. This can lead to OOMs.

Below provides a diff that this PR introduces:

import torch
from diffusers import Flux2Pipeline, AutoModel
from transformers import Mistral3ForConditionalGeneration

repo_id = "diffusers/FLUX.2-dev-bnb-4bit" # quantized text-encoder and DiT. VAE still in bf16
device = "cuda:0"
torch_dtype = torch.bfloat16

- text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
-     repo_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- dit = AutoModel.from_pretrained(
-     repo_id, subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cpu"
- )
- pipe = Flux2Pipeline.from_pretrained(
-     repo_id, text_encoder=text_encoder, transformer=dit, torch_dtype=torch_dtype
- )
- pipe.enable_model_cpu_offload()
+ pipe = Flux2Pipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16, device_map="cpu")
+ pipe.enable_model_cpu_offload()

prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL + Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
image = pipe(
    prompt=prompt,
    generator=torch.Generator(device=device).manual_seed(42),
    num_inference_steps=50,
    guidance_scale=4,
).images[0]

image.save("flux2_output.png")

cc: @asomoza @apolinario

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu December 9, 2025 06:43
@sayakpaul
Copy link
Member Author

@yiyixuxu @DN6 a gentle ping.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

@sayakpaul sayakpaul added the roadmap Add to current release roadmap label Feb 16, 2026
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)

@require_torch_accelerator
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that this runs on the CPUs too as it's supported.

Comment on lines +631 to +635
@pytest.mark.parametrize(
"config_name",
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is specified to bitsandbytes for now.

@sayakpaul sayakpaul merged commit 35086ac into main Feb 16, 2026
32 of 33 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.37 Feb 16, 2026
@sayakpaul sayakpaul deleted the device-map-direct branch February 16, 2026 11:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

roadmap Add to current release roadmap

Projects

Development

Successfully merging this pull request may close these issues.

3 participants