diff --git a/src/backend/app/s3.py b/src/backend/app/s3.py index 13a59036..49b120a2 100644 --- a/src/backend/app/s3.py +++ b/src/backend/app/s3.py @@ -2,6 +2,7 @@ from loguru import logger as log from minio import Minio from io import BytesIO +from typing import Any def s3_client(): @@ -46,20 +47,43 @@ def add_file_to_bucket(bucket_name: str, file_path: str, s3_path: str): file_path (str): The path to the file on the local filesystem. s3_path (str): The path in the S3 bucket where the file will be stored. """ + # Ensure s3_path starts with a forward slash + if not s3_path.startswith("/"): + s3_path = f"/{s3_path}" + client = s3_client() client.fput_object(bucket_name, file_path, s3_path) -def add_obj_to_bucket(bucket_name: str, file_obj: BytesIO, s3_path: str): +def add_obj_to_bucket( + bucket_name: str, + file_obj: BytesIO, + s3_path: str, + content_type: str = "application/octet-stream", + **kwargs: dict[str, Any], +): """Upload a BytesIO object to an S3 bucket. Args: bucket_name (str): The name of the S3 bucket. file_obj (BytesIO): A BytesIO object containing the data to be uploaded. s3_path (str): The path in the S3 bucket where the data will be stored. + content_type (str, optional): The content type of the uploaded file. + Default application/octet-stream. + kwargs (dict[str, Any]): Any other arguments to pass to client.put_object. + """ + # Strip "/" from start of s3_path (not required by put_object) + if s3_path.startswith("/"): + s3_path = s3_path.lstrip("/") + client = s3_client() - result = client.put_object(bucket_name, file_obj, s3_path) + # Set BytesIO object to start, prior to .read() + file_obj.seek(0) + + result = client.put_object( + bucket_name, s3_path, file_obj, file_obj.getbuffer().nbytes, **kwargs + ) log.debug( f"Created {result.object_name} object; etag: {result.etag}, " f"version-id: {result.version_id}" @@ -75,6 +99,10 @@ def get_file_from_bucket(bucket_name: str, s3_path: str, file_path: str): file_path (str): The path on the local filesystem where the S3 file will be saved. """ + # Ensure s3_path starts with a forward slash + if not s3_path.startswith("/"): + s3_path = f"/{s3_path}" + client = s3_client() client.fget_object(bucket_name, s3_path, file_path) @@ -89,10 +117,19 @@ def get_obj_from_bucket(bucket_name: str, s3_path: str) -> BytesIO: Returns: BytesIO: A BytesIO object containing the content of the downloaded S3 object. """ + # Ensure s3_path starts with a forward slash + if not s3_path.startswith("/"): + s3_path = f"/{s3_path}" + client = s3_client() + response = None try: response = client.get_object(bucket_name, s3_path) return BytesIO(response.read()) + except Exception as e: + log.warning(f"Failed attempted download from S3 path: {s3_path}") + raise ValueError(str(e)) from e finally: - response.close() - response.release_conn() + if response: + response.close() + response.release_conn()