Skip to content

Commit

Permalink
refactor: Implement unified serialization function (#6044)
Browse files Browse the repository at this point in the history
* feat: Implement serialization functions for various data types and add a unified serialize method

* feat: Enhance serialization by adding support for primitive types, enums, and generic types

* fix: Update Pinecone integration to use VectorStore and handle import errors gracefully

* test: Add hypothesis-based tests for serialization functions across various data types

* refactor: Replace custom serialization logic with unified serialize function for consistency and maintainability

* refactor: Replace recursive serialization function with unified serialize method for improved clarity and maintainability

* refactor: Replace custom serialization logic with unified serialize function for improved consistency and clarity

* refactor: Enhance serialization logic by adding instance handling and streamlining type checks

* refactor: Remove custom dictionary serialization from ResultDataResponse for streamlined handling

* refactor: Enhance serialization in ResultDataResponse by adding max_items_length for improved handling of outputs, logs, messages, and artifacts

* refactor: Move MAX_ITEMS_LENGTH and MAX_TEXT_LENGTH constants to serialization module for better organization

* refactor: Simplify message serialization in Log model by utilizing unified serialize function

* refactor: Remove unnecessary pytest marker from TestSerializationHypothesis class

* optimize _serialize_bytes

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>

* feat: Add support for numpy integer type serialization

* feat: Enhance serialization with support for pandas and numpy types

* test: Add comprehensive serialization tests for numpy and pandas types

* fix: Update _serialize_dispatcher to return string representation for unsupported types

* fix: Update _serialize_dispatcher to return the object directly instead of its string representation

* optmize conditional

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>

* optimize length check

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>

* fix: Update string and list truncation to include ellipsis for clarity

* fix: Update _serialize_primitive to exclude string type from primitive handling

* feat: Enhance serialization to handle numpy types and introduce unserializable sentinel

* fix: Update test cases for serialization of numpy boolean values for consistency

---------

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
  • Loading branch information
ogabrielluiz and codeflash-ai[bot] authored Feb 3, 2025
1 parent 5bcf4d0 commit c73070c
Show file tree
Hide file tree
Showing 20 changed files with 696 additions and 186 deletions.
62 changes: 7 additions & 55 deletions src/backend/base/langflow/api/v1/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Any
Expand All @@ -11,13 +10,14 @@
from langflow.schema import dotdict
from langflow.schema.graph import Tweaks
from langflow.schema.schema import InputType, OutputType, OutputValue
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
from langflow.serialization.serialization import serialize
from langflow.services.database.models.api_key.model import ApiKeyRead
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import FlowCreate, FlowRead
from langflow.services.database.models.user import UserRead
from langflow.services.settings.feature_flags import FeatureFlags
from langflow.services.tracing.schema import Log
from langflow.utils.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
from langflow.utils.util_strings import truncate_long_strings


Expand Down Expand Up @@ -270,65 +270,17 @@ class ResultDataResponse(BaseModel):
@classmethod
def serialize_results(cls, v):
"""Serialize results with custom handling for special types and truncation."""
if isinstance(v, dict):
return {key: cls._serialize_and_truncate(val, max_length=MAX_TEXT_LENGTH) for key, val in v.items()}
return cls._serialize_and_truncate(v, max_length=MAX_TEXT_LENGTH)

@staticmethod
def _serialize_and_truncate(obj: Any, max_length: int = MAX_TEXT_LENGTH) -> Any:
"""Helper method to serialize and truncate values."""
if isinstance(obj, bytes):
obj = obj.decode("utf-8", errors="ignore")
if len(obj) > max_length:
return f"{obj[:max_length]}... [truncated]"
return obj
if isinstance(obj, str):
if len(obj) > max_length:
return f"{obj[:max_length]}... [truncated]"
return obj
if isinstance(obj, datetime):
return obj.replace(tzinfo=timezone.utc).isoformat()
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, OutputValue | Log):
# First serialize the model
serialized = obj.model_dump()
# Then recursively truncate all values in the serialized dict
for key, value in serialized.items():
# Handle string values directly to ensure proper truncation
if isinstance(value, str) and len(value) > max_length:
serialized[key] = f"{value[:max_length]}... [truncated]"
else:
serialized[key] = ResultDataResponse._serialize_and_truncate(value, max_length=max_length)
return serialized
if isinstance(obj, BaseModel):
# For other BaseModel instances, serialize all fields
serialized = obj.model_dump()
return {
k: ResultDataResponse._serialize_and_truncate(v, max_length=max_length) for k, v in serialized.items()
}
if isinstance(obj, dict):
return {k: ResultDataResponse._serialize_and_truncate(v, max_length=max_length) for k, v in obj.items()}
if isinstance(obj, list | tuple):
# If list is too long, truncate it
if len(obj) > MAX_ITEMS_LENGTH:
truncated_list = list(obj)[:MAX_ITEMS_LENGTH]
truncated_list.append(f"... [truncated {len(obj) - MAX_ITEMS_LENGTH} items]")
obj = truncated_list
return [ResultDataResponse._serialize_and_truncate(item, max_length=max_length) for item in obj]
return obj
return serialize(v, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH)

@model_serializer(mode="plain")
def serialize_model(self) -> dict:
"""Custom serializer for the entire model."""
return {
"results": self.serialize_results(self.results),
"outputs": self._serialize_and_truncate(self.outputs, max_length=MAX_TEXT_LENGTH),
"logs": self._serialize_and_truncate(self.logs, max_length=MAX_TEXT_LENGTH),
"message": self._serialize_and_truncate(self.message, max_length=MAX_TEXT_LENGTH),
"artifacts": self._serialize_and_truncate(self.artifacts, max_length=MAX_TEXT_LENGTH),
"outputs": serialize(self.outputs, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
"logs": serialize(self.logs, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
"message": serialize(self.message, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
"artifacts": serialize(self.artifacts, max_length=MAX_TEXT_LENGTH, max_items=MAX_ITEMS_LENGTH),
"timedelta": self.timedelta,
"duration": self.duration,
"used_frozen_result": self.used_frozen_result,
Expand Down
12 changes: 9 additions & 3 deletions src/backend/base/langflow/components/vectorstores/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from langchain_pinecone import Pinecone
from langchain_core.vectorstores import VectorStore

from langflow.base.vectorstores.model import LCVectorStoreComponent, check_cached_vector_store
from langflow.helpers.data import docs_to_data
Expand Down Expand Up @@ -42,8 +42,14 @@ class PineconeVectorStoreComponent(LCVectorStoreComponent):
]

@check_cached_vector_store
def build_vector_store(self) -> Pinecone:
def build_vector_store(self) -> VectorStore:
"""Build and return a Pinecone vector store instance."""
try:
from langchain_pinecone import PineconeVectorStore
except ImportError as e:
msg = "langchain-pinecone is not installed. Please install it with `pip install langchain-pinecone`."
raise ValueError(msg) from e

try:
from langchain_pinecone._utilities import DistanceStrategy

Expand All @@ -55,7 +61,7 @@ def build_vector_store(self) -> Pinecone:
distance_strategy = DistanceStrategy[distance_strategy]

# Initialize Pinecone instance with wrapped embeddings
pinecone = Pinecone(
pinecone = PineconeVectorStore(
index_name=self.index_name,
embedding=wrapped_embeddings, # Use wrapped embeddings
text_key=self.text_key,
Expand Down
6 changes: 3 additions & 3 deletions src/backend/base/langflow/graph/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pydantic import BaseModel, Field, field_serializer, model_validator

from langflow.graph.utils import serialize_field
from langflow.schema.schema import OutputValue, StreamURL
from langflow.serialization import serialize
from langflow.utils.schemas import ChatOutputResponse, ContainsEnumMeta


Expand All @@ -23,8 +23,8 @@ class ResultData(BaseModel):
@field_serializer("results")
def serialize_results(self, value):
if isinstance(value, dict):
return {key: serialize_field(val) for key, val in value.items()}
return serialize_field(value)
return {key: serialize(val) for key, val in value.items()}
return serialize(value)

@model_validator(mode="before")
@classmethod
Expand Down
32 changes: 3 additions & 29 deletions src/backend/base/langflow/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from typing import TYPE_CHECKING, Any
from uuid import UUID

from langchain_core.documents import Document
from loguru import logger
from pydantic import BaseModel
from pydantic.v1 import BaseModel as V1BaseModel

from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema.data import Data
from langflow.schema.message import Message
from langflow.serialization import serialize
from langflow.services.database.models.transactions.crud import log_transaction as crud_log_transaction
from langflow.services.database.models.transactions.model import TransactionBase
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
Expand Down Expand Up @@ -68,30 +66,6 @@ def flatten_list(list_of_lists: list[list | Any]) -> list:
return new_list


def serialize_field(value):
"""Serialize field.
Unified serialization function for handling both BaseModel and Document types,
including handling lists of these types.
"""
if isinstance(value, list | tuple):
return [serialize_field(v) for v in value]
if isinstance(value, Document):
return value.to_json()
if isinstance(value, BaseModel):
return serialize_field(value.model_dump())
if isinstance(value, dict):
return {k: serialize_field(v) for k, v in value.items()}
if isinstance(value, V1BaseModel):
if hasattr(value, "to_json"):
return value.to_json()
return value.dict()
# Handle datetime objects
if hasattr(value, "isoformat"):
return value.isoformat()
return str(value)


def get_artifact_type(value, build_result) -> str:
result = ArtifactType.UNKNOWN
match value:
Expand Down Expand Up @@ -186,9 +160,9 @@ async def log_vertex_build(
valid=valid,
params=str(params) if params else None,
# Serialize data using our custom serializer
data=serialize_field(data),
data=serialize(data),
# Serialize artifacts using our custom serializer
artifacts=serialize_field(artifacts) if artifacts else None,
artifacts=serialize(artifacts) if artifacts else None,
)
async with session_getter(get_db_service()) as session:
inserted = await crud_log_vertex_build(session, vertex_build)
Expand Down
5 changes: 3 additions & 2 deletions src/backend/base/langflow/graph/vertex/vertex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from loguru import logger

from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, log_vertex_build, rewrite_file_path, serialize_field
from langflow.graph.utils import UnbuiltObject, log_vertex_build, rewrite_file_path
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.exceptions import NoComponentInstanceError
from langflow.schema import Data
from langflow.schema.artifact import ArtifactType
from langflow.schema.message import Message
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.serialization import serialize
from langflow.template.field.base import UNDEFINED, Output
from langflow.utils.schemas import ChatOutputResponse, DataOutputResponse
from langflow.utils.util import unescape_string
Expand Down Expand Up @@ -478,6 +479,6 @@ def built_object_repr(self):


def dict_to_codeblock(d: dict) -> str:
serialized = {key: serialize_field(val) for key, val in d.items()}
serialized = {key: serialize(val) for key, val in d.items()}
json_str = json.dumps(serialized, indent=4)
return f"```json\n{json_str}\n```"
4 changes: 2 additions & 2 deletions src/backend/base/langflow/schema/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langflow.schema.dataframe import DataFrame
from langflow.schema.encoders import CUSTOM_ENCODERS
from langflow.schema.message import Message
from langflow.schema.serialize import recursive_serialize_or_str
from langflow.serialization.serialization import serialize


class ArtifactType(str, Enum):
Expand Down Expand Up @@ -56,7 +56,7 @@ def _to_list_of_dicts(raw):
raw_ = []
for item in raw:
if hasattr(item, "dict") or hasattr(item, "model_dump"):
raw_.append(recursive_serialize_or_str(item))
raw_.append(serialize(item))
else:
raw_.append(str(item))
return raw_
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langflow.schema.data import Data
from langflow.schema.dataframe import DataFrame
from langflow.schema.message import Message
from langflow.schema.serialize import recursive_serialize_or_str
from langflow.serialization.serialization import serialize

INPUT_FIELD_NAME = "input_value"

Expand Down Expand Up @@ -110,7 +110,7 @@ def build_output_logs(vertex, result) -> dict:
case LogType.ARRAY:
if isinstance(message, DataFrame):
message = message.to_dict(orient="records")
message = [recursive_serialize_or_str(item) for item in message]
message = [serialize(item) for item in message]
name = output.get("name", f"output_{index}")
outputs |= {name: OutputValue(message=message, type=type_).model_dump()}

Expand Down
43 changes: 1 addition & 42 deletions src/backend/base/langflow/schema/serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from collections.abc import AsyncIterator, Generator, Iterator
from datetime import datetime
from typing import Annotated
from uuid import UUID

from loguru import logger
from pydantic import BaseModel, BeforeValidator
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic import BeforeValidator


def str_to_uuid(v: str | UUID) -> UUID:
Expand All @@ -15,40 +11,3 @@ def str_to_uuid(v: str | UUID) -> UUID:


UUIDstr = Annotated[UUID, BeforeValidator(str_to_uuid)]


def recursive_serialize_or_str(obj):
try:
if isinstance(obj, type) and issubclass(obj, BaseModel | BaseModelV1):
# This a type BaseModel and not an instance of it
return repr(obj)
if isinstance(obj, str):
return obj
if isinstance(obj, datetime):
return obj.isoformat()
if isinstance(obj, dict):
return {k: recursive_serialize_or_str(v) for k, v in obj.items()}
if isinstance(obj, list):
return [recursive_serialize_or_str(v) for v in obj]
if isinstance(obj, BaseModel | BaseModelV1):
if hasattr(obj, "model_dump"):
obj_dict = obj.model_dump()
elif hasattr(obj, "dict"):
obj_dict = obj.dict()
return {k: recursive_serialize_or_str(v) for k, v in obj_dict.items()}

if isinstance(obj, AsyncIterator | Generator | Iterator):
# contain memory addresses
# without consuming the iterator
# return list(obj) consumes the iterator
# return f"{obj}" this generates '<generator object BaseChatModel.stream at 0x33e9ec770>'
# it is not useful
return "Unconsumed Stream"
if hasattr(obj, "dict") and not isinstance(obj, type):
return {k: recursive_serialize_or_str(v) for k, v in obj.dict().items()}
if hasattr(obj, "model_dump") and not isinstance(obj, type):
return {k: recursive_serialize_or_str(v) for k, v in obj.model_dump().items()}
return str(obj)
except Exception: # noqa: BLE001
logger.debug(f"Cannot serialize object {obj}")
return str(obj)
3 changes: 3 additions & 0 deletions src/backend/base/langflow/serialization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .serialization import serialize

__all__ = ["serialize"]
2 changes: 2 additions & 0 deletions src/backend/base/langflow/serialization/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
MAX_TEXT_LENGTH = 20000
MAX_ITEMS_LENGTH = 1000
Loading

0 comments on commit c73070c

Please sign in to comment.