Skip to content

Commit

Permalink
Support fetching classic flow training dataset for classification mod…
Browse files Browse the repository at this point in the history
…els (#159)

* Support fetching classific flow training dataset for classification models

* Polish the code
  • Loading branch information
AsiaCao authored Nov 14, 2023
1 parent 2d0d8d7 commit cb30fed
Showing 1 changed file with 55 additions and 20 deletions.
75 changes: 55 additions & 20 deletions landingai/data_management/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import pandas as pd
import requests
from PIL import Image
from typing_extensions import deprecated

from tqdm import tqdm
from typing_extensions import deprecated

from landingai.common import decode_bitmap_rle
from landingai.data_management.client import (
Expand All @@ -20,7 +19,6 @@
LandingLens,
)
from landingai.data_management.metadata import Metadata
from landingai.exceptions import HttpError

_LOGGER = logging.getLogger(__name__)
_PAGE_SIZE = 10 # 10 is the max page size, unfortunately
Expand Down Expand Up @@ -255,8 +253,9 @@ def get_legacy_training_dataset(
self, output_dir: Path, job_id: str
) -> pd.DataFrame:
"""Get the training dataset from legacy training flow by job_id.
Currently, it only supports segmentation and classification datasets.
Example output of the returned dataframe:
Example output of the returned dataframe for a segmentation dataset:
```
media_id seg_mask_prediction_path seg_mask_label_path
0 10413664 /work/landingai-python/104136... /work/landingai-python/104136...
Expand All @@ -267,17 +266,34 @@ def get_legacy_training_dataset(
NOTE:
1. This dataset has a similar format as the dataset returned by `TrainingDataset.get_training_dataset()`.
2. Only difference is that the prediction mask is thresholded, i.e. the value of each pixel is either 0 or 1.
Example output of the returned dataframe for a classification dataset:
```
media_id label_class prediction_score prediction_class prediction_type
0 9789913 black_spot 0.992697 black_spot correct
1 9789914 black_spot 0.996753 black_spot correct
... ... ... ... ... ...
1801 9791719 unclassified 0.969400 unclassified correct
1802 9791720 unclassified 0.778278 unclassified correct
```
"""

output_dir.mkdir(parents=True, exist_ok=True)
data = _fetch_gt_and_predictions(
self._project_id, self._cookie, job_id=job_id, offset=0
)
if not data:
raise ValueError(
f"Failed to find a classic flow job by job id: {job_id} in project {self._project_id}. Please check the error log for more details and act accordingly."
)
dataset_type = data["type"]
rows: List[Dict[str, Any]] = [
_extract_gt_and_predictions(d, output_dir) for d in data["details"]
_extract_gt_and_predictions(d, output_dir, dataset_type)
for d in data["details"]
]
total = data["totalItems"]
_LOGGER.info(f"Found a total of {total} images:")
_LOGGER.info(f"Found {total} records from a {dataset_type} dataset:")
if total > _PAGE_SIZE:
new_offsets = list(range(0, total - _PAGE_SIZE, _PAGE_SIZE))
new_offsets = [offset + _PAGE_SIZE for offset in new_offsets]
Expand All @@ -295,8 +311,10 @@ def get_legacy_training_dataset(
]
for future in concurrent.futures.as_completed(futures):
new_data = future.result()
if not new_data:
continue
new_rows = [
_extract_gt_and_predictions(d, output_dir)
_extract_gt_and_predictions(d, output_dir, dataset_type)
for d in new_data["details"]
]
rows.extend(new_rows)
Expand Down Expand Up @@ -324,31 +342,48 @@ def _fetch_gt_and_predictions(
f"Could not find a classic flow job by job id: {job_id} in project {project_id}. Please double check your job id and project id is correct, and it's a classic flow job."
)
error_message = resp.text
_LOGGER.warning(
f"Failed to fetch legacy training dataset predictions: {resp.status_code}, {error_message}"
)
raise HttpError(
_LOGGER.error(
f"Failed to fetch legacy training dataset: project_id {project_id}, job_id {job_id}, offset {offset}.\n"
"HTTP request to LandingLens server failed with "
f"code {resp.status_code}-{resp.reason} and error message: \n"
f"{error_message}"
)
return {}
return cast(Dict[str, Any], resp.json()["data"])


def _extract_gt_and_predictions(
img_pred_gt_info: Dict[str, Any],
output_dir: Path,
dataset_type: str,
) -> Dict[str, Any]:
assert dataset_type in {
"classification",
"segmentation",
}, f"Unsupported dataset type: {dataset_type}"

media_id = int(img_pred_gt_info["mediaId"])
pred_bitmasks = img_pred_gt_info["prediction"]
pred_mask_path = _save_mask(pred_bitmasks, output_dir, media_id, save_suffix="pred")
gt_bitmasks = img_pred_gt_info["groundTruth"]
gt_mask_path = _save_mask(gt_bitmasks, output_dir, media_id, save_suffix="gt")
return {
"media_id": media_id,
"seg_mask_prediction_path": pred_mask_path.absolute().as_posix(),
"seg_mask_label_path": gt_mask_path.absolute().as_posix(),
}
if dataset_type == "classification":
pred = list(img_pred_gt_info["prediction"].values())[0]
return {
"media_id": media_id,
"label_class": list(img_pred_gt_info["groundTruth"].values())[0],
"prediction_score": pred["score"],
"prediction_class": pred["labelName"],
"prediction_type": pred["type"],
}
else:
pred_bitmasks = img_pred_gt_info["prediction"]
pred_mask_path = _save_mask(
pred_bitmasks, output_dir, media_id, save_suffix="pred"
)
gt_bitmasks = img_pred_gt_info["groundTruth"]
gt_mask_path = _save_mask(gt_bitmasks, output_dir, media_id, save_suffix="gt")
return {
"media_id": media_id,
"seg_mask_prediction_path": pred_mask_path.absolute().as_posix(),
"seg_mask_label_path": gt_mask_path.absolute().as_posix(),
}


def _save_mask(
Expand Down

0 comments on commit cb30fed

Please sign in to comment.