diff --git a/README.md b/README.md index 391d236..070500f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # ComfyUI-InferenceTimeScaling +![Version](https://img.shields.io/badge/version-0.0.3-blue.svg) A ComfyUI extension implementing "Inference-time scaling for diffusion models beyond scaling denoising steps" ([Ma et al., 2025](https://arxiv.org/abs/2501.09732)). This extension provides inference-time optimization techniques to enhance diffusion-based image generation quality through random search and zero-order optimization algorithms, along with an ensemble verification system. @@ -23,6 +24,24 @@ A ComfyUI extension implementing "Inference-time scaling for diffusion models be [View example workflows](workflows/) +## How It Works + +This extension implements two different search algorithms to find the best possible image for your prompt: + +1. **Random Search**: The simplest approach - generates multiple images with different random noises and evaluates them to explore the noise space. + +2. **Zero-Order Search**: A more sophisticated approach that performs local optimization. It starts with a random noise, generates nearby variations by perturbing noise, and iteratively moves toward better results based on evaluation. + +To explore the noise space, the quality of generated images is evaluated using an ensemble of three verifiers: + +- **CLIP Score**: Measures how well the image matches the text prompt using OpenAI's CLIP model +- **ImageReward**: Evaluates image quality and prompt alignment using a specialized reward model +- **Qwen VLM**: Uses a large vision-language model to provide detailed scoring across multiple aspects (visual quality, creativity, prompt accuracy, etc.) + +By exploring the noise space and using these verifiers to guide the search, it can produce images of higher quality and better prompt alignment than simply increasing denoising steps, with the tradeoff being increased time and compute during inference. + +For more detailed information about the algorithms and methodology, please refer to the original paper from Google DeepMind: ["Inference-time scaling for diffusion models beyond scaling denoising steps"](https://arxiv.org/abs/2501.09732). + ## Installation ### Prerequisites @@ -79,15 +98,19 @@ This is the main node implementing the random search and zero-order optimization - `vae`: (VAE) VAE model for decoding latents - `view_top_k`: (INT) Number of top images to show in grid - `search_algorithm`: Choice between "random" and "zero-order" -- `num_neighbors`: (INT) Number of neighbors per iteration in zero-order search (only used if search_algorithm is "zero-order") -- `lambda_threshold`: (FLOAT) Perturbation step size for zero-order search (only used if search_algorithm is "zero-order") + +> [!IMPORTANT] +> The following parameters are **only used for zero-order search** and have no effect when using random search: +> - `num_neighbors`: (INT) Number of neighbors per iteration in zero-order search +> - `lambda_threshold`: (FLOAT) Perturbation step size for zero-order search #### Optional Inputs: - `loaded_clip_score_verifier`: (CS_VERIFIER) CLIP model for scoring - `loaded_image_reward_verifier`: (IR_VERIFIER) ImageReward model - `loaded_qwen_verifier`: (QWN_VERIFIER) Qwen VLM model -**Note:** At least one verifier must be included! +> [!NOTE] +> The verifiers are optional - you can choose which ones to use by connecting them to the node. However, at least one verifier must be connected for the node to function! #### Outputs: - `Best Image`: The highest-scoring generated image @@ -103,6 +126,13 @@ Loads the Qwen VLM verifier model for image evaluation. #### Inputs: - `qwen_verifier_id`: Model identifier (default: "Qwen/Qwen2.5-VL-7B-Instruct") - `device`: Device to load model on ("cuda" or "cpu") +- `score_type`: Type of score to return from the evaluation (default: "overall_score"). Options: + - `overall_score`: Weighted average of all aspects + - `accuracy_to_prompt`: How well the image matches the text description + - `creativity_and_originality`: Uniqueness and creative interpretation + - `visual_quality_and_realism`: Overall visual quality, detail, and realism + - `consistency_and_cohesion`: Internal consistency and natural composition + - `emotional_or_thematic_resonance`: How well the image captures the intended mood/theme #### Outputs: - `qwen_verifier_instance`: Loaded Qwen verifier instance @@ -154,21 +184,13 @@ The model will be downloaded automatically on first use (you do not need to have ## Future Work +- [x] Enable configurable scoring criteria for Qwen VLM verifier + - Allow users to select specific aspects like visual quality, creativity, etc. + - Support individual aspect scoring - [ ] Add batch processing support for image generation (performance optimization) - [ ] Implement batched verification for multiple image-text pairs (speed optimization) -- [ ] Enable configurable scoring criteria for Qwen VLM verifier (currently only uses overall score) - - Allow users to select specific aspects like visual quality, creativity, etc. - - Support weighted combinations of multiple scoring criteria - -## Development +- [ ] Add support for image-to-image and image+text conditioning to image models (currently only supports text-to-image models) -To install development dependencies: - -```bash -cd inferencescale -pip install -e .[dev] -pre-commit install -``` ## License diff --git a/assets/qwen_verifier_node.jpeg b/assets/qwen_verifier_node.jpeg index f91b26a..7c73557 100644 Binary files a/assets/qwen_verifier_node.jpeg and b/assets/qwen_verifier_node.jpeg differ diff --git a/pyproject.toml b/pyproject.toml index 13f4ca2..1930550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ComfyUI-InferenceTimeScaling" -version = "0.0.2" +version = "0.0.3" description = "Inference-time techniques to enhance diffusion-based image generation quality through random search and zero-order optimization algorithms" authors = [ {name = "Max Clouser", email = "max@yrikka.com"} diff --git a/src/inferencescale/nodes.py b/src/inferencescale/nodes.py index 1d5271c..c4d6e80 100644 --- a/src/inferencescale/nodes.py +++ b/src/inferencescale/nodes.py @@ -149,7 +149,7 @@ def score_candidates(candidate_tensors: List[torch.Tensor], text_prompt: str, ve elif verifier_name == "image_reward": score = verifier.score(text_prompt, pil_img) elif verifier_name == "qwen_vlm_verifier": - score = verifier.get_overall_score(pil_img, text_prompt) + score = verifier.score(pil_img, text_prompt) else: logger.warning(f"Unknown verifier type: {verifier_name}") continue @@ -435,13 +435,19 @@ class LoadQwenVLMVerifier: def INPUT_TYPES(cls): return { "required": { - "qwen_verifier_id": (["Qwen/Qwen2.5-VL-7B-Instruct", None], { + "qwen_verifier_id": (["Qwen/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-72B-Instruct"], { "default": "Qwen/Qwen2.5-VL-7B-Instruct", "tooltip": "Identifier for the Qwen VLM model." }), "device": ("STRING", { "default": "cuda" if torch.cuda.is_available() else "cpu", "tooltip": "Device to load the model onto." + }), + "score_type": (["overall_score", "accuracy_to_prompt", "creativity_and_originality", + "visual_quality_and_realism", "consistency_and_cohesion", + "emotional_or_thematic_resonance"], { + "default": "overall_score", + "tooltip": "Type of score to return from the Qwen model evaluation." }) } } @@ -452,7 +458,7 @@ def INPUT_TYPES(cls): CATEGORY = "InferenceTimeScaling" DESCRIPTION = "Loads the Qwen VLM verifier model. Downloads it if necessary." - def execute(self, qwen_verifier_id, device): + def execute(self, qwen_verifier_id, device, score_type): # Construct a local comfyui checkpoint path for the model model_checkpoint = os.path.join(folder_paths.models_dir, "LLM", os.path.basename(qwen_verifier_id)) if not os.path.exists(model_checkpoint): @@ -461,7 +467,7 @@ def execute(self, qwen_verifier_id, device): local_dir=model_checkpoint, local_dir_use_symlinks=False, ) - verifier_instance = QwenVLMVerifier(qwen_verifier_id, device) + verifier_instance = QwenVLMVerifier(model_checkpoint, device, score_type=score_type) return (verifier_instance,) @@ -489,7 +495,10 @@ class LoadCLIPScoreVerifier: def INPUT_TYPES(cls): return { "required": { - "clip_verifier_id": (["openai/clip-vit-base-patch32", "openai/clip-vit-large-patch14", None], { + "clip_verifier_id": (["openai/clip-vit-base-patch32", + "openai/clip-vit-large-patch14", + "openai/clip-vit-base-patch16", + "openai/clip-vit-large-patch14-336"], { "default": "openai/clip-vit-base-patch32", "tooltip": "Identifier for the CLIP model." }), @@ -528,7 +537,7 @@ class LoadImageRewardVerifier: def INPUT_TYPES(cls): return { "required": { - "ir_verifier_id": (["ImageReward-v1.0", None], { + "ir_verifier_id": (["ImageReward-v1.0"], { "default": "ImageReward-v1.0", "tooltip": "Identifier for the ImageReward model." }), diff --git a/src/inferencescale/qwen_verifier.py b/src/inferencescale/qwen_verifier.py index 1be7cc8..7f9d025 100644 --- a/src/inferencescale/qwen_verifier.py +++ b/src/inferencescale/qwen_verifier.py @@ -82,11 +82,25 @@ class Grading(BaseModel): class QwenVLMVerifier(): - def __init__(self, model_name, device='cpu'): + def __init__(self, model_name, device='cpu', score_type='overall_score'): logger.info(f"Initializing QwenVLMVerifier with model {model_name} on device {device}") self.model_name = model_name self.device = device - self.dtype = torch.float16 if "cuda" in self.device else torch.float16 + self.dtype = torch.float16 + + # Validate score_type + valid_score_types = { + 'accuracy_to_prompt', + 'creativity_and_originality', + 'visual_quality_and_realism', + 'consistency_and_cohesion', + 'emotional_or_thematic_resonance', + 'overall_score' + } + if score_type not in valid_score_types: + raise ValueError(f"Invalid score_type. Must be one of: {valid_score_types}") + + self.score_type = score_type self.load_model() @@ -96,28 +110,14 @@ def load_model(self): min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 - logger.debug(f"Loading model from {self.model_name}") - model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - self.model_name, - torch_dtype=self.dtype, - device_map="auto", - low_cpu_mem_usage=True - ) - logger.info("Model loaded successfully") - - logger.debug("Loading processor") - processor = AutoProcessor.from_pretrained( - self.model_name, min_pixels=min_pixels, max_pixels=max_pixels - ) - logger.info("Processor loaded successfully") - logger.debug("Initializing transformers vision") self.qwen_model = transformers_vision( self.model_name, - model_class=model.__class__, + model_class=Qwen2_5_VLForConditionalGeneration, device=self.device, model_kwargs={"torch_dtype": self.dtype}, - processor_class=processor.__class__, + processor_class=AutoProcessor, + processor_kwargs={"min_pixels": min_pixels, "max_pixels": max_pixels} ) logger.info("Transformers vision initialized") @@ -125,8 +125,6 @@ def load_model(self): self.structured_qwen_generator = outlines.generate.json(self.qwen_model, Grading) logger.info("Structured generator setup complete") - del model - del processor torch.cuda.empty_cache() gc.collect() logger.info("Memory cleanup completed") @@ -173,21 +171,20 @@ def query_model(self, image, prompt: str, max_tokens: int = None, seed: int = 42 raise - def get_overall_score(self, image, prompt: str, max_tokens: int = None, seed: int = 42) -> float: - # TODO Extend to handle any score key (instead of just overall score) as input - logger.info("Getting overall score") + def score(self, image, prompt: str, max_tokens: int = None, seed: int = 42) -> float: + logger.info(f"Getting {self.score_type} score") try: outputs = self.query_model(image, prompt, max_tokens, seed) - overall_score = outputs["overall_score"]["score"] + score = outputs[self.score_type]["score"] - if overall_score: - logger.debug(f"Overall score calculated: {overall_score}") - return float(overall_score) + if score is not None: + logger.debug(f"{self.score_type} score calculated: {score}") + return float(score) - logger.warning("Overall score not found in model output") + logger.warning(f"{self.score_type} score not found in model output") return 0.0 except Exception as e: - logger.error(f"Error getting overall score: {str(e)}", exc_info=True) + logger.error(f"Error getting {self.score_type} score: {str(e)}", exc_info=True) return 0.0 @@ -196,35 +193,15 @@ def get_overall_score(self, image, prompt: str, max_tokens: int = None, seed: in model_name = "Qwen/Qwen2.5-VL-7B-Instruct" - model = QwenVLMVerifier(model_name=model_name, device=device) - # model.load_model() + model = QwenVLMVerifier(model_name=model_name, device=device, score_type='visual_quality_and_realism') image_path = "596F6DF4-2856-436E-A981-649ABFB15F1B.jpeg" image = Image.open(image_path).convert("RGB") test_prompt = "A red bird and a fish." - response = model.query_model(image, test_prompt) - print("Model Response:", response) - - aspect_keys = [ - "accuracy_to_prompt", - "creativity_and_originality", - "visual_quality_and_realism", - "consistency_and_cohesion", - "emotional_or_thematic_resonance" - ] - - scores = [] - for key in aspect_keys: - if key in response and "score" in response[key]: - scores.append(response[key]["score"]) - - if scores: - average_score = sum(scores) / len(scores) - print("Average Score:", average_score) - else: - print("No scores found to average.") + visual_quality_score = model.score(image, test_prompt) + print(f"Visual quality score: {visual_quality_score}") # model.to_device('cpu') # model.to_device('cuda') @@ -232,5 +209,5 @@ def get_overall_score(self, image, prompt: str, max_tokens: int = None, seed: in response = model.query_model(image, test_prompt) print("Model Response:", response) - overall_score = model.get_overall_score(image, test_prompt) + overall_score = model.score(image, test_prompt) print(f"Overall score: {overall_score}")