Skip to content

Commit e77eeee

Browse files
authored
core[patch]: add standard tracing params for retrievers (#25240)
1 parent 9927a48 commit e77eeee

File tree

6 files changed

+104
-3
lines changed

6 files changed

+104
-3
lines changed

libs/community/tests/unit_tests/retrievers/test_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ async def test_fake_retriever_v1_upgrade_async(
7474
assert callbacks.retriever_errors == 0
7575

7676

77+
def test_fake_retriever_v1_standard_params(fake_retriever_v1: BaseRetriever) -> None:
78+
ls_params = fake_retriever_v1._get_ls_params()
79+
assert ls_params == {"ls_retriever_name": "fakeretrieverv1"}
80+
81+
7782
@pytest.fixture
7883
def fake_retriever_v1_with_kwargs() -> BaseRetriever:
7984
# Test for things like the Weaviate V1 Retriever.
@@ -213,3 +218,8 @@ async def test_fake_retriever_v2_async(
213218
await fake_erroring_retriever_v2.ainvoke(
214219
"Foo", config={"callbacks": [callbacks]}
215220
)
221+
222+
223+
def test_fake_retriever_v2_standard_params(fake_retriever_v2: BaseRetriever) -> None:
224+
ls_params = fake_retriever_v2._get_ls_params()
225+
assert ls_params == {"ls_retriever_name": "fakeretrieverv2"}

libs/community/tests/unit_tests/retrievers/test_bedrock.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
3333
amazon_retriever.create_client({})
3434

3535

36+
def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
37+
ls_params = amazon_retriever._get_ls_params()
38+
assert ls_params == {"ls_retriever_name": "amazonknowledgebases"}
39+
40+
3641
def test_get_relevant_documents(
3742
amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock
3843
) -> None:

libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,28 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
633633
assert len(search_result) == 0
634634

635635

636+
@pytest.mark.requires("databricks", "databricks.vector_search")
637+
def test_standard_params() -> None:
638+
index = mock_index(DIRECT_ACCESS_INDEX)
639+
vectorstore = default_databricks_vector_search(index)
640+
retriever = vectorstore.as_retriever()
641+
ls_params = retriever._get_ls_params()
642+
assert ls_params == {
643+
"ls_retriever_name": "vectorstore",
644+
"ls_vector_store_provider": "DatabricksVectorSearch",
645+
"ls_embedding_provider": "FakeEmbeddingsWithDimension",
646+
}
647+
648+
index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS)
649+
vectorstore = default_databricks_vector_search(index)
650+
retriever = vectorstore.as_retriever()
651+
ls_params = retriever._get_ls_params()
652+
assert ls_params == {
653+
"ls_retriever_name": "vectorstore",
654+
"ls_vector_store_provider": "DatabricksVectorSearch",
655+
}
656+
657+
636658
@pytest.mark.requires("databricks", "databricks.vector_search")
637659
@pytest.mark.parametrize(
638660
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]

libs/community/tests/unit_tests/vectorstores/test_faiss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def test_faiss() -> None:
4949
output = docsearch.similarity_search("foo", k=1)
5050
assert output == [Document(page_content="foo")]
5151

52+
# Retriever standard params
53+
retriever = docsearch.as_retriever()
54+
ls_params = retriever._get_ls_params()
55+
assert ls_params == {
56+
"ls_retriever_name": "vectorstore",
57+
"ls_vector_store_provider": "FAISS",
58+
"ls_embedding_provider": "FakeEmbeddings",
59+
}
60+
5261

5362
@pytest.mark.requires("faiss")
5463
async def test_faiss_afrom_texts() -> None:

libs/core/langchain_core/retrievers.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from inspect import signature
2727
from typing import TYPE_CHECKING, Any, Dict, List, Optional
2828

29+
from typing_extensions import TypedDict
30+
2931
from langchain_core._api import deprecated
3032
from langchain_core.documents import Document
3133
from langchain_core.load.dump import dumpd
@@ -50,6 +52,19 @@
5052
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
5153

5254

55+
class LangSmithRetrieverParams(TypedDict, total=False):
56+
"""LangSmith parameters for tracing."""
57+
58+
ls_retriever_name: str
59+
"""Retriever name."""
60+
ls_vector_store_provider: Optional[str]
61+
"""Vector store provider."""
62+
ls_embedding_provider: Optional[str]
63+
"""Embedding provider."""
64+
ls_embedding_model: Optional[str]
65+
"""Embedding model."""
66+
67+
5368
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
5469
"""Abstract base class for a Document retrieval system.
5570
@@ -167,6 +182,19 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
167182
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
168183
)
169184

185+
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
186+
"""Get standard params for tracing."""
187+
188+
default_retriever_name = self.get_name()
189+
if default_retriever_name.startswith("Retriever"):
190+
default_retriever_name = default_retriever_name[9:]
191+
elif default_retriever_name.endswith("Retriever"):
192+
default_retriever_name = default_retriever_name[:-9]
193+
default_retriever_name = default_retriever_name.lower()
194+
195+
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
196+
return ls_params
197+
170198
def invoke(
171199
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
172200
) -> List[Document]:
@@ -191,13 +219,17 @@ def invoke(
191219
from langchain_core.callbacks.manager import CallbackManager
192220

193221
config = ensure_config(config)
222+
inheritable_metadata = {
223+
**(config.get("metadata") or {}),
224+
**self._get_ls_params(**kwargs),
225+
}
194226
callback_manager = CallbackManager.configure(
195227
config.get("callbacks"),
196228
None,
197229
verbose=kwargs.get("verbose", False),
198230
inheritable_tags=config.get("tags"),
199231
local_tags=self.tags,
200-
inheritable_metadata=config.get("metadata"),
232+
inheritable_metadata=inheritable_metadata,
201233
local_metadata=self.metadata,
202234
)
203235
run_manager = callback_manager.on_retriever_start(
@@ -250,13 +282,17 @@ async def ainvoke(
250282
from langchain_core.callbacks.manager import AsyncCallbackManager
251283

252284
config = ensure_config(config)
285+
inheritable_metadata = {
286+
**(config.get("metadata") or {}),
287+
**self._get_ls_params(**kwargs),
288+
}
253289
callback_manager = AsyncCallbackManager.configure(
254290
config.get("callbacks"),
255291
None,
256292
verbose=kwargs.get("verbose", False),
257293
inheritable_tags=config.get("tags"),
258294
local_tags=self.tags,
259-
inheritable_metadata=config.get("metadata"),
295+
inheritable_metadata=inheritable_metadata,
260296
local_metadata=self.metadata,
261297
)
262298
run_manager = await callback_manager.on_retriever_start(

libs/core/langchain_core/vectorstores/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
from langchain_core.embeddings import Embeddings
4646
from langchain_core.pydantic_v1 import Field, root_validator
47-
from langchain_core.retrievers import BaseRetriever
47+
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
4848
from langchain_core.runnables.config import run_in_executor
4949

5050
if TYPE_CHECKING:
@@ -1014,6 +1014,25 @@ def validate_search_type(cls, values: Dict) -> Dict:
10141014
)
10151015
return values
10161016

1017+
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
1018+
"""Get standard params for tracing."""
1019+
1020+
ls_params = super()._get_ls_params(**kwargs)
1021+
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__
1022+
1023+
if self.vectorstore.embeddings:
1024+
ls_params["ls_embedding_provider"] = (
1025+
self.vectorstore.embeddings.__class__.__name__
1026+
)
1027+
elif hasattr(self.vectorstore, "embedding") and isinstance(
1028+
self.vectorstore.embedding, Embeddings
1029+
):
1030+
ls_params["ls_embedding_provider"] = (
1031+
self.vectorstore.embedding.__class__.__name__
1032+
)
1033+
1034+
return ls_params
1035+
10171036
def _get_relevant_documents(
10181037
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
10191038
) -> List[Document]:

0 commit comments

Comments
 (0)