Skip to content

Commit 2866467

Browse files
committed
fixes
1 parent 2787d64 commit 2866467

File tree

5 files changed

+75
-46
lines changed

5 files changed

+75
-46
lines changed

src/backend/base/langflow/custom/custom_component/component.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,6 @@ def _map_parameters_on_template(self, template: dict):
474474
for name, value in self._parameters.items():
475475
try:
476476
template[name]["value"] = value
477-
if value and "load_from_db" in template[name]:
478-
template[name]["load_from_db"] = False
479477
except KeyError:
480478
close_match = find_closest_match(name, list(template.keys()))
481479
if close_match:
@@ -499,6 +497,8 @@ def to_frontend_node(self):
499497
#! works and then update this later
500498
field_config = self.get_template_config(self)
501499
frontend_node = ComponentFrontendNode.from_inputs(**field_config)
500+
for key, value in self._inputs.items():
501+
frontend_node.set_field_load_from_db_in_template(key, False)
502502
self._map_parameters_on_frontend_node(frontend_node)
503503

504504
frontend_node_dict = frontend_node.to_dict(keep_name=False)

src/backend/base/langflow/interface/initialize/loading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def update_params_with_load_from_db_fields(
112112
try:
113113
key = custom_component.variables(params[field], field)
114114
except ValueError as e:
115-
# check if "User id is not set" is in the error message
116-
if "User id is not set" in str(e) and not fallback_to_env_vars:
115+
# check if "User id is not set" is in the error message, this is an internal bug
116+
if "User id is not set" in str(e):
117117
raise e
118118
logger.debug(str(e))
119119
if fallback_to_env_vars and key is None:

src/backend/base/langflow/template/frontend_node/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,9 @@ def set_field_value_in_template(self, field_name, value):
184184
if field.name == field_name:
185185
field.value = value
186186
break
187+
188+
def set_field_load_from_db_in_template(self, field_name, value):
189+
for field in self.template.fields:
190+
if field.name == field_name and hasattr(field, "load_from_db"):
191+
field.load_from_db = value
192+
break

src/backend/tests/integration/components/astra/test_astra_component.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import os
2+
from typing import List
23

34
from astrapy.db import AstraDB
45
import pytest
56

7+
from langflow.components.embeddings import OpenAIEmbeddingsComponent
8+
from langflow.custom import Component
9+
from langflow.inputs import StrInput
10+
from langflow.template import Output
611
from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key
7-
from tests.integration.utils import MockEmbeddings, check_env_vars, valid_nvidia_vectorize_region
12+
from tests.integration.utils import MockEmbeddings, check_env_vars, valid_nvidia_vectorize_region, ComponentInputHandle
813
from langchain_core.documents import Document
914

1015

@@ -38,46 +43,51 @@ def astradb_client(request):
3843

3944
@pytest.mark.api_key_required
4045
@pytest.mark.asyncio
41-
async def test_astra_setup(astradb_client: AstraDB):
46+
async def test_base(astradb_client: AstraDB):
4247
from langflow.components.embeddings import OpenAIEmbeddingsComponent
43-
open_ai_embeddings = OpenAIEmbeddingsComponent(openai_api_key=get_openai_api_key())
4448
application_token = get_astradb_application_token()
4549
api_endpoint = get_astradb_api_endpoint()
46-
#embedding = MockEmbeddings()
4750

4851

4952
results = await run_single_component(AstraVectorStoreComponent, inputs={
5053
"token": application_token,
5154
"api_endpoint": api_endpoint,
5255
"collection_name": BASIC_COLLECTION,
53-
"embedding": open_ai_embeddings,
56+
"embedding": ComponentInputHandle(clazz=OpenAIEmbeddingsComponent, inputs={"openai_api_key": get_openai_api_key()}, output_name="embeddings"),
5457
})
55-
print(results)
58+
from langchain_core.vectorstores import VectorStoreRetriever
59+
assert isinstance(results["base_retriever"], VectorStoreRetriever)
60+
assert results["vector_store"] is not None
61+
assert results["search_results"] == []
5662
assert astradb_client.collection(BASIC_COLLECTION)
5763

64+
65+
class TextToData(Component):
66+
inputs = [
67+
StrInput(name="text_data",is_list=True)
68+
]
69+
outputs = [
70+
Output(name="data", display_name="Data", method="create_data")
71+
]
72+
def create_data(self) -> List[Data]:
73+
return [Data(text=t) for t in self.text_data]
5874
@pytest.mark.api_key_required
59-
def test_astra_embeds_and_search():
75+
@pytest.mark.asyncio
76+
async def test_astra_embeds_and_search():
6077
application_token = get_astradb_application_token()
6178
api_endpoint = get_astradb_api_endpoint()
62-
embedding = MockEmbeddings()
63-
64-
documents = [Document(page_content="test1"), Document(page_content="test2")]
65-
records = [Data.from_document(d) for d in documents]
66-
67-
component = AstraVectorStoreComponent()
68-
component.build(
69-
token=application_token,
70-
api_endpoint=api_endpoint,
71-
collection_name=SEARCH_COLLECTION,
72-
embedding=embedding,
73-
ingest_data=records,
74-
search_input="test1",
75-
number_of_results=1,
76-
)
77-
component.build_vector_store()
78-
records = component.search_documents()
79-
80-
assert len(records) == 1
79+
80+
results = await run_single_component(AstraVectorStoreComponent, inputs={
81+
"token": application_token,
82+
"api_endpoint": api_endpoint,
83+
"collection_name": BASIC_COLLECTION,
84+
"number_of_results": 1,
85+
"search_input":"test1",
86+
"ingest_data": ComponentInputHandle(clazz=TextToData, inputs={"text_data": ["test1", "test2"]}, output_name="data"),
87+
"embedding": ComponentInputHandle(clazz=OpenAIEmbeddingsComponent,
88+
inputs={"openai_api_key": get_openai_api_key()}, output_name="embeddings"),
89+
})
90+
assert len(results["search_results"]) == 1
8191

8292

8393

src/backend/tests/integration/utils.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,25 +127,38 @@ async def run_flow(graph: Graph, run_input: Optional[Any] = None,
127127
return outputs
128128

129129

130-
async def run_single_component(clazz: type, inputs: dict = None, run_input: Optional[Any] = None,
131-
session_id: Optional[str] = None) -> dict[str, Any]:
132130

133-
raw_inputs = {}
134-
for key, value in inputs.items():
135-
if not isinstance(value, BaseComponent):
136-
raw_inputs[key] = value
137-
component = clazz(
138-
**raw_inputs
139-
)
131+
@dataclasses.dataclass
132+
class ComponentInputHandle:
133+
clazz: type
134+
inputs: dict
135+
output_name: str
140136

137+
async def run_single_component(clazz: type, inputs: dict = None, run_input: Optional[Any] = None,
138+
session_id: Optional[str] = None) -> dict[str, Any]:
139+
user_id = str(uuid.uuid4())
141140
flow_id = str(uuid.uuid4())
142-
graph = Graph(user_id=str(uuid.uuid4()), flow_id=flow_id)
143-
component_id = graph.add_component(component)
144-
for input_name, input_value in inputs.items():
145-
if isinstance(input_value, Component):
146-
graph.add_component(input_value)
147-
graph.add_component_edge(input_value._id, (input_value.outputs[0].name, input_name), component._id)
148-
print("added edge")
141+
graph = Graph(user_id=user_id, flow_id=flow_id)
142+
143+
def _add_component(clazz: type, inputs: dict = None) -> str:
144+
raw_inputs = {}
145+
for key, value in inputs.items():
146+
if not isinstance(value, ComponentInputHandle):
147+
raw_inputs[key] = value
148+
if isinstance(value, Component):
149+
raise ValueError("Component inputs must be wrapped in ComponentInputHandle")
150+
component = clazz(
151+
**raw_inputs,
152+
_user_id=user_id
153+
)
154+
component_id = graph.add_component(component)
155+
for input_name, handle in inputs.items():
156+
if isinstance(handle, ComponentInputHandle):
157+
handle_component_id = _add_component(handle.clazz, handle.inputs)
158+
graph.add_component_edge(handle_component_id, (handle.output_name, input_name), component_id)
159+
return component_id
160+
161+
component_id = _add_component(clazz, inputs)
149162
graph.prepare()
150163
if run_input:
151164
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")]

0 commit comments

Comments
 (0)