Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ComfyUI-InferenceTimeScaling
![Version](https://img.shields.io/badge/version-0.0.3-blue.svg)

A ComfyUI extension implementing "Inference-time scaling for diffusion models beyond scaling denoising steps" ([Ma et al., 2025](https://arxiv.org/abs/2501.09732)). This extension provides inference-time optimization techniques to enhance diffusion-based image generation quality through random search and zero-order optimization algorithms, along with an ensemble verification system.

Expand All @@ -23,6 +24,24 @@ A ComfyUI extension implementing "Inference-time scaling for diffusion models be

[View example workflows](workflows/)

## How It Works

This extension implements two different search algorithms to find the best possible image for your prompt:

1. **Random Search**: The simplest approach - generates multiple images with different random noises and evaluates them to explore the noise space.

2. **Zero-Order Search**: A more sophisticated approach that performs local optimization. It starts with a random noise, generates nearby variations by perturbing noise, and iteratively moves toward better results based on evaluation.

To explore the noise space, the quality of generated images is evaluated using an ensemble of three verifiers:

- **CLIP Score**: Measures how well the image matches the text prompt using OpenAI's CLIP model
- **ImageReward**: Evaluates image quality and prompt alignment using a specialized reward model
- **Qwen VLM**: Uses a large vision-language model to provide detailed scoring across multiple aspects (visual quality, creativity, prompt accuracy, etc.)

By exploring the noise space and using these verifiers to guide the search, it can produce images of higher quality and better prompt alignment than simply increasing denoising steps, with the tradeoff being increased time and compute during inference.

For more detailed information about the algorithms and methodology, please refer to the original paper from Google DeepMind: ["Inference-time scaling for diffusion models beyond scaling denoising steps"](https://arxiv.org/abs/2501.09732).

## Installation

### Prerequisites
Expand Down Expand Up @@ -79,15 +98,19 @@ This is the main node implementing the random search and zero-order optimization
- `vae`: (VAE) VAE model for decoding latents
- `view_top_k`: (INT) Number of top images to show in grid
- `search_algorithm`: Choice between "random" and "zero-order"
- `num_neighbors`: (INT) Number of neighbors per iteration in zero-order search (only used if search_algorithm is "zero-order")
- `lambda_threshold`: (FLOAT) Perturbation step size for zero-order search (only used if search_algorithm is "zero-order")

> [!IMPORTANT]
> The following parameters are **only used for zero-order search** and have no effect when using random search:
> - `num_neighbors`: (INT) Number of neighbors per iteration in zero-order search
> - `lambda_threshold`: (FLOAT) Perturbation step size for zero-order search

#### Optional Inputs:
- `loaded_clip_score_verifier`: (CS_VERIFIER) CLIP model for scoring
- `loaded_image_reward_verifier`: (IR_VERIFIER) ImageReward model
- `loaded_qwen_verifier`: (QWN_VERIFIER) Qwen VLM model

**Note:** At least one verifier must be included!
> [!NOTE]
> The verifiers are optional - you can choose which ones to use by connecting them to the node. However, at least one verifier must be connected for the node to function!

#### Outputs:
- `Best Image`: The highest-scoring generated image
Expand All @@ -103,6 +126,13 @@ Loads the Qwen VLM verifier model for image evaluation.
#### Inputs:
- `qwen_verifier_id`: Model identifier (default: "Qwen/Qwen2.5-VL-7B-Instruct")
- `device`: Device to load model on ("cuda" or "cpu")
- `score_type`: Type of score to return from the evaluation (default: "overall_score"). Options:
- `overall_score`: Weighted average of all aspects
- `accuracy_to_prompt`: How well the image matches the text description
- `creativity_and_originality`: Uniqueness and creative interpretation
- `visual_quality_and_realism`: Overall visual quality, detail, and realism
- `consistency_and_cohesion`: Internal consistency and natural composition
- `emotional_or_thematic_resonance`: How well the image captures the intended mood/theme

#### Outputs:
- `qwen_verifier_instance`: Loaded Qwen verifier instance
Expand Down Expand Up @@ -154,21 +184,13 @@ The model will be downloaded automatically on first use (you do not need to have

## Future Work

- [x] Enable configurable scoring criteria for Qwen VLM verifier
- Allow users to select specific aspects like visual quality, creativity, etc.
- Support individual aspect scoring
- [ ] Add batch processing support for image generation (performance optimization)
- [ ] Implement batched verification for multiple image-text pairs (speed optimization)
- [ ] Enable configurable scoring criteria for Qwen VLM verifier (currently only uses overall score)
- Allow users to select specific aspects like visual quality, creativity, etc.
- Support weighted combinations of multiple scoring criteria

## Development
- [ ] Add support for image-to-image and image+text conditioning to image models (currently only supports text-to-image models)

To install development dependencies:

```bash
cd inferencescale
pip install -e .[dev]
pre-commit install
```

## License

Expand Down
Binary file modified assets/qwen_verifier_node.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ComfyUI-InferenceTimeScaling"
version = "0.0.2"
version = "0.0.3"
description = "Inference-time techniques to enhance diffusion-based image generation quality through random search and zero-order optimization algorithms"
authors = [
{name = "Max Clouser", email = "max@yrikka.com"}
Expand Down
21 changes: 15 additions & 6 deletions src/inferencescale/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def score_candidates(candidate_tensors: List[torch.Tensor], text_prompt: str, ve
elif verifier_name == "image_reward":
score = verifier.score(text_prompt, pil_img)
elif verifier_name == "qwen_vlm_verifier":
score = verifier.get_overall_score(pil_img, text_prompt)
score = verifier.score(pil_img, text_prompt)
else:
logger.warning(f"Unknown verifier type: {verifier_name}")
continue
Expand Down Expand Up @@ -435,13 +435,19 @@ class LoadQwenVLMVerifier:
def INPUT_TYPES(cls):
return {
"required": {
"qwen_verifier_id": (["Qwen/Qwen2.5-VL-7B-Instruct", None], {
"qwen_verifier_id": (["Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-72B-Instruct"], {
"default": "Qwen/Qwen2.5-VL-7B-Instruct",
"tooltip": "Identifier for the Qwen VLM model."
}),
"device": ("STRING", {
"default": "cuda" if torch.cuda.is_available() else "cpu",
"tooltip": "Device to load the model onto."
}),
"score_type": (["overall_score", "accuracy_to_prompt", "creativity_and_originality",
"visual_quality_and_realism", "consistency_and_cohesion",
"emotional_or_thematic_resonance"], {
"default": "overall_score",
"tooltip": "Type of score to return from the Qwen model evaluation."
})
}
}
Expand All @@ -452,7 +458,7 @@ def INPUT_TYPES(cls):
CATEGORY = "InferenceTimeScaling"
DESCRIPTION = "Loads the Qwen VLM verifier model. Downloads it if necessary."

def execute(self, qwen_verifier_id, device):
def execute(self, qwen_verifier_id, device, score_type):
# Construct a local comfyui checkpoint path for the model
model_checkpoint = os.path.join(folder_paths.models_dir, "LLM", os.path.basename(qwen_verifier_id))
if not os.path.exists(model_checkpoint):
Expand All @@ -461,7 +467,7 @@ def execute(self, qwen_verifier_id, device):
local_dir=model_checkpoint,
local_dir_use_symlinks=False,
)
verifier_instance = QwenVLMVerifier(qwen_verifier_id, device)
verifier_instance = QwenVLMVerifier(model_checkpoint, device, score_type=score_type)
return (verifier_instance,)


Expand Down Expand Up @@ -489,7 +495,10 @@ class LoadCLIPScoreVerifier:
def INPUT_TYPES(cls):
return {
"required": {
"clip_verifier_id": (["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14", None], {
"clip_verifier_id": (["openai/clip-vit-base-patch32",
"openai/clip-vit-large-patch14",
"openai/clip-vit-base-patch16",
"openai/clip-vit-large-patch14-336"], {
"default": "openai/clip-vit-base-patch32",
"tooltip": "Identifier for the CLIP model."
}),
Expand Down Expand Up @@ -528,7 +537,7 @@ class LoadImageRewardVerifier:
def INPUT_TYPES(cls):
return {
"required": {
"ir_verifier_id": (["ImageReward-v1.0", None], {
"ir_verifier_id": (["ImageReward-v1.0"], {
"default": "ImageReward-v1.0",
"tooltip": "Identifier for the ImageReward model."
}),
Expand Down
85 changes: 31 additions & 54 deletions src/inferencescale/qwen_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,25 @@ class Grading(BaseModel):

class QwenVLMVerifier():

def __init__(self, model_name, device='cpu'):
def __init__(self, model_name, device='cpu', score_type='overall_score'):
logger.info(f"Initializing QwenVLMVerifier with model {model_name} on device {device}")
self.model_name = model_name
self.device = device
self.dtype = torch.float16 if "cuda" in self.device else torch.float16
self.dtype = torch.float16

# Validate score_type
valid_score_types = {
'accuracy_to_prompt',
'creativity_and_originality',
'visual_quality_and_realism',
'consistency_and_cohesion',
'emotional_or_thematic_resonance',
'overall_score'
}
if score_type not in valid_score_types:
raise ValueError(f"Invalid score_type. Must be one of: {valid_score_types}")

self.score_type = score_type
self.load_model()


Expand All @@ -96,37 +110,21 @@ def load_model(self):
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28

logger.debug(f"Loading model from {self.model_name}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
device_map="auto",
low_cpu_mem_usage=True
)
logger.info("Model loaded successfully")

logger.debug("Loading processor")
processor = AutoProcessor.from_pretrained(
self.model_name, min_pixels=min_pixels, max_pixels=max_pixels
)
logger.info("Processor loaded successfully")

logger.debug("Initializing transformers vision")
self.qwen_model = transformers_vision(
self.model_name,
model_class=model.__class__,
model_class=Qwen2_5_VLForConditionalGeneration,
device=self.device,
model_kwargs={"torch_dtype": self.dtype},
processor_class=processor.__class__,
processor_class=AutoProcessor,
processor_kwargs={"min_pixels": min_pixels, "max_pixels": max_pixels}
)
logger.info("Transformers vision initialized")

logger.debug("Setting up structured generator")
self.structured_qwen_generator = outlines.generate.json(self.qwen_model, Grading)
logger.info("Structured generator setup complete")

del model
del processor
torch.cuda.empty_cache()
gc.collect()
logger.info("Memory cleanup completed")
Expand Down Expand Up @@ -173,21 +171,20 @@ def query_model(self, image, prompt: str, max_tokens: int = None, seed: int = 42
raise


def get_overall_score(self, image, prompt: str, max_tokens: int = None, seed: int = 42) -> float:
# TODO Extend to handle any score key (instead of just overall score) as input
logger.info("Getting overall score")
def score(self, image, prompt: str, max_tokens: int = None, seed: int = 42) -> float:
logger.info(f"Getting {self.score_type} score")
try:
outputs = self.query_model(image, prompt, max_tokens, seed)
overall_score = outputs["overall_score"]["score"]
score = outputs[self.score_type]["score"]

if overall_score:
logger.debug(f"Overall score calculated: {overall_score}")
return float(overall_score)
if score is not None:
logger.debug(f"{self.score_type} score calculated: {score}")
return float(score)

logger.warning("Overall score not found in model output")
logger.warning(f"{self.score_type} score not found in model output")
return 0.0
except Exception as e:
logger.error(f"Error getting overall score: {str(e)}", exc_info=True)
logger.error(f"Error getting {self.score_type} score: {str(e)}", exc_info=True)
return 0.0


Expand All @@ -196,41 +193,21 @@ def get_overall_score(self, image, prompt: str, max_tokens: int = None, seed: in

model_name = "Qwen/Qwen2.5-VL-7B-Instruct"

model = QwenVLMVerifier(model_name=model_name, device=device)
# model.load_model()
model = QwenVLMVerifier(model_name=model_name, device=device, score_type='visual_quality_and_realism')

image_path = "596F6DF4-2856-436E-A981-649ABFB15F1B.jpeg"
image = Image.open(image_path).convert("RGB")

test_prompt = "A red bird and a fish."

response = model.query_model(image, test_prompt)
print("Model Response:", response)

aspect_keys = [
"accuracy_to_prompt",
"creativity_and_originality",
"visual_quality_and_realism",
"consistency_and_cohesion",
"emotional_or_thematic_resonance"
]

scores = []
for key in aspect_keys:
if key in response and "score" in response[key]:
scores.append(response[key]["score"])

if scores:
average_score = sum(scores) / len(scores)
print("Average Score:", average_score)
else:
print("No scores found to average.")
visual_quality_score = model.score(image, test_prompt)
print(f"Visual quality score: {visual_quality_score}")

# model.to_device('cpu')
# model.to_device('cuda')

response = model.query_model(image, test_prompt)
print("Model Response:", response)

overall_score = model.get_overall_score(image, test_prompt)
overall_score = model.score(image, test_prompt)
print(f"Overall score: {overall_score}")
Loading