diff --git a/src/api/batch_decoder.py b/src/api/batch_decoder.py index f9accff..ade56fa 100644 --- a/src/api/batch_decoder.py +++ b/src/api/batch_decoder.py @@ -150,7 +150,8 @@ def save_prediction_outputs( group_ids: List[str], image_ids: List[str], base_output_path: str, - image_metadata: List[Dict] + image_metadata: List[Dict], + temp_dir: str = None ) -> List[str]: """ Generate output texts based on predictions and save to files atomically. @@ -167,6 +168,9 @@ def save_prediction_outputs( Base path where prediction outputs should be saved. image_metadata: List[Dict] List of metadata dictionaries for each image. + temp_dir: str, optional + Path to use for temporary files. If None, a subdirectory of + base_output_path is used. Returns ------- @@ -182,14 +186,21 @@ def save_prediction_outputs( """ output_texts = [] - # Create a temporary directory for intermediate files - with tempfile.TemporaryDirectory(prefix="prediction_outputs_") as temp_dir: + # If no temp_dir is provided, create one as a subdirectory of + # base_output_path + if temp_dir is None: + temp_dir = os.path.join(base_output_path, '.temp_prediction_outputs') + + os.makedirs(temp_dir, exist_ok=True) + logging.debug(f"Using temporary directory: {temp_dir}") + + try: for prediction, group_id, image_id, metadata in zip( prediction_data, group_ids, image_ids, image_metadata ): confidence, predicted_text = prediction - output_text = (f"{image_id}\t{metadata}\t{confidence}" - f"\t{predicted_text}") + output_text = ( + f"{image_id}\t{metadata}\t{confidence}\t{predicted_text}") output_texts.append(output_text) group_output_dir = os.path.join(base_output_path, group_id) @@ -207,6 +218,11 @@ def save_prediction_outputs( logging.error("Failed to write file %s. Error: %s", output_file_path, e) raise + finally: + # Clean up the temporary directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + logging.debug(f"Cleaned up temporary directory: {temp_dir}") return output_texts @@ -239,8 +255,7 @@ def write_file_atomically(content: str, target_path: str, temp_dir: str) \ temp_file.write(content + "\n") temp_file_path = temp_file.name - # On POSIX systems, this is atomic. On Windows, it's the best we can - # do. + # On POSIX systems, this is atomic. On Windows, it's the best we can do os.replace(temp_file_path, target_path) except IOError as e: if temp_file_path and os.path.exists(temp_file_path):