diff --git a/README.md b/README.md index 447cdfc..9d6a90e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![](https://img.shields.io/badge/Hemm-docs-blue)](https://wandb.github.io/Hemm/) -Hemm is a library for performing comprehensive benchmark of text-to-image diffusion models on image quality and prompt comprehension integrated with [Weights & Biases](https://wandb.ai/site) and [Weave](https://wandb.github.io/weave/). +Hemm is a library for performing comprehensive benchmark of text-to-image diffusion models on image quality and prompt comprehension integrated with [Weave](https://wandb.github.io/weave/), a lightweight toolkit for tracking and evaluating LLM applications, built by [Weights & Biases](https://wandb.ai/site). Hemm is highly inspired by the following projects: - [Holistic Evaluation of Text-To-Image Models](https://crfm.stanford.edu/helm/heim/v1.0.0/) @@ -37,49 +37,32 @@ First, you need to publish your evaluation dataset to Weave. Check out [this tut Once you have a dataset on your Weave project, you can evaluate a text-to-image generation model on the metrics. ```python -import wandb +import asyncio import weave +from hemm.metrics.vqa import MultiModalLLMEvaluationMetric +from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge +from hemm.models import DiffusersModel - -from hemm.eval_pipelines import EvaluationPipeline -from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric -from hemm.models import BaseDiffusionModel - - -# Initialize Weave and WandB -wandb.init(project="image-quality-leaderboard", job_type="evaluation") +# Initialize Weave weave.init(project_name="image-quality-leaderboard") +# The `DiffusersModel` is a `weave.Model` that uses a +# `diffusers.DiffusionPipeline` under the hood. +# You can write your own model `weave.Model` if your +# model is not diffusers compatible. +model = DiffusersModel( + diffusion_model_name_or_path="stabilityai/stable-diffusion-2-1", + image_height=1024, + image_width=1024, +) -# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel` -# The `BaseDiffusionModel` class uses a `diffusers.DiffusionPipeline` under the hood. -# You can write your own model `weave.Model` if your model is not diffusers compatible. -model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") - - -# Add the model to the evaluation pipeline -evaluation_pipeline = EvaluationPipeline(model=model) - - -# Add PSNR Metric to the evaluation pipeline -psnr_metric = PSNRMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(psnr_metric) - - -# Add SSIM Metric to the evaluation pipeline -ssim_metric = SSIMMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(ssim_metric) - - -# Add LPIPS Metric to the evaluation pipeline -lpips_metric = LPIPSMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(lpips_metric) - +# Define the metric +metric = MultiModalLLMEvaluationMetric(judge=OpenAIJudge()) # Get the Weave dataset reference -dataset = weave.ref("COCO:v0").get() - +dataset=weave.ref("Dataset:v2").get() # Evaluate! -evaluation_pipeline(dataset=dataset) +evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) +summary = asyncio.run(evaluation.evaluate(model)) ``` diff --git a/docs/eval_pipelines.md b/docs/eval_pipelines.md deleted file mode 100644 index ad09ec6..0000000 --- a/docs/eval_pipelines.md +++ /dev/null @@ -1,5 +0,0 @@ -# Evaluation Pipelines - -Hemm evaluation pipelines for [Diffusers pipelines](https://huggingface.co/docs/diffusers/using-diffusers/loading#diffusion-pipeline). - -::: hemm.eval_pipelines diff --git a/docs/index.md b/docs/index.md index 565ba03..972d172 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,6 @@ # Hemm: Holistic Evaluation of Multi-modal Generative Models -Hemm is a library for performing comprehensive benchmark of text-to-image diffusion models on image quality and prompt comprehension integrated with [Weights & Biases](https://wandb.ai/site) and [Weave](https://wandb.github.io/weave/). +Hemm is a library for performing comprehensive benchmark of text-to-image diffusion models on image quality and prompt comprehension integrated with [Weave](https://wandb.github.io/weave/), a lightweight toolkit for tracking and evaluating LLM applications, built by [Weights & Biases](https://wandb.ai/site). Hemm is highly inspired by the following projects: @@ -39,48 +39,32 @@ First, you need to publish your evaluation dataset to Weave. Check out [this tut Once you have a dataset on your Weave project, you can evaluate a text-to-image generation model on the metrics. ```python -import wandb +import asyncio import weave +from hemm.metrics.vqa import MultiModalLLMEvaluationMetric +from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge +from hemm.models import DiffusersModel - -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline -from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric - - -# Initialize Weave and WandB -wandb.init(project="image-quality-leaderboard", job_type="evaluation") +# Initialize Weave weave.init(project_name="image-quality-leaderboard") +# The `DiffusersModel` is a `weave.Model` that uses a +# `diffusers.DiffusionPipeline` under the hood. +# You can write your own model `weave.Model` if your +# model is not diffusers compatible. +model = DiffusersModel( + diffusion_model_name_or_path="stabilityai/stable-diffusion-2-1", + image_height=1024, + image_width=1024, +) -# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel` -# The `BaseDiffusionModel` class uses a `diffusers.DiffusionPipeline` under the hood. -# You can write your own model `weave.Model` if your model is not diffusers compatible. -model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") - - -# Add the model to the evaluation pipeline -evaluation_pipeline = EvaluationPipeline(model=model) - - -# Add PSNR Metric to the evaluation pipeline -psnr_metric = PSNRMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(psnr_metric) - - -# Add SSIM Metric to the evaluation pipeline -ssim_metric = SSIMMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(ssim_metric) - - -# Add LPIPS Metric to the evaluation pipeline -lpips_metric = LPIPSMetric(image_size=evaluation_pipeline.image_size) -evaluation_pipeline.add_metric(lpips_metric) - +# Define the metric +metric = MultiModalLLMEvaluationMetric(judge=OpenAIJudge()) # Get the Weave dataset reference -dataset = weave.ref("COCO:v0").get() - +dataset=weave.ref("Dataset:v2").get() # Evaluate! -evaluation_pipeline(dataset=dataset) +evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) +summary = asyncio.run(evaluation.evaluate(model)) ``` diff --git a/docs/metrics/spatial_relationship.md b/docs/metrics/spatial_relationship.md index 97b5ef8..0012575 100644 --- a/docs/metrics/spatial_relationship.md +++ b/docs/metrics/spatial_relationship.md @@ -20,21 +20,18 @@ This module aims to implement the Spatial relationship metric described in secti ## Step 2: Evaluate ```python - import wandb + import asyncio import weave - from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline + from hemm.models import DiffusersModel + from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric - # Initialize Weave and WandB - wandb.init(project="image-quality-leaderboard", job_type="evaluation") + # Initialize Weave weave.init(project_name="image-quality-leaderboard") - # Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel` - model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") - - # Add the model to the evaluation pipeline - evaluation_pipeline = EvaluationPipeline(model=model) + # Initialize the diffusion model to be evaluated as a `weave.Model` + model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") # Define the judge model for 2d spatial relationship metric judge = DETRSpatialRelationShipJudge( @@ -43,10 +40,11 @@ This module aims to implement the Spatial relationship metric described in secti # Add 2d spatial relationship Metric to the evaluation pipeline metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score") - evaluation_pipeline.add_metric(metric) # Evaluate! - evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0") + dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get() + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + summary = asyncio.run(evaluation.evaluate(model)) ``` ## Metrics diff --git a/docs/metrics/vqa/disentangled_vqa.md b/docs/metrics/vqa/disentangled_vqa.md index deb1366..a73d982 100644 --- a/docs/metrics/vqa/disentangled_vqa.md +++ b/docs/metrics/vqa/disentangled_vqa.md @@ -33,25 +33,29 @@ This module aims to implement the Disentangled VQA metric inspired by Section 4. ## Step 2: Evaluate ```python - import wandb + import asyncio + import weave - wandb.init(project=project, entity=entity, job_type="evaluation") + from hemm.metrics.vqa import DisentangledVQAMetric + from hemm.metrics.vqa.judges import BlipVQAJudge + from hemm.models import DiffusersModel + weave.init(project_name=project) - diffusion_model = BaseDiffusionModel( + diffusion_model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_address, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, image_height=image_size[0], image_width=image_size[1], ) - evaluation_pipeline = EvaluationPipeline(model=diffusion_model) judge = BlipVQAJudge() metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric") evaluation_pipeline.add_metric(metric) - evaluation_pipeline(dataset=dataset) + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) ``` ## Metrics diff --git a/docs/metrics/vqa/multi_modal_llm.md b/docs/metrics/vqa/multi_modal_llm.md index ebfcaff..2861a27 100644 --- a/docs/metrics/vqa/multi_modal_llm.md +++ b/docs/metrics/vqa/multi_modal_llm.md @@ -22,31 +22,25 @@ This module aims to implement the Multi-modal LLM based metric inspired by ``` Finallly, you can run the following snippet to evaluate your model: ```python - import wandb + import asyncio + import weave - from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.vqa import MultiModalLLMEvaluationMetric - from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory - - wandb.init(project="mllm-eval", job_type="evaluation") - weave.init(project_name="mllm-eval") + from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge + from hemm.models import DiffusersModel - dataset = weave.ref(dataset_ref).get() + weave.init(project_name="hemm-eval/mllm-eval") - diffusion_model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path="stabilityai/stable-diffusion-2-1", - enable_cpu_offfload=False, - image_height=512, - image_width=512, + image_height=1024, + image_width=1024, ) - evaluation_pipeline = EvaluationPipeline(model=diffusion_model) - - judge = OpenAIJudge(prompt_property=PromptCategory.complex) - metric = MultiModalLLMEvaluationMetric(judge=judge) - evaluation_pipeline.add_metric(metric) + metric = MultiModalLLMEvaluationMetric(judge=OpenAIJudge()) - evaluation_pipeline(dataset=dataset) + evaluation = weave.Evaluation(dataset=weave.ref("Dataset:v2").get(), scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) ``` ## Metrics diff --git a/examples/2d_spatial_eval/evaluate_spatial_relationship_detr.py b/examples/2d_spatial_eval/evaluate_spatial_relationship_detr.py index d718e83..47f6c90 100644 --- a/examples/2d_spatial_eval/evaluate_spatial_relationship_detr.py +++ b/examples/2d_spatial_eval/evaluate_spatial_relationship_detr.py @@ -1,13 +1,12 @@ -from typing import Optional, Tuple +import asyncio +from typing import Optional import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( @@ -17,33 +16,31 @@ def main( dataset_limit: Optional[int] = None, diffusion_model_address: str = "stabilityai/stable-diffusion-2-1", diffusion_model_enable_cpu_offfload: bool = False, - image_size: Tuple[int, int] = (1024, 1024), + image_height: int = 1024, + image_width: int = 1024, detr_model_address: str = "facebook/detr-resnet-50", detr_revision: str = "no_timm", + iou_threshold: Optional[float] = 0.1, ): - wandb.init(project=project, entity=entity, job_type="evaluation") weave.init(project_name=f"{entity}/{project}") dataset = weave.ref(dataset_ref).get() dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset - diffusion_model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_address, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, - image_height=image_size[0], - image_width=image_size[1], + image_height=image_height, + image_width=image_width, ) - evaluation_pipeline = EvaluationPipeline(model=diffusion_model) judge = DETRSpatialRelationShipJudge( model_address=detr_model_address, revision=detr_revision ) - metric = SpatialRelationshipMetric2D( - judge=judge, name="2d_spatial_relationship_score" - ) - evaluation_pipeline.add_metric(metric) + metric = SpatialRelationshipMetric2D(judge=judge, iou_threshold=iou_threshold) - evaluation_pipeline(dataset=dataset) + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/examples/disentangled_vqa/evaluate_disentangled_vqa.py b/examples/disentangled_vqa/evaluate_disentangled_vqa.py index afef6d7..61a85b7 100644 --- a/examples/disentangled_vqa/evaluate_disentangled_vqa.py +++ b/examples/disentangled_vqa/evaluate_disentangled_vqa.py @@ -1,13 +1,12 @@ -from typing import Optional, Tuple +import asyncio +from typing import Optional import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.vqa import DisentangledVQAMetric from hemm.metrics.vqa.judges import BlipVQAJudge -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( @@ -17,29 +16,27 @@ def main( dataset_limit: Optional[int] = None, diffusion_model_address: str = "stabilityai/stable-diffusion-2-1", diffusion_model_enable_cpu_offfload: bool = False, - image_size: Tuple[int, int] = (1024, 1024), + image_height: int = 1024, + image_width: int = 1024, ): - wandb.init(project=project, entity=entity, job_type="evaluation") weave.init(project_name=f"{entity}/{project}") dataset = weave.ref(dataset_ref).get() dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset - diffusion_model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_address, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, - image_height=image_size[0], - image_width=image_size[1], - pipeline_configs={"variant": "fp16", "use_safetensors": True}, + image_height=image_height, + image_width=image_width, ) - diffusion_model._pipeline.set_progress_bar_config(disable=True) - evaluation_pipeline = EvaluationPipeline(model=diffusion_model) + model._pipeline.set_progress_bar_config(disable=True) judge = BlipVQAJudge() metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric") - evaluation_pipeline.add_metric(metric) - evaluation_pipeline(dataset=dataset) + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/examples/evaluate_weave_image_quality.py b/examples/evaluate_weave_image_quality.py index f1de969..ac82e34 100644 --- a/examples/evaluate_weave_image_quality.py +++ b/examples/evaluate_weave_image_quality.py @@ -1,38 +1,32 @@ +import asyncio + import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( project_name: str = "image-quality", diffusion_model_name_or_path="stabilityai/stable-diffusion-2-1", dataset_ref: str = "COCO:v0", + image_height: int = 1024, + image_width: int = 1024, ): - wandb.init(project=project_name, job_type="evaluation") weave.init(project_name=project_name) - model = BaseDiffusionModel( - diffusion_model_name_or_path=diffusion_model_name_or_path - ) - evaluation_pipeline = EvaluationPipeline(model=model) - - # Add PSNR Metric - psnr_metric = PSNRMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(psnr_metric) + model = DiffusersModel(diffusion_model_name_or_path=diffusion_model_name_or_path) - # Add SSIM Metric - ssim_metric = SSIMMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(ssim_metric) + psnr_metric = PSNRMetric(image_size=(image_height, image_width)) + ssim_metric = SSIMMetric(image_size=(image_height, image_width)) + lpips_metric = LPIPSMetric(image_size=(image_height, image_width)) - # Add LPIPS Metric - lpips_metric = LPIPSMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(lpips_metric) - - evaluation_pipeline(dataset=dataset_ref) + dataset = weave.ref(dataset_ref).get() + evaluation = weave.Evaluation( + dataset=dataset, scorers=[psnr_metric, ssim_metric, lpips_metric] + ) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/examples/evaluate_weave_prompt_alignment.py b/examples/evaluate_weave_prompt_alignment.py index ddacaaf..4fc9218 100644 --- a/examples/evaluate_weave_prompt_alignment.py +++ b/examples/evaluate_weave_prompt_alignment.py @@ -1,10 +1,10 @@ +import asyncio + import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( @@ -15,26 +15,22 @@ def main( dataset: str = "parti-prompts:v0", project: str = "propmpt-alignment", ): - wandb.init(project=project, job_type="evaluation") weave.init(project_name=project) - model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_name_or_path, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, ) - evaluation_pipeline = EvaluationPipeline(model=model) - # Add CLIP Scorer metric clip_scorer = CLIPScoreMetric(clip_model_name_or_path=clip_model_name_or_path) - evaluation_pipeline.add_metric(clip_scorer) - - # Add CLIP IQA Metric clip_iqa_scorer = CLIPImageQualityScoreMetric( clip_model_name_or_path=clip_iqa_model_name_or_path ) - evaluation_pipeline.add_metric(clip_iqa_scorer) - evaluation_pipeline(dataset=dataset) + evaluation = weave.Evaluation( + dataset=dataset, scorers=[clip_scorer, clip_iqa_scorer] + ) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/examples/multimodal_llm_eval/evaluate_mllm_metric_action.py b/examples/multimodal_llm_eval/evaluate_mllm_metric_action.py index 738387b..cbdae4c 100644 --- a/examples/multimodal_llm_eval/evaluate_mllm_metric_action.py +++ b/examples/multimodal_llm_eval/evaluate_mllm_metric_action.py @@ -1,13 +1,12 @@ +import asyncio from typing import Optional import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.vqa import MultiModalLLMEvaluationMetric from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( @@ -21,38 +20,28 @@ def main( image_height: int = 1024, image_width: int = 1024, num_inference_steps: int = 50, - mock_inference_dataset_address: Optional[str] = None, - save_inference_dataset_name: Optional[str] = None, ): - wandb.init(project=project, entity=entity, job_type="evaluation") weave.init(project_name=f"{entity}/{project}") dataset = weave.ref(dataset_ref).get() dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset - diffusion_model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_address, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, image_height=image_height, image_width=image_width, num_inference_steps=num_inference_steps, ) - diffusion_model._pipeline.set_progress_bar_config(disable=True) - evaluation_pipeline = EvaluationPipeline( - model=diffusion_model, - mock_inference_dataset_address=mock_inference_dataset_address, - save_inference_dataset_name=save_inference_dataset_name, - ) + model._pipeline.set_progress_bar_config(disable=True) judge = OpenAIJudge( prompt_property=PromptCategory.action, openai_model=openai_judge_model ) metric = MultiModalLLMEvaluationMetric(judge=judge) - evaluation_pipeline.add_metric(metric) - evaluation_pipeline(dataset=dataset) - wandb.finish() - evaluation_pipeline.cleanup() + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py b/examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py index 738387b..5ad871e 100644 --- a/examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py +++ b/examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py @@ -1,13 +1,12 @@ +import asyncio from typing import Optional import fire import weave -import wandb -from hemm.eval_pipelines import EvaluationPipeline from hemm.metrics.vqa import MultiModalLLMEvaluationMetric from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory -from hemm.models import BaseDiffusionModel +from hemm.models import DiffusersModel def main( @@ -21,38 +20,28 @@ def main( image_height: int = 1024, image_width: int = 1024, num_inference_steps: int = 50, - mock_inference_dataset_address: Optional[str] = None, - save_inference_dataset_name: Optional[str] = None, ): - wandb.init(project=project, entity=entity, job_type="evaluation") weave.init(project_name=f"{entity}/{project}") dataset = weave.ref(dataset_ref).get() dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset - diffusion_model = BaseDiffusionModel( + model = DiffusersModel( diffusion_model_name_or_path=diffusion_model_address, enable_cpu_offfload=diffusion_model_enable_cpu_offfload, image_height=image_height, image_width=image_width, num_inference_steps=num_inference_steps, ) - diffusion_model._pipeline.set_progress_bar_config(disable=True) - evaluation_pipeline = EvaluationPipeline( - model=diffusion_model, - mock_inference_dataset_address=mock_inference_dataset_address, - save_inference_dataset_name=save_inference_dataset_name, - ) + model._pipeline.set_progress_bar_config(disable=True) judge = OpenAIJudge( - prompt_property=PromptCategory.action, openai_model=openai_judge_model + prompt_property=PromptCategory.complex, openai_model=openai_judge_model ) metric = MultiModalLLMEvaluationMetric(judge=judge) - evaluation_pipeline.add_metric(metric) - evaluation_pipeline(dataset=dataset) - wandb.finish() - evaluation_pipeline.cleanup() + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) if __name__ == "__main__": diff --git a/hemm/eval_pipelines/__init__.py b/hemm/eval_pipelines/__init__.py deleted file mode 100644 index dea3c26..0000000 --- a/hemm/eval_pipelines/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .eval_pipeline import EvaluationPipeline - -__all__ = ["EvaluationPipeline"] diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py deleted file mode 100644 index f241a4e..0000000 --- a/hemm/eval_pipelines/eval_pipeline.py +++ /dev/null @@ -1,192 +0,0 @@ -import asyncio -import os -import shutil -from abc import ABC -from typing import Dict, List, Optional, Union - -import weave -from PIL import Image - -import wandb - -from ..metrics.base import BaseMetric -from ..models import BaseDiffusionModel, FalAIModel, StabilityAPIModel - -MODEL_TYPE = Union[BaseDiffusionModel, FalAIModel, StabilityAPIModel] - - -class EvaluationPipeline(ABC): - """Evaluation pipeline to evaluate the a multi-modal generative model. - - Args: - model (BaseDiffusionModel): The model to evaluate. - seed (int): Seed value for the random number generator. - mock_inference_dataset_address (Optional[str]): A wandb dataset artifact address which if - provided will mock inference results. This prevents the need for redundant generations - when switching metrics/judges with the same evaluation datset(s). - save_inference_dataset_name (Optional[str]): A weave dataset name which if provided will - save inference results as a separate weave dataset. - """ - - def __init__( - self, - model: MODEL_TYPE, - seed: int = 42, - mock_inference_dataset_address: Optional[str] = None, - save_inference_dataset_name: Optional[str] = None, - ) -> None: - super().__init__() - self.model = model - - self.image_size = (self.model.image_height, self.model.image_width) - self.seed = seed - self.mock_inference_dataset_address = mock_inference_dataset_address - if mock_inference_dataset_address: - self.save_inference_dataset_name = None - artifact = wandb.use_artifact( - self.mock_inference_dataset_address, type="dataset" - ) - self.mock_inference_dataset_dir = artifact.download() - - else: - self.save_inference_dataset_name = save_inference_dataset_name - - if self.save_inference_dataset_name: - os.makedirs( - os.path.join("inference_dataset", self.save_inference_dataset_name), - exist_ok=True, - ) - - self.inference_counter = 0 - self.table_columns = ["model", "prompt", "generated_image"] - self.table_rows: List = [] - self.evaluation_table: wandb.Table = None - self.metric_functions: List[BaseMetric] = [] - - self.evaluation_configs = { - "pretrained_model_name_or_path": self.model.diffusion_model_name_or_path, - "torch_dtype": str(self.model._torch_dtype), - "enable_cpu_offfload": self.model.enable_cpu_offfload, - "image_size": { - "height": self.image_size[0], - "width": self.image_size[1], - }, - "seed": seed, - "diffusion_pipeline": dict(self.model._pipeline.config), - } - - def add_metric(self, metric_fn: BaseMetric): - """Add a metric function to the evaluation pipeline. - - Args: - metric_fn (BaseMetric): Metric function to evaluate the generated images. - """ - self.table_columns.append(metric_fn.__class__.__name__) - self.evaluation_configs.update(metric_fn.config) - self.metric_functions.append(metric_fn) - - @weave.op() - def infer(self, prompt: str) -> Dict[str, str]: - """Inference function to generate images for the given prompt. - - Args: - prompt (str): Prompt to generate the image. - - Returns: - Dict[str, str]: Dictionary containing base64 encoded image to be logged as - a Weave object. - """ - if self.inference_counter == 0: - self.evaluation_table = wandb.Table(columns=self.table_columns) - if self.mock_inference_dataset_address: - image = Image.open( - os.path.join( - self.mock_inference_dataset_dir, f"{self.inference_counter}.png" - ) - ) - output = {"image": image} - else: - output = self.model.predict(prompt, seed=self.seed) - self.table_rows.append( - [self.model.diffusion_model_name_or_path, prompt, output["image"]] - ) - if self.save_inference_dataset_name: - output["image"].save( - os.path.join( - "inference_dataset", - self.save_inference_dataset_name, - f"{self.inference_counter}.png", - ) - ) - self.inference_counter += 1 - return output - - @weave.op() - async def infer_async(self, prompt: str) -> Dict[str, str]: - """Async inference function to generate images for the given prompt. - - Args: - prompt (str): Prompt to generate the image. - - Returns: - Dict[str, str]: Dictionary containing base64 encoded image to be logged as - a Weave object. - """ - return self.infer(prompt) - - def log_summary(self, summary: Dict[str, float]) -> None: - """Log the evaluation summary to the Weights & Biases dashboard.""" - config = wandb.config - config.update(self.evaluation_configs) - for row_idx, row in enumerate(self.table_rows): - current_row = row - current_row[-1] = wandb.Image(current_row[-1]) - for metric_fn in self.metric_functions: - current_row.append(metric_fn.scores[row_idx]) - self.evaluation_table.add_data(*current_row) - summary_table = wandb.Table(columns=["summary"], data=[[summary]]) - wandb.log( - { - "evalution": self.evaluation_table, - "summary": summary_table, - "evaluation_summary": summary, - } - ) - - def save_inference_results(self): - artifact = wandb.Artifact(name=self.save_inference_dataset_name, type="dataset") - artifact.add_dir( - os.path.join("inference_dataset", self.save_inference_dataset_name) - ) - artifact.save() - - def cleanup(self): - """Cleanup the inference dataset directory. Should be called after the evaluation is complete - and `wandb.finish()` is called.""" - if os.path.exists("inference_dataset"): - shutil.rmtree("inference_dataset") - - def __call__( - self, dataset: Union[List[Dict], str], async_infer: bool = False - ) -> Dict[str, float]: - """Evaluate the Stable Diffusion model on the given dataset. - - Args: - dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is - passed, it is assumed to be a Weave dataset reference. - async_infer (bool, optional): Whether to use async inference. Defaults to False. - """ - dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset - evaluation = weave.Evaluation( - dataset=dataset, - scorers=[ - metric_fn.evaluate_async if async_infer else metric_fn.evaluate - for metric_fn in self.metric_functions - ], - ) - self.model.configs.update(self.evaluation_configs) - summary = asyncio.run(evaluation.evaluate(self.infer_async)) - self.log_summary(summary) - if self.save_inference_dataset_name: - self.save_inference_results() - return summary diff --git a/hemm/metrics/image_quality/base.py b/hemm/metrics/image_quality/base.py deleted file mode 100644 index 34546d5..0000000 --- a/hemm/metrics/image_quality/base.py +++ /dev/null @@ -1,68 +0,0 @@ -from abc import abstractmethod -from typing import Any, Dict, Union - -from PIL import Image -from pydantic import BaseModel - -from ..base import BaseMetric - - -class ComputeMetricOutput(BaseModel): - """Output of the metric computation function.""" - - score: Union[float, Dict[str, float]] - ground_truth_image: str - - -class BaseImageQualityMetric(BaseMetric): - - def __init__(self, name: str) -> None: - """Base class for Image Quality Metrics. - - Args: - name (str): Name of the metric. - """ - super().__init__() - self.scores = [] - self.name = name - self.config = {} - - @abstractmethod - def compute_metric( - self, - ground_truth_pil_image: Image.Image, - generated_pil_image: Image.Image, - prompt: str, - ) -> ComputeMetricOutput: - """Compute the metric for the given images. This is an abstract - method and must be overriden by the child class implementation. - - Args: - ground_truth_pil_image (Image.Image): Ground truth image in PIL format. - generated_pil_image (Image.Image): Generated image in PIL format. - prompt (str): Prompt for the image generation. - - Returns: - ComputeMetricOutput: Output containing the metric score and ground truth image. - """ - pass - - def evaluate( - self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] - ) -> Dict[str, float]: - """Compute the metric for the given images. This method is used as the scorer - function for `weave.Evaluation` in the evaluation pipelines. - - Args: - prompt (str): Prompt for the image generation. - ground_truth_image (str): Ground truth image in base64 format. - model_output (Dict[str, Any]): Model output containing the generated image. - - Returns: - Union[float, Dict[str, float]]: Metric score. - """ - metric_output = self.compute_metric( - ground_truth_image, model_output["image"], prompt - ) - self.scores.append(metric_output.score) - return {self.name: metric_output.score} diff --git a/hemm/metrics/image_quality/lpips.py b/hemm/metrics/image_quality/lpips.py index fe67f95..12ae628 100644 --- a/hemm/metrics/image_quality/lpips.py +++ b/hemm/metrics/image_quality/lpips.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, Literal, Union import numpy as np import torch @@ -7,11 +7,8 @@ from PIL import Image from torchmetrics.functional.image import learned_perceptual_image_patch_similarity -from ...utils import base64_encode_image -from .base import BaseImageQualityMetric, ComputeMetricOutput - -class LPIPSMetric(BaseImageQualityMetric): +class LPIPSMetric(weave.Scorer): """LPIPS Metric to compute the Learned Perceptual Image Patch Similarity (LPIPS) score between two images. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match @@ -20,32 +17,43 @@ class LPIPSMetric(BaseImageQualityMetric): Args: lpips_net_type (str): The network type to use for computing LPIPS. One of "alex", "vgg", or "squeeze". - image_size (Tuple[int, int]): The size to which images will be resized before computing - LPIPS. - name (str): The name of the metric. + image_height (int): The height to which images will be resized before computing LPIPS. + image_width (int): The width to which images will be resized before computing LPIPS. """ + lpips_net_type: Literal["alex", "vgg", "squeeze"] + image_height: int + image_width: int + _lpips_metric: Callable + def __init__( self, lpips_net_type: Literal["alex", "vgg", "squeeze"] = "alex", - image_size: Optional[Tuple[int, int]] = (512, 512), - name: str = "alexnet_learned_perceptual_image_patch_similarity", + image_height: int = 512, + image_width: int = 512, ) -> None: - super().__init__(name) - self.image_size = image_size - self.lpips_metric = partial( - learned_perceptual_image_patch_similarity, net_type=lpips_net_type + super().__init__( + lpips_net_type=lpips_net_type, + image_height=image_height, + image_width=image_width, + ) + self._lpips_metric = partial( + learned_perceptual_image_patch_similarity, net_type=self.lpips_net_type ) - self.config = {"lpips_net_type": lpips_net_type} @weave.op() def compute_metric( - self, ground_truth_pil_image: Image, generated_pil_image: Image, prompt: str - ) -> ComputeMetricOutput: + self, ground_truth_pil_image: Image, generated_pil_image: Image + ) -> Dict[str, float]: ground_truth_image = ( torch.from_numpy( np.expand_dims( - np.array(ground_truth_pil_image.resize(self.image_size)), axis=0 + np.array( + ground_truth_pil_image.resize( + (self.image_height, self.image_width) + ) + ), + axis=0, ).astype(np.uint8) ) .permute(0, 3, 2, 1) @@ -54,7 +62,12 @@ def compute_metric( generated_image = ( torch.from_numpy( np.expand_dims( - np.array(generated_pil_image.resize(self.image_size)), axis=0 + np.array( + generated_pil_image.resize( + (self.image_height, self.image_width) + ) + ), + axis=0, ).astype(np.uint8) ) .permute(0, 3, 2, 1) @@ -62,23 +75,17 @@ def compute_metric( ) ground_truth_image = (ground_truth_image / 127.5) - 1.0 generated_image = (generated_image / 127.5) - 1.0 - return ComputeMetricOutput( - score=float( - self.lpips_metric(generated_image, ground_truth_image).detach() + return { + "score": float( + self._lpips_metric(generated_image, ground_truth_image).detach() ), - ground_truth_image=base64_encode_image(ground_truth_pil_image), - ) - - @weave.op() - def evaluate( - self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] - ) -> Union[float, Dict[str, float]]: - _ = "LPIPSMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + "ground_truth_image": ground_truth_pil_image, + } @weave.op() - async def evaluate_async( + def score( self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] ) -> Union[float, Dict[str, float]]: - _ = "LPIPSMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + _ = prompt + metric_output = self.compute_metric(ground_truth_image, model_output["image"]) + return {"score": metric_output["score"]} diff --git a/hemm/metrics/image_quality/psnr.py b/hemm/metrics/image_quality/psnr.py index 1e15e42..30bbf8f 100644 --- a/hemm/metrics/image_quality/psnr.py +++ b/hemm/metrics/image_quality/psnr.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch @@ -7,72 +7,74 @@ from PIL import Image from torchmetrics.functional.image import peak_signal_noise_ratio -from ...utils import base64_encode_image -from .base import BaseImageQualityMetric, ComputeMetricOutput - -class PSNRMetric(BaseImageQualityMetric): +class PSNRMetric(weave.Scorer): """PSNR Metric to compute the Peak Signal-to-Noise Ratio (PSNR) between two images. Args: + psnr_base (float): The base of the logarithm in the PSNR formula. psnr_data_range (Optional[Union[float, Tuple[float, float]]]): The data range of the input image (min, max). If None, the data range is determined from the image data type. - psnr_base (float): The base of the logarithm in the PSNR formula. - image_size (Tuple[int, int]): The size to which images will be resized before computing - PSNR. - name (str): The name of the metric. + image_height (int): The height to which images will be resized before computing PSNR. + image_width (int): The width to which images will be resized before computing PSNR. """ + psnr_base: float + psnr_data_range: Optional[Union[float, Tuple[float, float]]] + image_height: int + image_width: int + _psnr_metric: Callable + def __init__( self, psnr_data_range: Optional[Union[float, Tuple[float, float]]] = None, psnr_base: float = 10.0, - image_size: Optional[Tuple[int, int]] = (512, 512), - name: str = "peak_signal_noise_ratio", + image_height: int = 512, + image_width: int = 512, ) -> None: - super().__init__(name) - self.image_size = image_size - self.psnr_metric = partial( - peak_signal_noise_ratio, data_range=psnr_data_range, base=psnr_base + super().__init__( + psnr_data_range=psnr_data_range, + psnr_base=psnr_base, + image_height=image_height, + image_width=image_width, + ) + self._psnr_metric = partial( + peak_signal_noise_ratio, + data_range=self.psnr_data_range, + base=self.psnr_base, ) - self.config = { - "psnr_base": psnr_base, - "psnr_data_range": psnr_data_range, - "image_size": image_size, - } @weave.op() def compute_metric( - self, - ground_truth_pil_image: Image.Image, - generated_pil_image: Image.Image, - prompt: str, - ) -> ComputeMetricOutput: + self, ground_truth_pil_image: Image.Image, generated_pil_image: Image.Image + ) -> Dict[str, float]: ground_truth_image = torch.from_numpy( np.expand_dims( - np.array(ground_truth_pil_image.resize(self.image_size)), axis=0 + np.array( + ground_truth_pil_image.resize((self.image_height, self.image_width)) + ), + axis=0, ).astype(np.uint8) ).float() generated_image = torch.from_numpy( np.expand_dims( - np.array(generated_pil_image.resize(self.image_size)), axis=0 + np.array( + generated_pil_image.resize((self.image_height, self.image_width)) + ), + axis=0, ).astype(np.uint8) ).float() - return ComputeMetricOutput( - score=float(self.psnr_metric(generated_image, ground_truth_image).detach()), - ground_truth_image=base64_encode_image(ground_truth_pil_image), - ) - - @weave.op() - def evaluate( - self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] - ) -> Union[float, Dict[str, float]]: - _ = "PSNRMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + return { + "score": float( + self._psnr_metric(generated_image, ground_truth_image).detach() + ), + "ground_truth_image": ground_truth_pil_image, + } @weave.op() - async def evaluate_async( + def score( self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] ) -> Union[float, Dict[str, float]]: - _ = "PSNRMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + _ = prompt + metric_output = self.compute_metric(ground_truth_image, model_output["image"]) + return {"score": metric_output["score"]} diff --git a/hemm/metrics/image_quality/ssim.py b/hemm/metrics/image_quality/ssim.py index 509eeb6..a4bb69b 100644 --- a/hemm/metrics/image_quality/ssim.py +++ b/hemm/metrics/image_quality/ssim.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union import numpy as np import torch @@ -7,11 +7,8 @@ from PIL import Image from torchmetrics.functional.image import structural_similarity_index_measure -from ...utils import base64_encode_image -from .base import BaseImageQualityMetric, ComputeMetricOutput - -class SSIMMetric(BaseImageQualityMetric): +class SSIMMetric(weave.Scorer): """SSIM Metric to compute the [Structural Similarity Index Measure (SSIM)](https://en.wikipedia.org/wiki/Structural_similarity) between two images. @@ -24,11 +21,20 @@ class SSIMMetric(BaseImageQualityMetric): image (min, max). If None, the data range is determined from the image data type. ssim_k1 (float): The constant used to stabilize the SSIM numerator. ssim_k2 (float): The constant used to stabilize the SSIM denominator. - image_size (Tuple[int, int]): The size to which images will be resized before computing - SSIM. - name (str): The name of the metric. + image_height (int): The height to which images will be resized before computing SSIM. + image_width (int): The width to which images will be resized before computing SSIM. """ + ssim_gaussian_kernel: bool + ssim_sigma: float + ssim_kernel_size: int + ssim_data_range: Union[float, Tuple[float, float], None] + ssim_k1: float + ssim_k2: float + image_height: int + image_width: int + _ssim_metric: Callable + def __init__( self, ssim_gaussian_kernel: bool = True, @@ -37,12 +43,20 @@ def __init__( ssim_data_range: Union[float, Tuple[float, float], None] = None, ssim_k1: float = 0.01, ssim_k2: float = 0.03, - image_size: Optional[Tuple[int, int]] = (512, 512), - name: str = "structural_similarity_index_measure", + image_height: int = 512, + image_width: int = 512, ) -> None: - super().__init__(name) - self.image_size = image_size - self.ssim_metric = partial( + super().__init__( + ssim_gaussian_kernel=ssim_gaussian_kernel, + ssim_sigma=ssim_sigma, + ssim_kernel_size=ssim_kernel_size, + ssim_data_range=ssim_data_range, + ssim_k1=ssim_k1, + ssim_k2=ssim_k2, + image_height=image_height, + image_width=image_width, + ) + self._ssim_metric = partial( structural_similarity_index_measure, gaussian_kernel=ssim_gaussian_kernel, sigma=ssim_sigma, @@ -51,26 +65,20 @@ def __init__( k1=ssim_k1, k2=ssim_k2, ) - self.config = { - "ssim_gaussian_kernel": ssim_gaussian_kernel, - "ssim_sigma": ssim_sigma, - "ssim_kernel_size": ssim_kernel_size, - "ssim_data_range": ssim_data_range, - "ssim_k1": ssim_k1, - "ssim_k2": ssim_k2, - } @weave.op() def compute_metric( - self, - ground_truth_pil_image: Image.Image, - generated_pil_image: Image.Image, - prompt: str, - ) -> ComputeMetricOutput: + self, ground_truth_pil_image: Image.Image, generated_pil_image: Image.Image + ) -> Dict[str, float]: ground_truth_image = ( torch.from_numpy( np.expand_dims( - np.array(ground_truth_pil_image.resize(self.image_size)), axis=0 + np.array( + ground_truth_pil_image.resize( + (self.image_height, self.image_width) + ) + ), + axis=0, ).astype(np.uint8) ) .permute(0, 3, 1, 2) @@ -79,27 +87,26 @@ def compute_metric( generated_image = ( torch.from_numpy( np.expand_dims( - np.array(generated_pil_image.resize(self.image_size)), axis=0 + np.array( + generated_pil_image.resize( + (self.image_height, self.image_width) + ) + ), + axis=0, ).astype(np.uint8) ) .permute(0, 3, 1, 2) .float() ) - return ComputeMetricOutput( - score=float(self.ssim_metric(generated_image, ground_truth_image)), - ground_truth_image=base64_encode_image(ground_truth_pil_image), - ) - - @weave.op() - def evaluate( - self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] - ) -> Union[float, Dict[str, float]]: - _ = "SSIMMetric" - return super().evaluate(prompt, ground_truth_image, model_output) + return { + "score": float(self._ssim_metric(generated_image, ground_truth_image)), + "ground_truth_image": ground_truth_pil_image, + } @weave.op() - async def evaluate_async( + def score( self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any] ) -> Union[float, Dict[str, float]]: - _ = "SSIMMetric" - return self.evaluate(prompt, ground_truth_image, model_output) + _ = prompt + metric_output = self.compute_metric(ground_truth_image, model_output["image"]) + return {"score": metric_output["score"]} diff --git a/hemm/metrics/prompt_alignment/base.py b/hemm/metrics/prompt_alignment/base.py index d60de18..5016c5f 100644 --- a/hemm/metrics/prompt_alignment/base.py +++ b/hemm/metrics/prompt_alignment/base.py @@ -1,23 +1,11 @@ from abc import abstractmethod from typing import Any, Dict, Union +import weave from PIL import Image -from ..base import BaseMetric - -class BasePromptAlignmentMetric(BaseMetric): - """Base class for Prompt Alignment Metrics. - - Args: - name (str): Name of the metric. - """ - - def __init__(self, name: str) -> None: - super().__init__() - self.scores = [] - self.name = name - self.config = {} +class BasePromptAlignmentMetric(weave.Scorer): @abstractmethod def compute_metric( @@ -47,5 +35,4 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float Dict[str, float]: Metric score. """ score = self.compute_metric(model_output["image"], prompt) - self.scores.append(score) - return {self.name: score} + return {"score": score} diff --git a/hemm/metrics/prompt_alignment/blip_score.py b/hemm/metrics/prompt_alignment/blip_score.py index 0caf9c1..4a32c03 100644 --- a/hemm/metrics/prompt_alignment/blip_score.py +++ b/hemm/metrics/prompt_alignment/blip_score.py @@ -1,60 +1,48 @@ -from typing import Any, Dict, Union +from typing import Any, Dict import weave -from PIL import Image from torch.nn import functional as F from transformers import BlipForConditionalGeneration, BlipProcessor -from .base import BasePromptAlignmentMetric +class BLIPScoreMertric(weave.Scorer): + model_name: str = "Salesforce/blip-image-captioning-base" + device: str = "cuda" + _blip_processor: BlipProcessor + _blip_model: BlipForConditionalGeneration -class BLIPScoreMertric(BasePromptAlignmentMetric): def __init__( self, - name: str = "blip_score", - blip_model_name_or_path: str = "Salesforce/blip-image-captioning-base", + model_name: str = "Salesforce/blip-image-captioning-base", device: str = "cuda", ) -> None: - super().__init__(name) - self.blip_processor = BlipProcessor.from_pretrained(blip_model_name_or_path) - self.blip_model = BlipForConditionalGeneration.from_pretrained( - blip_model_name_or_path - ).to(device) - self.config = {"blip_model_name_or_path": blip_model_name_or_path} + super().__init__(model_name=model_name, device=device) + self._blip_processor = BlipProcessor.from_pretrained(model_name) + self._blip_model = BlipForConditionalGeneration.from_pretrained(model_name).to( + device + ) @weave.op() - def compute_metric( - self, pil_image: Image, prompt: str - ) -> Union[float, Dict[str, Any]]: + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: pixel_values = self.blip_processor( - images=pil_image, return_tensors="pt" + images=model_output["image"], return_tensors="pt" ).pixel_values - text_input_ids = self.blip_processor( + text_input_ids = self._blip_processor( text=prompt, return_tensors="pt", padding=True, truncation=True ).input_ids - outputs = self.blip_model( + outputs = self._blip_model( pixel_values=pixel_values.to(self.device), input_ids=text_input_ids.to(self.device), ) logits = outputs.logits[:, :-1, :] shift_labels = text_input_ids[..., 1:].contiguous() - return float( - F.cross_entropy( - logits.view(-1, logits.size(-1)).to(self.device), - shift_labels.view(-1).to(self.device), + return { + "score": float( + F.cross_entropy( + logits.view(-1, logits.size(-1)).to(self.device), + shift_labels.view(-1).to(self.device), + ) + .detach() + .item() ) - .detach() - .item() - ) - - @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - _ = "BLIPScoreMertric" - return super().evaluate(prompt, model_output) - - @weave.op() - async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] - ) -> Dict[str, float]: - _ = "BLIPScoreMertric" - return self.evaluate(prompt, model_output) + } diff --git a/hemm/metrics/prompt_alignment/clip_iqa_score.py b/hemm/metrics/prompt_alignment/clip_iqa_score.py index ba61660..8a7a1b0 100644 --- a/hemm/metrics/prompt_alignment/clip_iqa_score.py +++ b/hemm/metrics/prompt_alignment/clip_iqa_score.py @@ -1,17 +1,14 @@ from functools import partial -from typing import Any, Dict, Union +from typing import Any, Callable, Dict, List import numpy as np import torch import weave -from PIL import Image from torchmetrics.functional.multimodal import clip_image_quality_assessment from tqdm.auto import tqdm -from .base import BasePromptAlignmentMetric - -class CLIPImageQualityScoreMetric(BasePromptAlignmentMetric): +class CLIPImageQualityScoreMetric(weave.Scorer): """[CLIP Image Quality Assessment](https://arxiv.org/abs/2207.12396) metric for to measuring the visual content of images. @@ -28,65 +25,53 @@ class CLIPImageQualityScoreMetric(BasePromptAlignmentMetric): image is more similar to the first prompt than the second prompt. Args: - clip_model_name_or_path (str, optional): The name or path of the CLIP model to use. - Defaults to "clip_iqa". - name (str, optional): Name of the metric. Defaults to "clip_image_quality_assessment". + model_name (str, optional): The name or path of the CLIP model to use. """ - def __init__( - self, - clip_model_name_or_path: str = "clip_iqa", - name: str = "clip_image_quality_assessment", - ) -> None: - super().__init__(name) - self.clip_iqa_fn = partial( - clip_image_quality_assessment, model_name_or_path=clip_model_name_or_path + model_name: str + built_in_prompts: List[str] + _clip_iqa_fn: Callable + + def __init__(self, model_name: str = "clip_iqa") -> None: + super().__init__( + model_name=model_name, + built_in_prompts=[ + "quality", + "brightness", + "noisiness", + "colorfullness", + "sharpness", + "contrast", + "complexity", + "natural", + "happy", + "scary", + "new", + "real", + "beautiful", + "lonely", + "relaxing", + ], + ) + self._clip_iqa_fn = partial( + clip_image_quality_assessment, model_name_or_path=model_name ) - self.built_in_prompts = [ - "quality", - "brightness", - "noisiness", - "colorfullness", - "sharpness", - "contrast", - "complexity", - "natural", - "happy", - "scary", - "new", - "real", - "beautiful", - "lonely", - "relaxing", - ] - self.config = {"clip_model_name_or_path": clip_model_name_or_path} @weave.op() - def compute_metric( - self, pil_image: Image, prompt: str - ) -> Union[float, Dict[str, float]]: - images = np.expand_dims(np.array(pil_image), axis=0).astype(np.uint8) / 255.0 + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + images = ( + np.expand_dims(np.array(model_output["image"]), axis=0).astype(np.uint8) + / 255.0 + ) score_dict = {} for prompt in tqdm( self.built_in_prompts, desc="Calculating IQA scores", leave=False ): clip_iqa_score = float( - self.clip_iqa_fn( + self._clip_iqa_fn( images=torch.from_numpy(images).permute(0, 3, 1, 2), prompts=tuple([prompt] * images.shape[0]), ).detach() ) score_dict[f"{self.name}_{prompt}"] = clip_iqa_score return score_dict - - @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - _ = "CLIPImageQualityScoreMetric" - return super().evaluate(prompt, model_output) - - @weave.op() - async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] - ) -> Dict[str, float]: - _ = "CLIPImageQualityScoreMetric" - return self.evaluate(prompt, model_output) diff --git a/hemm/metrics/prompt_alignment/clip_score.py b/hemm/metrics/prompt_alignment/clip_score.py index bb1f79e..897f955 100644 --- a/hemm/metrics/prompt_alignment/clip_score.py +++ b/hemm/metrics/prompt_alignment/clip_score.py @@ -1,57 +1,36 @@ from functools import partial -from typing import Any, Dict, Union +from typing import Any, Callable, Dict import numpy as np import torch import weave -from PIL import Image from torchmetrics.functional.multimodal import clip_score -from .base import BasePromptAlignmentMetric - -class CLIPScoreMetric(BasePromptAlignmentMetric): +class CLIPScoreMetric(weave.Scorer): """[CLIP score](https://arxiv.org/abs/2104.08718) metric for text-to-image similarity. CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to be highly correlated with human judgement. Args: - name (str, optional): Name of the metric. Defaults to "clip_score". - clip_model_name_or_path (str, optional): The name or path of the CLIP model to use. - Defaults to "openai/clip-vit-base-patch16". + model_name (str, optional): The name or path of the CLIP model to use. """ - def __init__( - self, - clip_model_name_or_path: str = "openai/clip-vit-base-patch16", - name: str = "clip_score", - ) -> None: - super().__init__(name) - self.clip_score_fn = partial( - clip_score, model_name_or_path=clip_model_name_or_path - ) - self.config = {"clip_model_name_or_path": clip_model_name_or_path} - - @weave.op() - def compute_metric( - self, pil_image: Image.Image, prompt: str - ) -> Union[float, Dict[str, float]]: - images = np.expand_dims(np.array(pil_image), axis=0) - return float( - self.clip_score_fn( - torch.from_numpy(images).permute(0, 3, 1, 2), prompt - ).detach() - ) + model_name: str + _clip_score_fn: Callable - @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - _ = "CLIPScoreMetric" - return super().evaluate(prompt, model_output) + def __init__(self, model_name: str = "openai/clip-vit-base-patch16") -> None: + super().__init__(model_name=model_name) + self._clip_score_fn = partial(clip_score, model_name_or_path=model_name) @weave.op() - async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] - ) -> Dict[str, float]: - _ = "CLIPScoreMetric" - return self.evaluate(prompt, model_output) + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + images = np.expand_dims(np.array(model_output["image"]), axis=0) + return { + "score": float( + self._clip_score_fn( + torch.from_numpy(images).permute(0, 3, 1, 2), prompt + ).detach() + ) + } diff --git a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py index dae173a..b8a14a3 100644 --- a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py +++ b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py @@ -1,49 +1,42 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import weave from PIL import Image -import wandb - -from ..base import BaseMetric -from .judges import DETRSpatialRelationShipJudge from .judges.commons import BoundingBox from .utils import annotate_with_bounding_box, get_iou -class SpatialRelationshipMetric2D(BaseMetric): +class SpatialRelationshipMetric2D(weave.Scorer): """Spatial relationship metric for image generation as proposed in Section 4.2 from the paper [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350). - ??? example "Sample usage" + !!! example "Sample usage" ```python - import wandb + import asyncio import weave - from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline - from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric + from hemm.models import DiffusersModel + from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge - # Initialize Weave and WandB - wandb.init(project="image-quality-leaderboard", job_type="evaluation") + # Initialize Weave weave.init(project_name="image-quality-leaderboard") - # Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel` - model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") - - # Add the model to the evaluation pipeline - evaluation_pipeline = EvaluationPipeline(model=model) + # Initialize the diffusion model to be evaluated as a `weave.Model` + model = DiffusersModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") # Define the judge model for 2d spatial relationship metric judge = DETRSpatialRelationShipJudge( model_address=detr_model_address, revision=detr_revision ) - # Add 2d spatial relationship Metric to the evaluation pipeline + # Define 2d spatial relationship Metric to the evaluation pipeline metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score") - evaluation_pipeline.add_metric(metric) # Evaluate! - evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0") + dataset = weave.ref("2d-spatial-t2i_compbench_spatial_prompts-mscoco:v0").get() + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) ``` Args: @@ -54,21 +47,9 @@ class SpatialRelationshipMetric2D(BaseMetric): name (Optional[str], optional): The name of the metric. """ - def __init__( - self, - judge: Union[weave.Model, DETRSpatialRelationShipJudge], - iou_threshold: Optional[float] = 0.1, - distance_threshold: Optional[float] = 150, - name: Optional[str] = "spatial_relationship_score", - ) -> None: - super().__init__() - self.judge = judge - self.judge_config = self.judge.model_dump(mode="json") - self.iou_threshold = iou_threshold - self.distance_threshold = distance_threshold - self.name = name - self.scores = [] - self.config = judge.model_dump() + judge: weave.Model + iou_threshold: float = 0.1 + distance_threshold: float = 150 @weave.op() def compose_judgement( @@ -187,23 +168,10 @@ def compose_judgement( score = self.iou_threshold / iou judgement["score"] = score - self.scores.append( - { - **judgement, - **{ - "judge_annotated_image": wandb.Image(annotated_image), - "judge_config": self.judge_config, - }, - } - ) - return { - **judgement, - **{"judge_annotated_image": annotated_image}, - "judge_config": self.judge_config, - } + return {**judgement, **{"judge_annotated_image": annotated_image}} @weave.op() - def evaluate( + def score( self, prompt: str, entity_1: str, @@ -231,14 +199,3 @@ def evaluate( prompt, image, entity_1, entity_2, relationship, boxes ) return {self.name: judgement["score"]} - - @weave.op() - async def evaluate_async( - self, - prompt: str, - entity_1: str, - entity_2: str, - relationship: str, - model_output: Dict[str, Any], - ) -> Dict[str, Union[bool, float, int]]: - return self.evaluate(prompt, entity_1, entity_2, relationship, model_output) diff --git a/hemm/metrics/vqa/disentangled_vqa.py b/hemm/metrics/vqa/disentangled_vqa.py index 1e4556a..446e3cf 100644 --- a/hemm/metrics/vqa/disentangled_vqa.py +++ b/hemm/metrics/vqa/disentangled_vqa.py @@ -1,12 +1,11 @@ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict import weave -from ..base import BaseMetric from .judges import BlipVQAJudge -class DisentangledVQAMetric(BaseMetric): +class DisentangledVQAMetric(weave.Scorer): """Disentangled VQA metric to evaluate the attribute-binding capability for image generation models as proposed in Section 4.1 from the paper [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350). @@ -39,22 +38,12 @@ class DisentangledVQAMetric(BaseMetric): Args: judge (Union[weave.Model, BlipVQAJudge]): The judge model to evaluate the attribute-binding capability. - name (Optional[str]): The name of the metric. Defaults to "disentangled_vlm_metric". """ - def __init__( - self, - judge: Union[weave.Model, BlipVQAJudge], - name: Optional[str] = "disentangled_vlm_metric", - ) -> None: - super().__init__() - self.judge = judge - self.config = self.judge.model_dump() - self.scores = [] - self.name = name + judge: BlipVQAJudge @weave.op() - def evaluate( + def score( self, prompt: str, adj_1: str, @@ -80,17 +69,4 @@ def evaluate( judgement = self.judge.predict( adj_1, noun_1, adj_2, noun_2, model_output["image"] ) - self.scores.append(judgement) return judgement - - @weave.op() - async def evaluate_async( - self, - prompt: str, - adj_1: str, - noun_1: str, - adj_2: str, - noun_2: str, - model_output: Dict[str, Any], - ) -> Dict[str, Any]: - return self.evaluate(prompt, adj_1, noun_1, adj_2, noun_2, model_output) diff --git a/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py b/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py index b3fbd75..a16e946 100644 --- a/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py +++ b/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py @@ -1,6 +1,6 @@ import os import subprocess -from typing import List +from typing import Dict, List import spacy import weave @@ -9,7 +9,7 @@ from pydantic import BaseModel from .....utils import base64_encode_image -from .commons import JudgeMent, JudgeQuestion, PromptCategory, TaggedPromptParts +from .commons import JudgeMent, PromptCategory, TaggedPromptParts class OpenAIJudgeMent(BaseModel): @@ -91,7 +91,7 @@ def extract_prompt_parts(self, prompt: str) -> List[TaggedPromptParts]: return tagged_prompt_parts @weave.op() - def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion]: + def frame_question(self, prompt: str, image: Image.Image) -> List[Dict[str, str]]: """Frame the question corresponding to the given prompt and image for the chain-of-thought system of judgement. @@ -100,20 +100,21 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] image (Image.Image): The image to frame the question for. Returns: - List[JudgeQuestion]: List of questions to ask for the given prompt. + List[Dict[str, str]]: List of questions to ask for the given prompt. """ + prompt = str(prompt) if self.prompt_property in [PromptCategory.spatial, PromptCategory.spatial_3d]: self._total_score = 5 - question = JudgeQuestion( - image_desciption_system_prompt=""" + question = { + "image_desciption_system_prompt": """ You are a helpful assistant meant to describe images is detail. You should pay special attention to the objects and their spatial layout in the image. """, - judgement_question_system_prompt=""" + "judgement_question_system_prompt": """ You are a helpful assistant meant to identify objects and their spatial layout in the image. You have to extract the question, the score, and the explanation from the user's response. """, - judgement_question=f""" + "judgement_question": f""" Looking at the image and given a detailed description of the image, evaluate if the text \"{prompt}\" is correctly portrayed in the image. Give a score from 1 to 5, according to the following criteria: @@ -133,20 +134,20 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] 3. The spatial layout of the objects in the image should be consistent with the text prompt. You should deduct 1 point from the score if the spatial layout of the objects in the image is not consistent with the text prompt. """, - ) + } return [(question, image)] elif self.prompt_property == PromptCategory.action: self._total_score = 5 - question = JudgeQuestion( - image_desciption_system_prompt=""" + question = { + "image_desciption_system_prompt": """ You are a helpful assistant meant to describe images is detail. You should pay special attention to the the actions, events, objects and their relationships in the image. """, - judgement_question_system_prompt=""" + "judgement_question_system_prompt": """ You are a helpful assistant meant to identify the actions, events, objects and their relationships in the image. You have to extract the question, the score, and the explanation from the user's response. """, - judgement_question=f""" + "judgement_question": f""" Looking at the image and given a detailed description of the image, evaluate if the text \"{prompt}\" is correctly portrayed in the image. Give a score from 1 to 5, according to the following criteria: @@ -166,20 +167,20 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] 3. The spatial layout of the objects in the image should be consistent with the text prompt. You should deduct 1 point from the score if the spatial layout of the objects in the image is not consistent with the text prompt. """, - ) + } return [(question, image)] elif self.prompt_property == PromptCategory.numeracy: self._total_score = 5 - question = JudgeQuestion( - image_desciption_system_prompt=""" + question = { + "image_desciption_system_prompt": """ You are a helpful assistant meant to describe images is detail. You should pay special attention to the objects and their quantities in the image. """, - judgement_question_system_prompt=""" + "judgement_question_system_prompt": """ You are a helpful assistant meant to identify objects and their quantities in the image. You have to extract the question, the score, and the explanation from the user's response. """, - judgement_question=f""" + "judgement_question": f""" Looking at the image and given a detailed description of the image, evaluate how well the image aligns with the text prompt: \"{prompt}\" Give a score from 1 to 5, according to the following criteria: @@ -199,23 +200,23 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] 3. The spatial layout of the objects in the image should be consistent with the text prompt. You should deduct 1 point from the score if the spatial layout of the objects in the image is not consistent with the text prompt. """, - ) + } return [(question, image)] elif self.prompt_property == PromptCategory.complex: self._total_score = 5 - question = JudgeQuestion( - image_desciption_system_prompt=""" + question = { + "image_desciption_system_prompt": """ You are a helpful assistant meant to describe images is detail. You should pay special attention to the objects in the image and their attributes (such as color, shape, texture), spatial layout and action relationships. """, - judgement_question_system_prompt=""" + "judgement_question_system_prompt": """ You are a helpful assistant meant to evaluate the correspondence of the image to a given text prompt. Focus on the objects in the image and their attributes (such as color, shape, texture), spatial layout and action relationships. You have to extract the question, the score, and the explanation from the user's response. """, - judgement_question=f""" + "judgement_question": f""" Looking at the image and given a detailed description of the image, evaluate how well the image aligns with the text prompt: \"{prompt}\" Give a score from 1 to 5, according to the following criteria: @@ -235,21 +236,21 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] 3. The spatial layout of the objects in the image should be consistent with the text prompt. You should deduct 1 point from the score if the spatial layout of the objects in the image is not consistent with the text prompt. """, - ) + } return [(question, image)] tagged_prompt_parts = self.extract_prompt_parts(prompt) questions: List[str] = [] for tagged_prompt_part in tagged_prompt_parts: - question = JudgeQuestion( - image_desciption_system_prompt=f""" + question = { + "image_desciption_system_prompt": f""" You are a helpful assistant meant to describe images is detail. You should pay special attention to any objects and their {self.prompt_property.name} in the given image. """, - judgement_question_system_prompt=f""" + "judgement_question_system_prompt": f""" You are a helpful assistant meant to identify any objects and their {self.prompt_property.name} in the given image. You have to extract the question, the score, and the explanation from the user's response. """, - judgement_question=f""" + "judgement_question": f""" Looking at the image and given a detailed description of the image, evaluate if there is a {tagged_prompt_part.entity} in the image. Give a score from 1 to 4, according to the following criteria: @@ -268,13 +269,13 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] 3. The spatial layout of the objects in the image should be consistent with the text prompt. You should deduct 1 point from the score if the spatial layout of the objects in the image is not consistent with the text prompt. """, - ) + } questions.append((question, image)) return questions @weave.op def execute_chain_of_thought( - self, question: JudgeQuestion, image: Image.Image + self, question: Dict[str, str], image: Image.Image ) -> OpenAIJudgeMent: image_description_explanation = ( self._openai_client.chat.completions.create( @@ -283,7 +284,7 @@ def execute_chain_of_thought( messages=[ { "role": "system", - "content": question.image_desciption_system_prompt, + "content": question["image_desciption_system_prompt"], }, { "role": "user", @@ -299,7 +300,9 @@ def execute_chain_of_thought( .choices[0] .message.content ) - question.judgement_question += f""" + question[ + "judgement_question" + ] += f""" Here is a detailed explanation of the image: --- @@ -309,19 +312,19 @@ def execute_chain_of_thought( Provide your analysis and explanation to justify the score. """ judgement_response = ( - self._openai_client.beta.chat.completions.parse( + weave.op()(self._openai_client.beta.chat.completions.parse)( model=self.openai_model, response_format=JudgeMent, seed=self.seed, messages=[ { "role": "system", - "content": question.judgement_question_system_prompt, + "content": question["judgement_question_system_prompt"], }, { "role": "user", "content": [ - {"type": "text", "text": question.judgement_question}, + {"type": "text", "text": question["judgement_question"]}, { "type": "image_url", "image_url": {"url": base64_encode_image(image)}, diff --git a/hemm/metrics/vqa/multi_modal_llm_eval.py b/hemm/metrics/vqa/multi_modal_llm_eval.py index 94abe84..29af903 100644 --- a/hemm/metrics/vqa/multi_modal_llm_eval.py +++ b/hemm/metrics/vqa/multi_modal_llm_eval.py @@ -1,34 +1,22 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List import weave -from ..base import BaseMetric from .judges.mmllm_judges import OpenAIJudge from .judges.mmllm_judges.openai_judge import OpenAIJudgeMent -class MultiModalLLMEvaluationMetric(BaseMetric): +class MultiModalLLMEvaluationMetric(weave.Scorer): """Multi-modal LLM-based evaluation metric for an image-generation model. Args: - judge (Union[weave.Model, OpenAIJudge]): The judge LLM model to evaluate the generated images. - name (Optional[str]): Name of the evaluation. + judge (OpenAIJudge): The judge LLM model to evaluate the generated images. """ - def __init__( - self, - judge: Union[weave.Model, OpenAIJudge], - name: Optional[str] = "mmllm_eval_metric", - ) -> None: - super().__init__() - self.judge = judge - self.config = self.judge.model_dump() - self.prompt_property = judge.prompt_property - self.scores = [] - self.name = name + judge: OpenAIJudge @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: """Evaluate the generated image using the judge LLM model. Args: @@ -44,11 +32,4 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: "score": score / len(judgements), "fractional_score": fractional_score / len(judgements), } - self.scores.append(evaluation_dict) return evaluation_dict - - @weave.op() - async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] - ) -> Dict[str, Any]: - return self.evaluate(prompt, model_output) diff --git a/hemm/models/__init__.py b/hemm/models/__init__.py index 883d860..57271a4 100644 --- a/hemm/models/__init__.py +++ b/hemm/models/__init__.py @@ -1,5 +1,5 @@ -from .diffusion_model import BaseDiffusionModel +from .diffusion_model import DiffusersModel from .falai_model import FalAIModel from .stability_model import StabilityAPIModel -__all__ = ["BaseDiffusionModel", "FalAIModel", "StabilityAPIModel"] +__all__ = ["DiffusersModel", "FalAIModel", "StabilityAPIModel"] diff --git a/hemm/models/diffusion_model.py b/hemm/models/diffusion_model.py index 2cfe0ba..3d85ec0 100644 --- a/hemm/models/diffusion_model.py +++ b/hemm/models/diffusion_model.py @@ -5,7 +5,7 @@ from diffusers import DiffusionPipeline -class BaseDiffusionModel(weave.Model): +class DiffusersModel(weave.Model): """`weave.Model` wrapping `diffusers.DiffusionPipeline`. Args: @@ -16,19 +16,18 @@ class BaseDiffusionModel(weave.Model): num_inference_steps (int): The number of inference steps. disable_safety_checker (bool): Disable safety checker for the diffusion model. configs (Dict[str, Any]): Additional configs. - pipeline_configs (Dict[str, Any]): Diffusion pipeline configs. inference_kwargs (Dict[str, Any]): Inference kwargs. """ diffusion_model_name_or_path: str enable_cpu_offfload: bool = False - image_height: int = 512 - image_width: int = 512 - num_inference_steps: int = 50 - disable_safety_checker: bool = True - configs: Dict[str, Any] = {} - pipeline_configs: Dict[str, Any] = {} - inference_kwargs: Dict[str, Any] = {} + image_height: int + image_width: int + num_inference_steps: int + seed: int + disable_safety_checker: bool + configs: Dict[str, Any] + inference_kwargs: Dict[str, Any] _torch_dtype: torch.dtype = torch.float16 _pipeline: DiffusionPipeline = None @@ -39,9 +38,9 @@ def __init__( image_height: int = 512, image_width: int = 512, num_inference_steps: int = 50, + seed: int = 42, disable_safety_checker: bool = True, configs: Dict[str, Any] = {}, - pipeline_configs: Dict[str, Any] = {}, inference_kwargs: Dict[str, Any] = {}, ) -> None: super().__init__( @@ -50,17 +49,15 @@ def __init__( image_height=image_height, image_width=image_width, num_inference_steps=num_inference_steps, + seed=seed, disable_safety_checker=disable_safety_checker, configs=configs, - pipeline_configs=pipeline_configs, inference_kwargs=inference_kwargs, ) - self.configs["torch_dtype"] = str(self._torch_dtype) pipeline_init_kwargs = { "pretrained_model_name_or_path": self.diffusion_model_name_or_path, "torch_dtype": self._torch_dtype, } - pipeline_init_kwargs.update(self.pipeline_configs) if self.disable_safety_checker: pipeline_init_kwargs["safety_checker"] = None self._pipeline = DiffusionPipeline.from_pretrained(**pipeline_init_kwargs) @@ -70,14 +67,26 @@ def __init__( self._pipeline = self._pipeline.to("cuda") self._pipeline.set_progress_bar_config(leave=False, desc="Generating Image") + self.configs = { + **self.configs, + "torch_dtype": str(self._torch_dtype), + "pretrained_model_name_or_path": self.diffusion_model_name_or_path, + "enable_cpu_offfload": self.enable_cpu_offfload, + "image_size": { + "height": self.image_height, + "width": self.image_width, + }, + "diffusion_pipeline": dict(self._pipeline.config), + } + @weave.op() - def predict(self, prompt: str, seed: int) -> Dict[str, Any]: + def predict(self, prompt: str) -> Dict[str, Any]: pipeline_output = self._pipeline( prompt, num_images_per_prompt=1, height=self.image_height, width=self.image_width, - generator=torch.Generator(device="cuda").manual_seed(seed), + generator=torch.Generator(device="cuda").manual_seed(self.seed), num_inference_steps=self.num_inference_steps, **self.inference_kwargs, ) diff --git a/hemm/tests/test_2d_spatial_relationship_eval.py b/hemm/tests/test_2d_spatial_relationship_eval.py index 4cd3f2d..58ed26e 100644 --- a/hemm/tests/test_2d_spatial_relationship_eval.py +++ b/hemm/tests/test_2d_spatial_relationship_eval.py @@ -1,72 +1,46 @@ -import unittest +import asyncio import weave -import wandb -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D from hemm.metrics.spatial_relationship.judges import ( DETRSpatialRelationShipJudge, RTDETRSpatialRelationShipJudge, ) - - -class Test2DSpatialRelationshipEval(unittest.TestCase): - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - wandb.init( - project="unit-tests", - entity="hemm-eval", - job_type="test_2d_spatial_relationship_evaluation", - ) - weave.init(project_name="hemm-eval/unit-tests") - - def test_2d_spatial_relationship_evaluation_detr_judge(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - ) - evaluation_pipeline = EvaluationPipeline(model=model) - - judge = DETRSpatialRelationShipJudge( - model_address="facebook/detr-resnet-50", revision="no_timm" - ) - metric = SpatialRelationshipMetric2D( - judge=judge, name="2d_spatial_relationship_score" - ) - evaluation_pipeline.add_metric(metric) - - dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get().rows[:2] - summary = evaluation_pipeline(dataset=dataset) - - self.assertGreater( - summary["SpatialRelationshipMetric2D.evaluate_async"][ - "2d_spatial_relationship_score" - ]["mean"], - 0.0, - ) - self.assertGreater(summary["model_latency"]["mean"], 0.0) - - def test_2d_spatial_relationship_evaluation_rt_detr_judge(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - ) - evaluation_pipeline = EvaluationPipeline(model=model) - - judge = RTDETRSpatialRelationShipJudge(model_address="PekingU/rtdetr_r50vd") - metric = SpatialRelationshipMetric2D( - judge=judge, name="2d_spatial_relationship_score" - ) - evaluation_pipeline.add_metric(metric) - - dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get().rows[:2] - summary = evaluation_pipeline(dataset=dataset) - - self.assertGreater( - summary["SpatialRelationshipMetric2D.evaluate_async"][ - "2d_spatial_relationship_score" - ]["mean"], - 0.0, - ) - self.assertGreater(summary["model_latency"]["mean"], 0.0) +from hemm.models import DiffusersModel + + +def test_2d_spatial_relationship_evaluation_detr_judge(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + ) + + judge = DETRSpatialRelationShipJudge( + model_address="facebook/detr-resnet-50", revision="no_timm" + ) + metric = SpatialRelationshipMetric2D( + judge=judge, name="2d_spatial_relationship_score" + ) + + dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get().rows[:2] + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) + + +def test_2d_spatial_relationship_evaluation_rt_detr_judge(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + ) + + judge = RTDETRSpatialRelationShipJudge(model_address="PekingU/rtdetr_r50vd") + metric = SpatialRelationshipMetric2D( + judge=judge, name="2d_spatial_relationship_score" + ) + + dataset = weave.ref("2d-spatial-prompts-mscoco:v0").get().rows[:2] + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) diff --git a/hemm/tests/test_disentangled_vqa.py b/hemm/tests/test_disentangled_vqa.py index 669355f..9a5df8a 100644 --- a/hemm/tests/test_disentangled_vqa.py +++ b/hemm/tests/test_disentangled_vqa.py @@ -1,35 +1,22 @@ -import unittest +import asyncio import weave -import wandb -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.vqa import DisentangledVQAMetric from hemm.metrics.vqa.judges import BlipVQAJudge +from hemm.models import DiffusersModel -class TestDisentangledVQA(unittest.TestCase): - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - wandb.init( - project="unit-tests", - entity="hemm-eval", - job_type="test_desentangled_vqa_evaluation", - ) - weave.init(project_name="hemm-eval/unit-tests") +def test_disentangled_vqa_evaluation(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + ) - def test_desentangled_vqa_evaluation(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - ) - evaluation_pipeline = EvaluationPipeline(model=model) + judge = BlipVQAJudge() + metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric") - judge = BlipVQAJudge() - metric = DisentangledVQAMetric(judge=judge, name="disentangled_blip_metric") - evaluation_pipeline.add_metric(metric) - - dataset = weave.ref("attribute_binding_dataset:v0").get().rows[:2] - summary = evaluation_pipeline(dataset=dataset) - - self.assertGreater(summary["model_latency"]["mean"], 0.0) + dataset = weave.ref("attribute_binding_dataset:v0").get().rows[:2] + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) diff --git a/hemm/tests/test_image_quality_eval.py b/hemm/tests/test_image_quality_eval.py index f254ee9..0d11130 100644 --- a/hemm/tests/test_image_quality_eval.py +++ b/hemm/tests/test_image_quality_eval.py @@ -1,58 +1,28 @@ -import unittest +import asyncio import weave -import wandb -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric - - -class TestImageQualityEvaluation(unittest.TestCase): - - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - wandb.init( - project="unit-tests", - entity="hemm-eval", - job_type="test_image_quality_evaluation", - ) - weave.init(project_name="hemm-eval/unit-tests") - - def test_image_quality_metrics(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - ) - evaluation_pipeline = EvaluationPipeline(model=model) - - # Add PSNR Metric - psnr_metric = PSNRMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(psnr_metric) - - # Add SSIM Metric - ssim_metric = SSIMMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(ssim_metric) - - # Add LPIPS Metric - lpips_metric = LPIPSMetric(image_size=evaluation_pipeline.image_size) - evaluation_pipeline.add_metric(lpips_metric) - - dataset = weave.ref("COCO:v1").get().rows[:2] - summary = evaluation_pipeline(dataset=dataset) - - self.assertGreater( - summary["PSNRMetric.evaluate_async"]["peak_signal_noise_ratio"]["mean"], 0.0 - ) - self.assertGreater( - summary["SSIMMetric.evaluate_async"]["structural_similarity_index_measure"][ - "mean" - ], - 0.0, - ) - self.assertGreater( - summary["LPIPSMetric.evaluate_async"][ - "alexnet_learned_perceptual_image_patch_similarity" - ]["mean"], - 0.0, - ) - self.assertGreater(summary["model_latency"]["mean"], 0.0) +from hemm.models import DiffusersModel + + +def test_image_quality_metrics(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + ) + psnr_metric = PSNRMetric( + image_height=model.image_height, image_width=model.image_width + ) + ssim_metric = SSIMMetric( + image_height=model.image_height, image_width=model.image_width + ) + lpips_metric = LPIPSMetric( + image_height=model.image_height, image_width=model.image_width + ) + dataset = weave.ref("COCO:v1").get().rows[:2] + evaluation = weave.Evaluation( + dataset=dataset, scorers=[psnr_metric, ssim_metric, lpips_metric] + ) + asyncio.run(evaluation.evaluate(model)) diff --git a/hemm/tests/test_mllm_eval.py b/hemm/tests/test_mllm_eval.py index 415ed3b..7bb9a0a 100644 --- a/hemm/tests/test_mllm_eval.py +++ b/hemm/tests/test_mllm_eval.py @@ -1,40 +1,29 @@ -import unittest +import asyncio import weave -import wandb -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.vqa import MultiModalLLMEvaluationMetric from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory +from hemm.models import DiffusersModel -class TestMultiModalLLMEvaluation(unittest.TestCase): - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - wandb.init( - project="unit-tests", - entity="hemm-eval", - job_type="test_multimodal_llm_evaluation", - ) - weave.init(project_name="hemm-eval/unit-tests") +def test_multimodal_llm_evaluation(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + image_height=1024, + image_width=1024, + ) - def test_multimodal_llm_evaluation(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - image_height=1024, - image_width=1024, - ) - evaluation_pipeline = EvaluationPipeline(model=model) + judge = OpenAIJudge(prompt_property=PromptCategory.complex) + metric = MultiModalLLMEvaluationMetric(judge=judge) - judge = OpenAIJudge(prompt_property=PromptCategory.complex) - metric = MultiModalLLMEvaluationMetric(judge=judge) - evaluation_pipeline.add_metric(metric) + dataset = [ + {"prompt": "The fluffy pillow was on the left of the striped blanket."}, + {"prompt": "The round clock was mounted on the white wall."}, + {"prompt": "The black chair is on the right of the wooden table."}, + ] - evaluation_pipeline( - dataset=[ - {"prompt": "The fluffy pillow was on the left of the striped blanket."}, - {"prompt": "The round clock was mounted on the white wall."}, - {"prompt": "The black chair is on the right of the wooden table."}, - ] - ) + evaluation = weave.Evaluation(dataset=dataset, scorers=[metric]) + asyncio.run(evaluation.evaluate(model)) diff --git a/hemm/tests/test_prompt_alignment_eval.py b/hemm/tests/test_prompt_alignment_eval.py index a876455..bb9b0ca 100644 --- a/hemm/tests/test_prompt_alignment_eval.py +++ b/hemm/tests/test_prompt_alignment_eval.py @@ -1,46 +1,21 @@ -import unittest +import asyncio import weave -import wandb -from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric - - -class TestPromptAlignmentEvaluation(unittest.TestCase): - - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - wandb.init( - project="unit-tests", - entity="hemm-eval", - job_type="test_prompt_alignment_evaluation", - ) - weave.init(project_name="hemm-eval/unit-tests") - - def test_prompt_alignment_evaluation(self): - model = BaseDiffusionModel( - diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", - enable_cpu_offfload=False, - ) - evaluation_pipeline = EvaluationPipeline(model=model) - - # Add CLIP Scorer metric - clip_scorer = CLIPScoreMetric( - clip_model_name_or_path="openai/clip-vit-base-patch16" - ) - evaluation_pipeline.add_metric(clip_scorer) - - # Add CLIP IQA Metric - clip_iqa_scorer = CLIPImageQualityScoreMetric( - clip_model_name_or_path="clip_iqa" - ) - evaluation_pipeline.add_metric(clip_iqa_scorer) - - dataset = weave.ref("parti-prompts:v0").get().rows[:2] - summary = evaluation_pipeline(dataset=dataset) - - self.assertGreater( - summary["CLIPScoreMetric.evaluate_async"]["clip_score"]["mean"], 0.0 - ) - self.assertGreater(summary["model_latency"]["mean"], 0.0) +from hemm.models import DiffusersModel + + +def test_prompt_alignment_evaluation(): + weave.init(project_name="hemm-eval/unit-tests") + model = DiffusersModel( + diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4", + enable_cpu_offfload=False, + ) + clip_scorer = CLIPScoreMetric(model_name="openai/clip-vit-base-patch16") + clip_iqa_scorer = CLIPImageQualityScoreMetric(model_name="clip_iqa") + dataset = weave.ref("parti-prompts:v0").get().rows[:2] + evaluation = weave.Evaluation( + dataset=dataset, scorers=[clip_scorer, clip_iqa_scorer] + ) + asyncio.run(evaluation.evaluate(model)) diff --git a/hemm/tests/test_utils.py b/hemm/tests/test_utils.py index 583c3ae..c89dd0f 100644 --- a/hemm/tests/test_utils.py +++ b/hemm/tests/test_utils.py @@ -1,35 +1,31 @@ -import unittest - import weave from hemm.utils import publish_dataset_to_weave -class TestUtils(unittest.TestCase): - def __init__(self, methodName: str = "runTest") -> None: - super().__init__(methodName) - weave.init(project_name="hemm-eval/unit-tests") +def test_parti_prompts(): + weave.init(project_name="hemm-eval/unit-tests") + dataset_reference = publish_dataset_to_weave( + dataset_path="nateraw/parti-prompts", + prompt_column="Prompt", + split="train", + data_limit=10, + ) + assert dataset_reference is not None - def test_parti_prompts(self): - dataset_reference = publish_dataset_to_weave( - dataset_path="nateraw/parti-prompts", - prompt_column="Prompt", - split="train", - data_limit=10, - ) - self.assertIsNotNone(dataset_reference) - def test_coco(self): - def preprocess_sentences_column(example): - example["sentences"] = example["sentences"]["raw"] - return example +def test_coco(): + def preprocess_sentences_column(example): + example["sentences"] = example["sentences"]["raw"] + return example - dataset_reference = publish_dataset_to_weave( - dataset_path="HuggingFaceM4/COCO", - prompt_column="sentences", - ground_truth_image_column="image", - split="validation", - dataset_transforms=[preprocess_sentences_column], - data_limit=10, - ) - self.assertIsNotNone(dataset_reference) + weave.init(project_name="hemm-eval/unit-tests") + dataset_reference = publish_dataset_to_weave( + dataset_path="HuggingFaceM4/COCO", + prompt_column="sentences", + ground_truth_image_column="image", + split="validation", + dataset_transforms=[preprocess_sentences_column], + data_limit=10, + ) + assert dataset_reference is not None diff --git a/mkdocs.yml b/mkdocs.yml index c66a0e9..7446e4f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,7 +7,7 @@ theme: # Palette toggle for light mode - scheme: default toggle: - icon: material/brightness-7 + icon: material/brightness-7 name: Switch to dark mode # Palette toggle for dark mode - scheme: slate @@ -57,7 +57,6 @@ extra_javascript: nav: - Home: 'index.md' - - Evaluation-Pipelines: 'eval_pipelines.md' - Models: - Diffusion-Models: 'models/diffusion_model.md' - FalAI-Models: 'models/falai_model.md' diff --git a/pyproject.toml b/pyproject.toml index 1fdbb9a..e6d62f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,23 +14,24 @@ dependencies = [ "instructor>=1.6.1", "jsonlines>=4.0.0", "openai>=1.51.2", + "opencv-python>=4.10.0.84", "sentencepiece>=0.2.0", "torcheval>=0.0.7", "torchmetrics[multimodal]>=1.4.1,<2.0.0", "transformers>=4.45.2", "spacy>=3.8.2", "wandb>=0.18.3", - "weave==0.51.14", + "weave==0.51.17", "uv>=0.4.22", "pip>=24.2", ] [project.optional-dependencies] dev = [ - "pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9", + "pytest>=8.3.3", ] docs = [