From 28e07be870eaa7746c945886b26978be438d7b53 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Thu, 6 Feb 2025 10:35:24 -0800 Subject: [PATCH] feat: Unified File Management API (#6100) * feat: FIrst pass at file management API * [autofix.ci] apply automated fixes * Add delete and edit endpoints * [autofix.ci] apply automated fixes * Add file size and duplicate name handling * Ensure the File model has a unique name * Ensure count is before extension * [autofix.ci] apply automated fixes * Add the correct path to the return * Added function to handle list of paths in File component * [autofix.ci] apply automated fixes * Update input_mixin.py * Refactor to a v2 endpoint * Add unit tests * Update test_files.py * Update frontend.ts * [autofix.ci] apply automated fixes * Remove extension from name * Cast the string type for like * Update files.py * Update base.py * Update base.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Lucas Oliveira --- scripts/aws/lib/construct/frontend.ts | 1 + .../dd9e0804ebd1_add_v2_file_table.py | 49 ++++ src/backend/base/langflow/api/__init__.py | 4 +- src/backend/base/langflow/api/router.py | 8 + src/backend/base/langflow/api/schemas.py | 14 ++ src/backend/base/langflow/api/v2/__init__.py | 5 + src/backend/base/langflow/api/v2/files.py | 228 ++++++++++++++++++ .../base/langflow/base/data/base_file.py | 26 +- .../base/langflow/graph/vertex/base.py | 33 ++- .../base/langflow/inputs/input_mixin.py | 20 +- src/backend/base/langflow/main.py | 3 +- .../services/database/models/file/__init__.py | 5 + .../services/database/models/file/crud.py | 14 ++ .../services/database/models/file/model.py | 17 ++ .../base/langflow/services/storage/local.py | 12 + src/backend/tests/unit/api/v2/__init__.py | 0 src/backend/tests/unit/api/v2/test_files.py | 209 ++++++++++++++++ .../src/customization/config-constants.ts | 2 +- 18 files changed, 635 insertions(+), 15 deletions(-) create mode 100644 src/backend/base/langflow/alembic/versions/dd9e0804ebd1_add_v2_file_table.py create mode 100644 src/backend/base/langflow/api/schemas.py create mode 100644 src/backend/base/langflow/api/v2/__init__.py create mode 100644 src/backend/base/langflow/api/v2/files.py create mode 100644 src/backend/base/langflow/services/database/models/file/__init__.py create mode 100644 src/backend/base/langflow/services/database/models/file/crud.py create mode 100644 src/backend/base/langflow/services/database/models/file/model.py create mode 100644 src/backend/tests/unit/api/v2/__init__.py create mode 100644 src/backend/tests/unit/api/v2/test_files.py diff --git a/scripts/aws/lib/construct/frontend.ts b/scripts/aws/lib/construct/frontend.ts index 78d516e1484c..85eec2c93f58 100644 --- a/scripts/aws/lib/construct/frontend.ts +++ b/scripts/aws/lib/construct/frontend.ts @@ -92,6 +92,7 @@ export class Web extends Construct { defaultBehavior: { origin: s3SpaOrigin }, additionalBehaviors: { '/api/v1/*': albBehaviorOptions, + '/api/v2/*': albBehaviorOptions, '/health' : albBehaviorOptions, }, enableLogging: true, // ログ出力設定 diff --git a/src/backend/base/langflow/alembic/versions/dd9e0804ebd1_add_v2_file_table.py b/src/backend/base/langflow/alembic/versions/dd9e0804ebd1_add_v2_file_table.py new file mode 100644 index 000000000000..2f9575b99f53 --- /dev/null +++ b/src/backend/base/langflow/alembic/versions/dd9e0804ebd1_add_v2_file_table.py @@ -0,0 +1,49 @@ +"""Add V2 File Table + +Revision ID: dd9e0804ebd1 +Revises: e3162c1804e6 +Create Date: 2025-02-03 11:47:16.101523 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel +from langflow.utils import migration + + +# revision identifiers, used by Alembic. +revision: str = 'dd9e0804ebd1' +down_revision: Union[str, None] = 'e3162c1804e6' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + if not migration.table_exists("file", conn): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "file", + sa.Column("id", sqlmodel.sql.sqltypes.types.Uuid(), nullable=False), + sa.Column("user_id", sqlmodel.sql.sqltypes.types.Uuid(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False, unique=True), + sa.Column("path", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("size", sa.Integer(), nullable=False), + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], name="fk_file_user_id_user"), + sa.UniqueConstraint("name"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + conn = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + if migration.table_exists("file", conn): + op.drop_table("file") + # ### end Alembic commands ### diff --git a/src/backend/base/langflow/api/__init__.py b/src/backend/base/langflow/api/__init__.py index 8150f0d9cac5..0efbd0bde442 100644 --- a/src/backend/base/langflow/api/__init__.py +++ b/src/backend/base/langflow/api/__init__.py @@ -1,5 +1,5 @@ from langflow.api.health_check_router import health_check_router from langflow.api.log_router import log_router -from langflow.api.router import router +from langflow.api.router import router, router_v2 -__all__ = ["health_check_router", "log_router", "router"] +__all__ = ["health_check_router", "log_router", "router", "router_v2"] diff --git a/src/backend/base/langflow/api/router.py b/src/backend/base/langflow/api/router.py index d2ce1905ada0..94d290e5a546 100644 --- a/src/backend/base/langflow/api/router.py +++ b/src/backend/base/langflow/api/router.py @@ -16,10 +16,16 @@ validate_router, variables_router, ) +from langflow.api.v2 import files_router as files_router_v2 router = APIRouter( prefix="/api/v1", ) + +router_v2 = APIRouter( + prefix="/api/v2", +) + router.include_router(chat_router) router.include_router(endpoints_router) router.include_router(validate_router) @@ -33,3 +39,5 @@ router.include_router(monitor_router) router.include_router(folders_router) router.include_router(starter_projects_router) + +router_v2.include_router(files_router_v2) diff --git a/src/backend/base/langflow/api/schemas.py b/src/backend/base/langflow/api/schemas.py new file mode 100644 index 000000000000..99fec0e6eb9f --- /dev/null +++ b/src/backend/base/langflow/api/schemas.py @@ -0,0 +1,14 @@ +from pathlib import Path +from uuid import UUID + +from pydantic import BaseModel + + +class UploadFileResponse(BaseModel): + """File upload response schema.""" + + id: UUID + name: str + path: Path + size: int + provider: str | None = None diff --git a/src/backend/base/langflow/api/v2/__init__.py b/src/backend/base/langflow/api/v2/__init__.py new file mode 100644 index 000000000000..2ada31ec9974 --- /dev/null +++ b/src/backend/base/langflow/api/v2/__init__.py @@ -0,0 +1,5 @@ +from langflow.api.v2.files import router as files_router + +__all__ = [ + "files_router", +] diff --git a/src/backend/base/langflow/api/v2/files.py b/src/backend/base/langflow/api/v2/files.py new file mode 100644 index 000000000000..66ce932d907a --- /dev/null +++ b/src/backend/base/langflow/api/v2/files.py @@ -0,0 +1,228 @@ +import uuid +from collections.abc import AsyncGenerator +from http import HTTPStatus +from pathlib import Path +from typing import Annotated + +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from fastapi.responses import StreamingResponse +from sqlmodel import String, cast, select + +from langflow.api.schemas import UploadFileResponse +from langflow.api.utils import CurrentActiveUser, DbSession +from langflow.services.database.models.file import File as UserFile +from langflow.services.deps import get_settings_service, get_storage_service +from langflow.services.storage.service import StorageService + +router = APIRouter(tags=["Files"], prefix="/files") + + +async def byte_stream_generator(file_bytes: bytes, chunk_size: int = 8192) -> AsyncGenerator[bytes, None]: + """Convert bytes object into an async generator that yields chunks.""" + for i in range(0, len(file_bytes), chunk_size): + yield file_bytes[i : i + chunk_size] + + +async def fetch_file_object(file_id: uuid.UUID, current_user: CurrentActiveUser, session: DbSession): + # Fetch the file from the DB + stmt = select(UserFile).where(UserFile.id == file_id) + results = await session.exec(stmt) + file = results.first() + + # Check if the file exists + if not file: + raise HTTPException(status_code=404, detail="File not found") + + # Make sure the user has access to the file + if file.user_id != current_user.id: + raise HTTPException(status_code=403, detail="You don't have access to this file") + + return file + + +@router.post("", status_code=HTTPStatus.CREATED) +async def upload_user_file( + file: Annotated[UploadFile, File(...)], + session: DbSession, + current_user: CurrentActiveUser, + storage_service=Depends(get_storage_service), + settings_service=Depends(get_settings_service), +) -> UploadFileResponse: + """Upload a file for the current user and track it in the database.""" + # Get the max allowed file size from settings (in MB) + try: + max_file_size_upload = settings_service.settings.max_file_size_upload + except Exception as e: + raise HTTPException(status_code=500, detail=f"Settings error: {e}") from e + + # Validate that a file is actually provided + if not file or not file.filename: + raise HTTPException(status_code=400, detail="No file provided") + + # Validate file size (convert MB to bytes) + if file.size > max_file_size_upload * 1024 * 1024: + raise HTTPException( + status_code=413, + detail=f"File size is larger than the maximum file size {max_file_size_upload}MB.", + ) + + # Read file content and create a unique file name + try: + # Create a unique file name + file_id = uuid.uuid4() + file_content = await file.read() + + # Get file extension of the file + file_extension = "." + file.filename.split(".")[-1] if file.filename and "." in file.filename else "" + anonymized_file_name = f"{file_id!s}{file_extension}" + + # Here we use the current user's id as the folder name + folder = str(current_user.id) + # Save the file using the storage service. + await storage_service.save_file(flow_id=folder, file_name=anonymized_file_name, data=file_content) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error saving file: {e}") from e + + # Create a new database record for the uploaded file. + try: + # Enforce unique constraint on name + # Name it as filename (1), (2), etc. + # Check if the file name already exists + new_filename = file.filename + try: + root_filename, _ = new_filename.rsplit(".", 1) + except ValueError: + root_filename, _ = new_filename, "" + + # Check if there are files with the same name + stmt = select(UserFile).where(cast(UserFile.name, String).like(f"{root_filename}%")) + existing_files = await session.exec(stmt) + files = existing_files.all() # Fetch all matching records + + # If there are files with the same name, append a count to the filename + if files: + count = len(files) # Count occurrences + + # Split the extension from the filename + root_filename = f"{root_filename} ({count})" + + # Compute the file size based on the path + file_size = await storage_service.get_file_size(flow_id=folder, file_name=anonymized_file_name) + + # Compute the file path + file_path = f"{folder}/{anonymized_file_name}" + + # Create a new file record + new_file = UserFile( + id=file_id, + user_id=current_user.id, + name=root_filename, + path=file_path, + size=file_size, + ) + session.add(new_file) + + await session.commit() + await session.refresh(new_file) + except Exception as e: + # Optionally, you could also delete the file from disk if the DB insert fails. + raise HTTPException(status_code=500, detail=f"Database error: {e}") from e + + return UploadFileResponse(id=new_file.id, name=new_file.name, path=Path(new_file.path), size=new_file.size) + + +@router.get("") +async def list_files( + current_user: CurrentActiveUser, + session: DbSession, +) -> list[UserFile]: + """List the files available to the current user.""" + try: + # Fetch from the UserFile table + stmt = select(UserFile).where(UserFile.user_id == current_user.id) + results = await session.exec(stmt) + + return list(results) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error listing files: {e}") from e + + +@router.get("/{file_id}") +async def download_file( + file_id: uuid.UUID, + current_user: CurrentActiveUser, + session: DbSession, + storage_service: Annotated[StorageService, Depends(get_storage_service)], +): + """Download a file by its ID.""" + try: + # Fetch the file from the DB + file = await fetch_file_object(file_id, current_user, session) + + # Get the basename of the file path + file_name = file.path.split("/")[-1] + + # Get file stream + file_stream = await storage_service.get_file(flow_id=str(current_user.id), file_name=file_name) + + # Ensure file_stream is an async iterator returning bytes + byte_stream = byte_stream_generator(file_stream) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") from e + + # Return the file as a streaming response + return StreamingResponse( + byte_stream, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{file.name}"'}, + ) + + +@router.put("/{file_id}") +async def edit_file_name( + file_id: uuid.UUID, + name: str, + current_user: CurrentActiveUser, + session: DbSession, +) -> UploadFileResponse: + """Edit the name of a file by its ID.""" + try: + # Fetch the file from the DB + file = await fetch_file_object(file_id, current_user, session) + + # Update the file name + file.name = name + await session.commit() + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error editing file: {e}") from e + + return UploadFileResponse(id=file.id, name=file.name, path=file.path, size=file.size) + + +@router.delete("/{file_id}") +async def delete_file( + file_id: uuid.UUID, + current_user: CurrentActiveUser, + session: DbSession, + storage_service: Annotated[StorageService, Depends(get_storage_service)], +): + """Delete a file by its ID.""" + try: + # Fetch the file from the DB + file = await fetch_file_object(file_id, current_user, session) + if not file: + raise HTTPException(status_code=404, detail="File not found") + + # Delete the file from the storage service + await storage_service.delete_file(flow_id=str(current_user.id), file_name=file.path) + + # Delete from the database + await session.delete(file) + await session.flush() # Ensures delete is staged + await session.commit() # Commit deletion + + except Exception as e: + await session.rollback() # Rollback on failure + raise HTTPException(status_code=500, detail=f"Error deleting file: {e}") from e + + return {"message": "File deleted successfully"} diff --git a/src/backend/base/langflow/base/data/base_file.py b/src/backend/base/langflow/base/data/base_file.py index 26d3300c8525..8b9f1ad5bf07 100644 --- a/src/backend/base/langflow/base/data/base_file.py +++ b/src/backend/base/langflow/base/data/base_file.py @@ -100,7 +100,10 @@ def __str__(self): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Dynamically update FileInput to include valid extensions and bundles - self._base_inputs[0].file_types = [*self.valid_extensions, *self.SUPPORTED_BUNDLE_EXTENSIONS] + self._base_inputs[0].file_types = [ + *self.valid_extensions, + *self.SUPPORTED_BUNDLE_EXTENSIONS, + ] file_types = ", ".join(self.valid_extensions) bundles = ", ".join(self.SUPPORTED_BUNDLE_EXTENSIONS) @@ -342,8 +345,13 @@ def add_file(data: Data, path: str | Path, *, delete_after_processing: bool): if self.path and not file_path: # Wrap self.path into a Data object - data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: self.path}) - add_file(data=data_obj, path=self.path, delete_after_processing=False) + if isinstance(self.path, list): + for path in self.path: + data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: path}) + add_file(data=data_obj, path=path, delete_after_processing=False) + else: + data_obj = Data(data={self.SERVER_FILE_PATH_FIELDNAME: self.path}) + add_file(data=data_obj, path=self.path, delete_after_processing=False) elif file_path: for obj in file_path: server_file_path = obj.data.get(self.SERVER_FILE_PATH_FIELDNAME) @@ -384,7 +392,11 @@ def _unpack_and_collect_files(self, files: list[BaseFile]) -> list[BaseFile]: # Recurse into directories collected_files.extend( [ - BaseFileComponent.BaseFile(data, sub_path, delete_after_processing=delete_after_processing) + BaseFileComponent.BaseFile( + data, + sub_path, + delete_after_processing=delete_after_processing, + ) for sub_path in path.rglob("*") if sub_path.is_file() ] @@ -399,7 +411,11 @@ def _unpack_and_collect_files(self, files: list[BaseFile]) -> list[BaseFile]: self.log(f"Unpacked bundle {path.name} into {subpaths}") collected_files.extend( [ - BaseFileComponent.BaseFile(data, sub_path, delete_after_processing=delete_after_processing) + BaseFileComponent.BaseFile( + data, + sub_path, + delete_after_processing=delete_after_processing, + ) for sub_path in subpaths ] ) diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index d21a1f5982b8..f048fac5d05c 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -14,7 +14,12 @@ from loguru import logger from langflow.exceptions.component import ComponentBuildError -from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData +from langflow.graph.schema import ( + INPUT_COMPONENTS, + OUTPUT_COMPONENTS, + InterfaceComponentTypes, + ResultData, +) from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction from langflow.interface import initialize from langflow.interface.listing import lazy_load_dict @@ -355,8 +360,16 @@ def build_params(self) -> None: if file_path := field.get("file_path"): storage_service = get_storage_service() try: - flow_id, file_name = os.path.split(file_path) - full_path = storage_service.build_full_path(flow_id, file_name) + full_path: str | list[str] = "" + if field.get("list"): + full_path = [] + for p in file_path: + flow_id, file_name = os.path.split(p) + path = storage_service.build_full_path(flow_id, file_name) + full_path.append(path) + else: + flow_id, file_name = os.path.split(file_path) + full_path = storage_service.build_full_path(flow_id, file_name) except ValueError as e: if "too many values to unpack" in str(e): full_path = file_path @@ -621,7 +634,12 @@ async def get_result(self, requester: Vertex, target_handle_name: str | None = N return await self._get_result(requester, target_handle_name) async def _log_transaction_async( - self, flow_id: str | UUID, source: Vertex, status, target: Vertex | None = None, error=None + self, + flow_id: str | UUID, + source: Vertex, + status, + target: Vertex | None = None, + error=None, ) -> None: """Log a transaction asynchronously with proper task handling and cancellation. @@ -723,7 +741,12 @@ def _extend_params_list_with_result(self, key, result) -> None: self.params[key].extend(result) async def _build_results( - self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False + self, + custom_component, + custom_params, + base_type: str, + *, + fallback_to_env_vars=False, ) -> None: try: result = await initialize.loading.get_instance_results( diff --git a/src/backend/base/langflow/inputs/input_mixin.py b/src/backend/base/langflow/inputs/input_mixin.py index 4ccb8928db37..44ea4a121d74 100644 --- a/src/backend/base/langflow/inputs/input_mixin.py +++ b/src/backend/base/langflow/inputs/input_mixin.py @@ -132,9 +132,27 @@ class DatabaseLoadMixin(BaseModel): # Specific mixin for fields needing file interaction class FileMixin(BaseModel): - file_path: str | None = Field(default="") + file_path: list[str] | str | None = Field(default="") file_types: list[str] = Field(default=[], alias="fileTypes") + @field_validator("file_path") + @classmethod + def validate_file_path(cls, v): + if v is None or v == "": + return v + # If it's already a list, validate each element is a string + if isinstance(v, list): + for item in v: + if not isinstance(item, str): + msg = "All file paths must be strings" + raise TypeError(msg) + return v + # If it's a single string, that's also valid + if isinstance(v, str): + return v + msg = "file_path must be a string, list of strings, or None" + raise ValueError(msg) + @field_validator("file_types") @classmethod def validate_file_types(cls, v): diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 172837b4c817..e5caf6bdeaa0 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -23,7 +23,7 @@ from rich import print as rprint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from langflow.api import health_check_router, log_router, router +from langflow.api import health_check_router, log_router, router, router_v2 from langflow.initial_setup.setup import ( create_or_update_starter_projects, initialize_super_user_if_needed, @@ -239,6 +239,7 @@ async def flatten_query_string_lists(request: Request, call_next): router.include_router(mcp_router) app.include_router(router) + app.include_router(router_v2) app.include_router(health_check_router) app.include_router(log_router) diff --git a/src/backend/base/langflow/services/database/models/file/__init__.py b/src/backend/base/langflow/services/database/models/file/__init__.py new file mode 100644 index 000000000000..63fdee160eda --- /dev/null +++ b/src/backend/base/langflow/services/database/models/file/__init__.py @@ -0,0 +1,5 @@ +from .model import File + +__all__ = [ + "File", +] diff --git a/src/backend/base/langflow/services/database/models/file/crud.py b/src/backend/base/langflow/services/database/models/file/crud.py new file mode 100644 index 000000000000..e5ae1ad56c83 --- /dev/null +++ b/src/backend/base/langflow/services/database/models/file/crud.py @@ -0,0 +1,14 @@ +from uuid import UUID + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from langflow.services.database.models.file.model import File + + +async def get_file_by_id(db: AsyncSession, file_id: UUID) -> File | None: + if isinstance(file_id, str): + file_id = UUID(file_id) + stmt = select(File).where(File.id == file_id) + + return (await db.exec(stmt)).first() diff --git a/src/backend/base/langflow/services/database/models/file/model.py b/src/backend/base/langflow/services/database/models/file/model.py new file mode 100644 index 000000000000..e2be55dc955a --- /dev/null +++ b/src/backend/base/langflow/services/database/models/file/model.py @@ -0,0 +1,17 @@ +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + +from langflow.schema.serialize import UUIDstr + + +class File(SQLModel, table=True): # type: ignore[call-arg] + id: UUIDstr = Field(default_factory=uuid4, primary_key=True) + user_id: UUID = Field(foreign_key="user.id") + name: str = Field(unique=True, nullable=False) + path: str = Field(nullable=False) + size: int = Field(nullable=False) + provider: str | None = Field(default=None) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/src/backend/base/langflow/services/storage/local.py b/src/backend/base/langflow/services/storage/local.py index acfb42164d8c..743a60fdfb9c 100644 --- a/src/backend/base/langflow/services/storage/local.py +++ b/src/backend/base/langflow/services/storage/local.py @@ -108,3 +108,15 @@ async def delete_file(self, flow_id: str, file_name: str) -> None: async def teardown(self) -> None: """Perform any cleanup operations when the service is being torn down.""" # No specific teardown actions required for local + + async def get_file_size(self, flow_id: str, file_name: str) -> None: + """Get the size of a file in the local storage.""" + # Get the file size from the file path + file_path = self.data_dir / flow_id / file_name + if not await file_path.exists(): + logger.warning(f"File {file_name} not found in flow {flow_id}.") + msg = f"File {file_name} not found in flow {flow_id}" + raise FileNotFoundError(msg) + + file_size_stat = await file_path.stat() + return file_size_stat.st_size diff --git a/src/backend/tests/unit/api/v2/__init__.py b/src/backend/tests/unit/api/v2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/backend/tests/unit/api/v2/test_files.py b/src/backend/tests/unit/api/v2/test_files.py new file mode 100644 index 000000000000..599604b5a7f1 --- /dev/null +++ b/src/backend/tests/unit/api/v2/test_files.py @@ -0,0 +1,209 @@ +import asyncio +import tempfile +from contextlib import suppress +from pathlib import Path + +# we need to import tmpdir +import anyio +import pytest +from asgi_lifespan import LifespanManager +from httpx import ASGITransport, AsyncClient +from langflow.main import create_app +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.api_key.model import ApiKey +from langflow.services.database.models.user.model import User, UserRead +from langflow.services.database.utils import session_getter +from langflow.services.deps import get_db_service +from sqlalchemy.orm import selectinload +from sqlmodel import select + +from tests.conftest import _delete_transactions_and_vertex_builds + + +@pytest.fixture(name="files_created_api_key") +async def files_created_api_key(files_client, files_active_user): # noqa: ARG001 + hashed = get_password_hash("random_key") + api_key = ApiKey( + name="files_created_api_key", + user_id=files_active_user.id, + api_key="random_key", + hashed_api_key=hashed, + ) + db_manager = get_db_service() + async with session_getter(db_manager) as session: + stmt = select(ApiKey).where(ApiKey.api_key == api_key.api_key) + if existing_api_key := (await session.exec(stmt)).first(): + yield existing_api_key + return + session.add(api_key) + await session.commit() + await session.refresh(api_key) + yield api_key + # Clean up + await session.delete(api_key) + await session.commit() + + +@pytest.fixture(name="files_active_user") +async def files_active_user(files_client): # noqa: ARG001 + db_manager = get_db_service() + async with db_manager.with_session() as session: + user = User( + username="files_active_user", + password=get_password_hash("testpassword"), + is_active=True, + is_superuser=False, + ) + stmt = select(User).where(User.username == user.username) + if active_user := (await session.exec(stmt)).first(): + user = active_user + else: + session.add(user) + await session.commit() + await session.refresh(user) + user = UserRead.model_validate(user, from_attributes=True) + yield user + # Clean up + # Now cleanup transactions, vertex_build + async with db_manager.with_session() as session: + user = await session.get(User, user.id, options=[selectinload(User.flows)]) + await _delete_transactions_and_vertex_builds(session, user.flows) + await session.delete(user) + + await session.commit() + + +@pytest.fixture +def max_file_size_upload_fixture(monkeypatch): + monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "1") + yield + monkeypatch.undo() + + +@pytest.fixture +def max_file_size_upload_10mb_fixture(monkeypatch): + monkeypatch.setenv("LANGFLOW_MAX_FILE_SIZE_UPLOAD", "10") + yield + monkeypatch.undo() + + +@pytest.fixture(name="files_client") +async def files_client_fixture( + monkeypatch, + request, +): + # Set the database url to a test database + if "noclient" in request.keywords: + yield + else: + + def init_app(): + db_dir = tempfile.mkdtemp() + db_path = Path(db_dir) / "test.db" + monkeypatch.setenv("LANGFLOW_DATABASE_URL", f"sqlite:///{db_path}") + monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") + from langflow.services.manager import service_manager + + service_manager.factories.clear() + service_manager.services.clear() # Clear the services cache + app = create_app() + return app, db_path + + app, db_path = await asyncio.to_thread(init_app) + + async with ( + LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager, + AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client, + ): + yield client + # app.dependency_overrides.clear() + monkeypatch.undo() + # clear the temp db + with suppress(FileNotFoundError): + await anyio.Path(db_path).unlink() + + +async def test_upload_file(files_client, files_created_api_key): + headers = {"x-api-key": files_created_api_key.api_key} + + response = await files_client.post( + "api/v2/files", + files={"file": ("test.txt", b"test content")}, + headers=headers, + ) + assert response.status_code == 201, f"Expected 201, got {response.status_code}: {response.json()}" + + response_json = response.json() + assert "id" in response_json + + +async def test_download_file(files_client, files_created_api_key): + headers = {"x-api-key": files_created_api_key.api_key} + + # First upload a file + response = await files_client.post( + "api/v2/files", + files={"file": ("test.txt", b"test content")}, + headers=headers, + ) + assert response.status_code == 201 + upload_response = response.json() + + # Then try to download it + response = await files_client.get(f"api/v2/files/{upload_response['id']}", headers=headers) + + assert response.status_code == 200 + assert response.content == b"test content" + + +async def test_list_files(files_client, files_created_api_key): + headers = {"x-api-key": files_created_api_key.api_key} + + # First upload a file + response = await files_client.post( + "api/v2/files", + files={"file": ("test.txt", b"test content")}, + headers=headers, + ) + assert response.status_code == 201 + + # Then list the files + response = await files_client.get("api/v2/files", headers=headers) + assert response.status_code == 200 + files = response.json() + assert len(files) == 1 + + +async def test_delete_file(files_client, files_created_api_key): + headers = {"x-api-key": files_created_api_key.api_key} + + response = await files_client.post( + "api/v2/files", + files={"file": ("test.txt", b"test content")}, + headers=headers, + ) + assert response.status_code == 201 + upload_response = response.json() + + response = await files_client.delete(f"api/v2/files/{upload_response['id']}", headers=headers) + assert response.status_code == 200 + assert response.json() == {"message": "File deleted successfully"} + + +async def test_edit_file(files_client, files_created_api_key): + headers = {"x-api-key": files_created_api_key.api_key} + + # First upload a file + response = await files_client.post( + "api/v2/files", + files={"file": ("test.txt", b"test content")}, + headers=headers, + ) + assert response.status_code == 201 + upload_response = response.json() + + # Then list the files + response = await files_client.put(f"api/v2/files/{upload_response['id']}?name=potato.txt", headers=headers) + assert response.status_code == 200 + file = response.json() + assert file["name"] == "potato.txt" diff --git a/src/frontend/src/customization/config-constants.ts b/src/frontend/src/customization/config-constants.ts index c27ebc013cf5..a6ec780faa6d 100644 --- a/src/frontend/src/customization/config-constants.ts +++ b/src/frontend/src/customization/config-constants.ts @@ -1,7 +1,7 @@ export const BASENAME = ""; export const PORT = 3000; export const PROXY_TARGET = "http://127.0.0.1:7860"; -export const API_ROUTES = ["^/api/v1/", "/health"]; +export const API_ROUTES = ["^/api/v1/", "/api/v2/", "/health"]; export const BASE_URL_API = "/api/v1/"; export const HEALTH_CHECK_URL = "/health_check"; export const DOCS_LINK = "https://docs.langflow.org";