Skip to content

Commit

Permalink
fix: add default models to Anthropic and make sure template is updated (
Browse files Browse the repository at this point in the history
#5839)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
ogabrielluiz and autofix-ci[bot] authored Jan 21, 2025
1 parent 64d82d4 commit 050c12d
Show file tree
Hide file tree
Showing 19 changed files with 240 additions and 75 deletions.
35 changes: 4 additions & 31 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@
from uuid import UUID

import sqlalchemy as sa
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
HTTPException,
Request,
UploadFile,
status,
)
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from loguru import logger
Expand All @@ -36,11 +27,7 @@
UploadFileResponse,
)
from langflow.custom.custom_component.component import Component
from langflow.custom.utils import (
build_custom_component_template,
get_instance_name,
update_component_build_config,
)
from langflow.custom.utils import build_custom_component_template, get_instance_name, update_component_build_config
from langflow.events.event_manager import create_stream_tokens_event_manager
from langflow.exceptions.api import APIException, InvalidChatInputError
from langflow.exceptions.serialization import SerializationError
Expand All @@ -55,16 +42,9 @@
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.database.models.flow.utils import (
get_all_webhook_components_in_flow,
)
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_session_service,
get_settings_service,
get_task_service,
get_telemetry_service,
)
from langflow.services.deps import get_session_service, get_settings_service, get_task_service, get_telemetry_service
from langflow.services.settings.feature_flags import FEATURE_FLAGS
from langflow.services.telemetry.schema import RunPayload
from langflow.utils.version import get_version_info
Expand Down Expand Up @@ -720,7 +700,6 @@ async def custom_component_update(
user_id=user.id,
)

template_data = code_request.model_dump().get("template", {}).copy()
component_node["tool_mode"] = code_request.tool_mode

if hasattr(cc_instance, "set_attributes"):
Expand Down Expand Up @@ -749,12 +728,6 @@ async def custom_component_update(
)
component_node["template"] = updated_build_config

# Preserve previous field values by merging filtered template data into
# the component node's template. Only include entries where the value
# is a dictionary containing the key "value".
filtered_data = {k: v for k, v in template_data.items() if isinstance(v, dict) and "value" in v}
component_node["template"] |= filtered_data

if isinstance(cc_instance, Component):
await cc_instance.run_and_validate_update_outputs(
frontend_node=component_node,
Expand Down
17 changes: 6 additions & 11 deletions src/backend/base/langflow/base/models/model_input_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langflow.components.models.nvidia import NVIDIAModelComponent
from langflow.components.models.openai import OpenAIModelComponent
from langflow.inputs.inputs import InputTypes, SecretStrInput
from langflow.template.field.base import Input


class ModelProvidersDict(TypedDict):
Expand All @@ -24,7 +25,7 @@ def get_filtered_inputs(component_class):
return [process_inputs(input_) for input_ in component_instance.inputs if input_.name not in base_input_names]


def process_inputs(component_data):
def process_inputs(component_data: Input):
if isinstance(component_data, SecretStrInput):
component_data.value = ""
component_data.load_from_db = False
Expand Down Expand Up @@ -61,8 +62,8 @@ def add_combobox_true(component_input):
return component_input


def create_input_fields_dict(inputs, prefix):
return {f"{prefix}{input_.name}": input_ for input_ in inputs}
def create_input_fields_dict(inputs: list[Input], prefix: str) -> dict[str, Input]:
return {f"{prefix}{input_.name}": input_.to_dict() for input_ in inputs}


def _get_openai_inputs_and_fields():
Expand All @@ -73,7 +74,7 @@ def _get_openai_inputs_and_fields():
except ImportError as e:
msg = "OpenAI is not installed. Please install it with `pip install langchain-openai`."
raise ImportError(msg) from e
return openai_inputs, {input_.name: input_ for input_ in openai_inputs}
return openai_inputs, create_input_fields_dict(openai_inputs, "")


def _get_azure_inputs_and_fields():
Expand Down Expand Up @@ -204,10 +205,4 @@ def _get_amazon_bedrock_inputs_and_fields():
MODEL_PROVIDERS = list(MODEL_PROVIDERS_DICT.keys())
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]

MODEL_DYNAMIC_UPDATE_FIELDS = [
"api_key",
"model",
"tool_model_enabled",
"base_url",
"model_name",
]
MODEL_DYNAMIC_UPDATE_FIELDS = ["api_key", "model", "tool_model_enabled", "base_url", "model_name"]
13 changes: 7 additions & 6 deletions src/backend/base/langflow/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from langflow.base.models.model_utils import get_model_name
from langflow.components.helpers import CurrentDateComponent
from langflow.components.helpers.memory import MemoryComponent
from langflow.components.langchain_utilities.tool_calling import (
ToolCallingAgentComponent,
)
from langflow.components.langchain_utilities.tool_calling import ToolCallingAgentComponent
from langflow.custom.utils import update_component_build_config
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
from langflow.logging import logger
Expand Down Expand Up @@ -121,6 +119,8 @@ async def get_memory_data(self):
memory_kwargs = {
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
}
# filter out empty values
memory_kwargs = {k: v for k, v in memory_kwargs.items() if v}

return await MemoryComponent().set(**memory_kwargs).retrieve_messages()

Expand Down Expand Up @@ -177,13 +177,14 @@ async def update_build_config(
# Iterate over all providers in the MODEL_PROVIDERS_DICT
# Existing logic for updating build_config
if field_name in ("agent_llm",):
build_config["agent_llm"]["value"] = field_value
provider_info = MODEL_PROVIDERS_DICT.get(field_value)
if provider_info:
component_class = provider_info.get("component_class")
if component_class and hasattr(component_class, "update_build_config"):
# Call the component class's update_build_config method
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
component_class, build_config, field_value, "model_name"
)

provider_configs: dict[str, tuple[dict, list[dict]]] = {
Expand Down Expand Up @@ -261,6 +262,6 @@ async def update_build_config(
if isinstance(field_name, str) and isinstance(prefix, str):
field_name = field_name.replace(prefix, "")
build_config = await update_component_build_config(
component_class, build_config, field_value, field_name
component_class, build_config, field_value, "model_name"
)
return build_config
return {k: v.to_dict() if hasattr(v, "to_dict") else v for k, v in build_config.items()}
11 changes: 7 additions & 4 deletions src/backend/base/langflow/components/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class AnthropicModelComponent(LCModelComponent):
DropdownInput(
name="model_name",
display_name="Model Name",
options=[],
options=ANTHROPIC_MODELS,
refresh_button=True,
value=ANTHROPIC_MODELS[0],
),
SecretStrInput(
name="api_key",
Expand Down Expand Up @@ -138,14 +139,16 @@ def _get_exception_message(self, exception: Exception) -> str | None:
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name in ("base_url", "model_name", "tool_model_enabled", "api_key") and field_value:
try:
if len(self.api_key) != 0:
if len(self.api_key) == 0:
ids = ANTHROPIC_MODELS
else:
try:
ids = self.get_models(tool_model_enabled=self.tool_model_enabled)
except (ImportError, ValueError, requests.exceptions.RequestException) as e:
logger.exception(f"Error getting model names: {e}")
ids = ANTHROPIC_MODELS
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
except Exception as e:
msg = f"Error getting model names: {e}"
raise ValueError(msg) from e
Expand Down
3 changes: 2 additions & 1 deletion src/backend/base/langflow/components/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class GroqModel(LCModelComponent):
name="model_name",
display_name="Model",
info="The name of the model to use.",
options=[],
options=GROQ_MODELS,
value=GROQ_MODELS[0],
refresh_button=True,
real_time_refresh=True,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,11 @@ def _set_input_value(self, name: str, value: Any) -> None:
name, f"Input is connected to {input_value.__self__.display_name}.{input_value.__name__}"
)
raise ValueError(msg)
self._inputs[name].value = value
try:
self._inputs[name].value = value
except Exception as e:
msg = f"Error setting input value for {name}: {e}"
raise ValueError(msg) from e
if hasattr(self._inputs[name], "load_from_db"):
self._inputs[name].load_from_db = False
else:
Expand Down
Loading

0 comments on commit 050c12d

Please sign in to comment.