Skip to content

Commit

Permalink
refactor(api): change variable type annotations (#3592)
Browse files Browse the repository at this point in the history
* refactor(log_router.py): change variable type annotation from List to list for better consistency
refactor(utils.py): change variable type annotation from Dict to dict for better consistency
refactor(base.py): change variable type annotation from Optional to Union for better clarity
refactor(callback.py): change variable type annotation from Dict to dict for better consistency
refactor(chat.py): change variable type annotation from Optional to Union for better clarity
refactor(endpoints.py): change variable type annotation from Optional to Union for better clarity
refactor(flows.py): change variable type annotation from List to list for better consistency

refactor(api): update response_model annotations to use lowercase list for consistency and improve readability

refactor(store.py): update type annotations for query parameters in get_components endpoint to improve code readability and maintainability
feat(store.py): add support for type hinting Union and list types in query parameters for better data validation and documentation

* run make format
  • Loading branch information
ogabrielluiz authored Aug 28, 2024
1 parent 7641e81 commit d7dbf1a
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 124 deletions.
4 changes: 2 additions & 2 deletions src/backend/base/langflow/api/log_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import List, Any
from typing import Any

from fastapi import APIRouter, Query, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
Expand All @@ -15,7 +15,7 @@ async def event_generator(request: Request):
last_read_item = None
current_not_sent = 0
while not await request.is_disconnected():
to_write: List[Any] = []
to_write: list[Any] = []
with log_buffer.get_write_lock():
if last_read_item is None:
last_read_item = log_buffer.buffer[len(log_buffer.buffer) - 1]
Expand Down
6 changes: 3 additions & 3 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Dict
from typing import TYPE_CHECKING, Any

from fastapi import HTTPException
from sqlmodel import Session
Expand Down Expand Up @@ -122,7 +122,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


async def build_graph_from_data(flow_id: str, payload: Dict, **kwargs):
async def build_graph_from_data(flow_id: str, payload: dict, **kwargs):
"""Build and cache the graph."""
graph = Graph.from_payload(payload, flow_id, **kwargs)
for vertex_id in graph._has_session_id_vertices:
Expand All @@ -141,7 +141,7 @@ async def build_graph_from_data(flow_id: str, payload: Dict, **kwargs):

async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
flow: Flow | None = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))
Expand Down
8 changes: 3 additions & 5 deletions src/backend/base/langflow/api/v1/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from pydantic import BaseModel, field_validator, model_serializer

from langflow.template.frontend_node.base import FrontendNode
Expand All @@ -26,8 +24,8 @@ def serialize_model(self, handler):
class ValidatePromptRequest(BaseModel):
name: str
template: str
custom_fields: Optional[dict] = None
frontend_node: Optional[FrontendNodeRequest] = None
custom_fields: dict | None = None
frontend_node: FrontendNodeRequest | None = None


# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}}
Expand All @@ -49,4 +47,4 @@ def validate_function(cls, v):
class PromptValidationResponse(BaseModel):
input_variables: list
# object return for tweak call
frontend_node: Optional[FrontendNodeRequest] = None
frontend_node: FrontendNodeRequest | None = None
8 changes: 4 additions & 4 deletions src/backend/base/langflow/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from loguru import logger
Expand Down Expand Up @@ -32,7 +32,7 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())

async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
async def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> Any:
"""Run when tool starts running."""
resp = ChatResponse(
message="",
Expand Down Expand Up @@ -79,8 +79,8 @@ async def on_tool_error(
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
Expand Down
28 changes: 14 additions & 14 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traceback
import typing
import uuid
from typing import TYPE_CHECKING, Annotated, Optional
from typing import TYPE_CHECKING, Annotated

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -68,9 +68,9 @@ async def try_running_celery_task(vertex, user_id):
async def retrieve_vertices_order(
flow_id: uuid.UUID,
background_tasks: BackgroundTasks,
data: Optional[Annotated[Optional[FlowDataRequest], Body(embed=True)]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
session=Depends(get_session),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
Expand Down Expand Up @@ -141,12 +141,12 @@ async def retrieve_vertices_order(
async def build_flow(
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
data: Annotated[Optional[FlowDataRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
log_builds: Optional[bool] = True,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
data: Annotated[FlowDataRequest | None, Body(embed=True)] = None,
files: list[str] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
log_builds: bool | None = True,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
Expand Down Expand Up @@ -434,7 +434,7 @@ def __init__(
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
on_disconnect: Optional[typing.Callable] = None,
on_disconnect: typing.Callable | None = None,
):
super().__init__(content, status_code, headers, media_type, background)
self.on_disconnect = on_disconnect
Expand All @@ -453,8 +453,8 @@ async def build_vertex(
flow_id: uuid.UUID,
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
files: list[str] | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
Expand Down Expand Up @@ -606,7 +606,7 @@ async def build_vertex(
async def build_vertex_stream(
flow_id: uuid.UUID,
vertex_id: str,
session_id: Optional[str] = None,
session_id: str | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
session_service: "SessionService" = Depends(get_session_service),
):
Expand Down
22 changes: 11 additions & 11 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from asyncio import Lock
from http import HTTPStatus
from typing import TYPE_CHECKING, Annotated, List, Optional, Union
from typing import TYPE_CHECKING, Annotated
from uuid import UUID

import sqlalchemy as sa
Expand Down Expand Up @@ -108,12 +108,12 @@ async def simple_run_flow(
flow: Flow,
input_request: SimplifiedAPIRequest,
stream: bool = False,
api_key_user: Optional[User] = None,
api_key_user: User | None = None,
):
if input_request.input_value is not None and input_request.tweaks is not None:
validate_input_and_tweaks(input_request)
try:
task_result: List[RunOutputs] = []
task_result: list[RunOutputs] = []
user_id = api_key_user.id if api_key_user else None
flow_id_str = str(flow.id)
if flow.data is None:
Expand Down Expand Up @@ -155,7 +155,7 @@ async def simple_run_flow_task(
flow: Flow,
input_request: SimplifiedAPIRequest,
stream: bool = False,
api_key_user: Optional[User] = None,
api_key_user: User | None = None,
):
"""
Run a flow task as a BackgroundTask, therefore it should not throw exceptions.
Expand Down Expand Up @@ -362,11 +362,11 @@ async def webhook_run_flow(
async def experimental_run_flow(
session: Annotated[Session, Depends(get_session)],
flow_id: UUID,
inputs: Optional[List[InputValueRequest]] = [InputValueRequest(components=[], input_value="")],
outputs: Optional[List[str]] = [],
tweaks: Annotated[Optional[Tweaks], Body(embed=True)] = None, # noqa: F821
inputs: list[InputValueRequest] | None = [InputValueRequest(components=[], input_value="")],
outputs: list[str] | None = [],
tweaks: Annotated[Tweaks | None, Body(embed=True)] = None, # noqa: F821
stream: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
session_id: Annotated[None | str, Body(embed=True)] = None, # noqa: F821
api_key_user: UserRead = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
):
Expand Down Expand Up @@ -476,10 +476,10 @@ async def experimental_run_flow(
async def process(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[Union[List[dict], dict]] = None,
tweaks: Optional[dict] = None,
inputs: list[dict] | dict | None = None,
tweaks: dict | None = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
session_id: Annotated[None | str, Body(embed=True)] = None, # noqa: F821
task_service: "TaskService" = Depends(get_task_service),
api_key_user: UserRead = Depends(api_key_security),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
Expand Down
9 changes: 4 additions & 5 deletions src/backend/base/langflow/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
import zipfile
from datetime import datetime, timezone
from typing import List
from uuid import UUID

import orjson
Expand Down Expand Up @@ -273,7 +272,7 @@ def delete_flow(
return {"message": "Flow deleted successfully"}


@router.post("/batch/", response_model=List[FlowRead], status_code=201)
@router.post("/batch/", response_model=list[FlowRead], status_code=201)
def create_flows(
*,
session: Session = Depends(get_session),
Expand All @@ -293,7 +292,7 @@ def create_flows(
return db_flows


@router.post("/upload/", response_model=List[FlowRead], status_code=201)
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: Session = Depends(get_session),
Expand Down Expand Up @@ -322,7 +321,7 @@ async def upload_file(

@router.delete("/")
async def delete_multiple_flows(
flow_ids: List[UUID], user: User = Depends(get_current_active_user), db: Session = Depends(get_session)
flow_ids: list[UUID], user: User = Depends(get_current_active_user), db: Session = Depends(get_session)
):
"""
Delete multiple flows by their IDs.
Expand Down Expand Up @@ -357,7 +356,7 @@ async def delete_multiple_flows(

@router.post("/download/", status_code=200)
async def download_multiple_file(
flow_ids: List[UUID],
flow_ids: list[UUID],
user: User = Depends(get_current_active_user),
db: Session = Depends(get_session),
):
Expand Down
6 changes: 2 additions & 4 deletions src/backend/base/langflow/api/v1/folders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import orjson
from fastapi import APIRouter, Depends, File, HTTPException, Response, UploadFile, status
from sqlalchemy import or_, update
Expand Down Expand Up @@ -78,7 +76,7 @@ def create_folder(
raise HTTPException(status_code=500, detail=str(e))


@router.get("/", response_model=List[FolderRead], status_code=200)
@router.get("/", response_model=list[FolderRead], status_code=200)
def read_folders(
*,
session: Session = Depends(get_session),
Expand Down Expand Up @@ -211,7 +209,7 @@ async def download_file(
raise HTTPException(status_code=500, detail=str(e))


@router.post("/upload/", response_model=List[FlowRead], status_code=201)
@router.post("/upload/", response_model=list[FlowRead], status_code=201)
async def upload_file(
*,
session: Session = Depends(get_session),
Expand Down
17 changes: 8 additions & 9 deletions src/backend/base/langflow/api/v1/monitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query
Expand Down Expand Up @@ -44,13 +43,13 @@ async def delete_vertex_builds(
raise HTTPException(status_code=500, detail=str(e))


@router.get("/messages", response_model=List[MessageModelResponse])
@router.get("/messages", response_model=list[MessageModelResponse])
async def get_messages(
flow_id: Optional[str] = Query(None),
session_id: Optional[str] = Query(None),
sender: Optional[str] = Query(None),
sender_name: Optional[str] = Query(None),
order_by: Optional[str] = Query("timestamp"),
flow_id: str | None = Query(None),
session_id: str | None = Query(None),
sender: str | None = Query(None),
sender_name: str | None = Query(None),
order_by: str | None = Query("timestamp"),
session: Session = Depends(get_session),
):
try:
Expand All @@ -74,7 +73,7 @@ async def get_messages(

@router.delete("/messages", status_code=204)
async def delete_messages(
message_ids: List[UUID],
message_ids: list[UUID],
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
):
Expand Down Expand Up @@ -125,7 +124,7 @@ async def delete_messages_session(
raise HTTPException(status_code=500, detail=str(e))


@router.get("/transactions", response_model=List[TransactionReadResponse])
@router.get("/transactions", response_model=list[TransactionReadResponse])
async def get_transactions(
flow_id: UUID = Query(),
session: Session = Depends(get_session),
Expand Down
Loading

0 comments on commit d7dbf1a

Please sign in to comment.