Skip to content

Commit

Permalink
fix: fix typing errors in 6_push_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Sep 19, 2024
1 parent 726430e commit d77fe24
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions nutrition-detector/dataset-generation/6_push_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from label_studio_sdk.client import LabelStudio
from openfoodfacts.images import extract_barcode_from_url
from openfoodfacts.utils import get_image_from_url, get_logger
from PIL import Image

logger = get_logger()

Expand Down Expand Up @@ -65,7 +66,7 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:
current_bbox_id = None
tokens = []
bboxes = []
ner_tags = []
ner_tags: list[str] = []
for result in annotation_results:
result_value = result["value"]
if result["from_name"] in ("transcription", "label"):
Expand Down Expand Up @@ -159,7 +160,7 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:
)
return None

image = get_image_from_url(image_url, error_raise=False)
image: Image.Image | None = get_image_from_url(image_url, error_raise=False)

if image is None:
logger.info("Cannot load image from %s, skipping", image_url)
Expand All @@ -178,7 +179,10 @@ def create_sample(task: Task, only_checked: bool = False) -> Optional[dict]:


def get_tasks(
label_studio_url: str, api_key: str, project_id: int, batch_ids: list[int] = None
label_studio_url: str,
api_key: str,
project_id: int,
batch_ids: list[int] | None = None,
) -> Iterator[dict]:
"""Yield tasks (annotations) from Label Studio."""

Expand Down Expand Up @@ -243,19 +247,19 @@ def push_dataset(
):
logger.info("Fetching tasks from Label Studio, project %s", project_id)
if batch_ids:
batch_ids = list(map(int, batch_ids.split(",")))
logger.info("Fetching tasks for batches %s", batch_ids)
batch_ids_int = list(map(int, batch_ids.split(",")))
logger.info("Fetching tasks for batches %s", batch_ids_int)

created = 0
ner_tag_set = set()
ner_tag_set: set[str] = set()

with tempfile.TemporaryDirectory() as tmp_dir_str:
tmp_dir = Path(tmp_dir_str)
logger.info("Saving samples in temporary directory %s", tmp_dir)

for i, task in enumerate(
tqdm.tqdm(
get_tasks(label_studio_url, api_key, project_id, batch_ids),
get_tasks(label_studio_url, api_key, project_id, batch_ids_int),
desc="tasks",
)
):
Expand Down

0 comments on commit d77fe24

Please sign in to comment.