Skip to content

Commit

Permalink
include label_info in API results
Browse files Browse the repository at this point in the history
  • Loading branch information
fjxmlzn committed Dec 27, 2024
1 parent e6f4078 commit aed239a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pe/api/image/improved_diffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@ def random_api(self, label_info, num_samples):
{
IMAGE_DATA_COLUMN_NAME: list(samples),
IMAGE_MODEL_LABEL_COLUMN_NAME: list(labels),
LABEL_ID_COLUMN_NAME: 0,
}
)
metadata = {"label_info": [label_info]}
execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}")
return Data(data_frame=data_frame)
return Data(data_frame=data_frame, metadata=metadata)

def variation_api(self, syn_data):
"""Generating variations of the synthetic data.
Expand Down Expand Up @@ -200,6 +202,7 @@ def variation_api(self, syn_data):
{
IMAGE_DATA_COLUMN_NAME: list(variations),
IMAGE_MODEL_LABEL_COLUMN_NAME: list(labels),
LABEL_ID_COLUMN_NAME: syn_data.data_frame[LABEL_ID_COLUMN_NAME].values,
}
)
if LABEL_ID_COLUMN_NAME in syn_data.data_frame.columns:
Expand Down
5 changes: 4 additions & 1 deletion pe/api/image/stable_diffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def random_api(self, label_info, num_samples):
{
IMAGE_DATA_COLUMN_NAME: list(images),
IMAGE_PROMPT_COLUMN_NAME: prompt,
LABEL_ID_COLUMN_NAME: 0,
}
)
metadata = {"label_info": [label_info]}
execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}")
return Data(data_frame=data_frame)
return Data(data_frame=data_frame, metadata=metadata)

def variation_api(self, syn_data):
"""Generating variations of the synthetic data.
Expand Down Expand Up @@ -199,6 +201,7 @@ def variation_api(self, syn_data):
{
IMAGE_DATA_COLUMN_NAME: list(variations),
IMAGE_PROMPT_COLUMN_NAME: prompts,
LABEL_ID_COLUMN_NAME: syn_data.data_frame[LABEL_ID_COLUMN_NAME].values,
}
)
if LABEL_ID_COLUMN_NAME in syn_data.data_frame.columns:
Expand Down

0 comments on commit aed239a

Please sign in to comment.