Skip to content

Commit

Permalink
update: config management for evals
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Sep 4, 2024
1 parent f672e6d commit a612baf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
4 changes: 2 additions & 2 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
)
with weave.attributes(self.evaluation_configs):
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
return summary
13 changes: 4 additions & 9 deletions hemm/eval_pipelines/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseDiffusionModel(weave.Model):
image_height: int = 512
image_width: int = 512
disable_safety_checker: bool = True
configs: Dict[str, Any] = {}
pipeline_configs: Dict[str, Any] = {}
_torch_dtype: torch.dtype = torch.float16
_pipeline: DiffusionPipeline = None
Expand All @@ -32,6 +33,7 @@ def __init__(
image_height: int = 512,
image_width: int = 512,
disable_safety_checker: bool = True,
configs: Dict[str, Any] = {},
pipeline_configs: Dict[str, Any] = {},
) -> None:
super().__init__(
Expand All @@ -40,6 +42,7 @@ def __init__(
image_height=image_height,
image_width=image_width,
disable_safety_checker=disable_safety_checker,
configs=configs,
pipeline_configs=pipeline_configs,
)
pipeline_init_kwargs = {
Expand All @@ -65,12 +68,4 @@ def predict(self, prompt: str, seed: int) -> Dict[str, Any]:
width=self.image_width,
generator=torch.Generator(device="cuda").manual_seed(seed),
)
result_dict = {
"image": pipeline_output.images[0],
}
result_dict["nsfw_content_detected"] = (
pipeline_output.nsfw_content_detected is not None
if hasattr(pipeline_output, "nsfw_content_detected")
else False
)
return result_dict
return {"image": pipeline_output.images[0]}

0 comments on commit a612baf

Please sign in to comment.