Skip to content

Commit 57a0292

Browse files
authored
Add validation for configurable keys passed to .with_config() (langchain-ai#11910)
- Fix some typing issues found while doing that <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
2 parents 42cd2ef + 778e7c5 commit 57a0292

File tree

6 files changed

+58
-35
lines changed

6 files changed

+58
-35
lines changed

libs/langchain/langchain/runnables/openai_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing_extensions import TypedDict
55

66
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
7-
from langchain.schema.output import ChatGeneration
7+
from langchain.schema.messages import BaseMessage
88
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
99

1010

@@ -19,7 +19,7 @@ class OpenAIFunction(TypedDict):
1919
"""The parameters to the function."""
2020

2121

22-
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
22+
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
2323
"""A runnable that routes to the selected function."""
2424

2525
functions: Optional[List[OpenAIFunction]]

libs/langchain/langchain/schema/runnable/_locals.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Union,
1212
)
1313

14-
from langchain.schema.runnable.base import Input, Output, RunnableSerializable
14+
from langchain.schema.runnable.base import Input, Other, Output, RunnableSerializable
1515
from langchain.schema.runnable.config import RunnableConfig
1616
from langchain.schema.runnable.passthrough import RunnablePassthrough
1717

@@ -36,7 +36,7 @@ def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
3636

3737
def _concat_put(
3838
self,
39-
input: Input,
39+
input: Other,
4040
*,
4141
config: Optional[RunnableConfig] = None,
4242
replace: bool = False,
@@ -68,35 +68,35 @@ def _concat_put(
6868
f"{(type(self.key))}."
6969
)
7070

71-
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
71+
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
7272
self._concat_put(input, config=config, replace=True)
7373
return super().invoke(input, config=config)
7474

7575
async def ainvoke(
7676
self,
77-
input: Input,
77+
input: Other,
7878
config: Optional[RunnableConfig] = None,
7979
**kwargs: Optional[Any],
80-
) -> Input:
80+
) -> Other:
8181
self._concat_put(input, config=config, replace=True)
8282
return await super().ainvoke(input, config=config)
8383

8484
def transform(
8585
self,
86-
input: Iterator[Input],
86+
input: Iterator[Other],
8787
config: Optional[RunnableConfig] = None,
8888
**kwargs: Optional[Any],
89-
) -> Iterator[Input]:
89+
) -> Iterator[Other]:
9090
for chunk in super().transform(input, config=config):
9191
self._concat_put(chunk, config=config)
9292
yield chunk
9393

9494
async def atransform(
9595
self,
96-
input: AsyncIterator[Input],
96+
input: AsyncIterator[Other],
9797
config: Optional[RunnableConfig] = None,
9898
**kwargs: Optional[Any],
99-
) -> AsyncIterator[Input]:
99+
) -> AsyncIterator[Other]:
100100
async for chunk in super().atransform(input, config=config):
101101
self._concat_put(chunk, config=config)
102102
yield chunk

libs/langchain/langchain/schema/runnable/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,6 +2296,26 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
22962296
class Config:
22972297
arbitrary_types_allowed = True
22982298

2299+
def __init__(
2300+
self,
2301+
*,
2302+
bound: Runnable[Input, Output],
2303+
kwargs: Mapping[str, Any],
2304+
config: Optional[Mapping[str, Any]] = None,
2305+
**other_kwargs: Any,
2306+
) -> None:
2307+
config = config or {}
2308+
# config_specs contains the list of valid `configurable` keys
2309+
if configurable := config.get("configurable", None):
2310+
allowed_keys = set(s.id for s in bound.config_specs)
2311+
for key in configurable:
2312+
if key not in allowed_keys:
2313+
raise ValueError(
2314+
f"Configurable key '{key}' not found in runnable with"
2315+
f" config keys: {allowed_keys}"
2316+
)
2317+
super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs)
2318+
22992319
@property
23002320
def InputType(self) -> Type[Input]:
23012321
return self.bound.InputType

libs/langchain/langchain/schema/runnable/passthrough.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from langchain.pydantic_v1 import BaseModel, create_model
2424
from langchain.schema.runnable.base import (
25-
Input,
25+
Other,
2626
Runnable,
2727
RunnableParallel,
2828
RunnableSerializable,
@@ -33,17 +33,17 @@
3333
from langchain.utils.iter import safetee
3434

3535

36-
def identity(x: Input) -> Input:
36+
def identity(x: Other) -> Other:
3737
"""An identity function"""
3838
return x
3939

4040

41-
async def aidentity(x: Input) -> Input:
41+
async def aidentity(x: Other) -> Other:
4242
"""An async identity function"""
4343
return x
4444

4545

46-
class RunnablePassthrough(RunnableSerializable[Input, Input]):
46+
class RunnablePassthrough(RunnableSerializable[Other, Other]):
4747
"""A runnable to passthrough inputs unchanged or with additional keys.
4848
4949
This runnable behaves almost like the identity function, except that it
@@ -100,20 +100,20 @@ def fake_llm(prompt: str) -> str: # Fake LLM for the example
100100
# {'llm1': 'completion', 'llm2': 'completion', 'total_chars': 20}
101101
"""
102102

103-
input_type: Optional[Type[Input]] = None
103+
input_type: Optional[Type[Other]] = None
104104

105-
func: Optional[Callable[[Input], None]] = None
105+
func: Optional[Callable[[Other], None]] = None
106106

107-
afunc: Optional[Callable[[Input], Awaitable[None]]] = None
107+
afunc: Optional[Callable[[Other], Awaitable[None]]] = None
108108

109109
def __init__(
110110
self,
111111
func: Optional[
112-
Union[Callable[[Input], None], Callable[[Input], Awaitable[None]]]
112+
Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]]
113113
] = None,
114-
afunc: Optional[Callable[[Input], Awaitable[None]]] = None,
114+
afunc: Optional[Callable[[Other], Awaitable[None]]] = None,
115115
*,
116-
input_type: Optional[Type[Input]] = None,
116+
input_type: Optional[Type[Other]] = None,
117117
**kwargs: Any,
118118
) -> None:
119119
if inspect.iscoroutinefunction(func):
@@ -161,17 +161,17 @@ def assign(
161161
"""
162162
return RunnableAssign(RunnableParallel(kwargs))
163163

164-
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
164+
def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other:
165165
if self.func is not None:
166166
self.func(input)
167167
return self._call_with_config(identity, input, config)
168168

169169
async def ainvoke(
170170
self,
171-
input: Input,
171+
input: Other,
172172
config: Optional[RunnableConfig] = None,
173173
**kwargs: Optional[Any],
174-
) -> Input:
174+
) -> Other:
175175
if self.afunc is not None:
176176
await self.afunc(input, **kwargs)
177177
elif self.func is not None:
@@ -180,10 +180,10 @@ async def ainvoke(
180180

181181
def transform(
182182
self,
183-
input: Iterator[Input],
183+
input: Iterator[Other],
184184
config: Optional[RunnableConfig] = None,
185185
**kwargs: Any,
186-
) -> Iterator[Input]:
186+
) -> Iterator[Other]:
187187
if self.func is None:
188188
for chunk in self._transform_stream_with_config(input, identity, config):
189189
yield chunk
@@ -202,10 +202,10 @@ def transform(
202202

203203
async def atransform(
204204
self,
205-
input: AsyncIterator[Input],
205+
input: AsyncIterator[Other],
206206
config: Optional[RunnableConfig] = None,
207207
**kwargs: Any,
208-
) -> AsyncIterator[Input]:
208+
) -> AsyncIterator[Other]:
209209
if self.afunc is None and self.func is None:
210210
async for chunk in self._atransform_stream_with_config(
211211
input, identity, config
@@ -231,19 +231,19 @@ async def atransform(
231231

232232
def stream(
233233
self,
234-
input: Input,
234+
input: Other,
235235
config: Optional[RunnableConfig] = None,
236236
**kwargs: Any,
237-
) -> Iterator[Input]:
237+
) -> Iterator[Other]:
238238
return self.transform(iter([input]), config, **kwargs)
239239

240240
async def astream(
241241
self,
242-
input: Input,
242+
input: Other,
243243
config: Optional[RunnableConfig] = None,
244244
**kwargs: Any,
245-
) -> AsyncIterator[Input]:
246-
async def input_aiter() -> AsyncIterator[Input]:
245+
) -> AsyncIterator[Other]:
246+
async def input_aiter() -> AsyncIterator[Other]:
247247
yield input
248248

249249
async for chunk in self.atransform(input_aiter(), config, **kwargs):

libs/langchain/langchain/schema/runnable/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
Union,
2525
)
2626

27-
Input = TypeVar("Input")
27+
Input = TypeVar("Input", contravariant=True)
2828
# Output type should implement __concat__, as eg str, list, dict do
29-
Output = TypeVar("Output")
29+
Output = TypeVar("Output", covariant=True)
3030

3131

3232
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:

libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,9 @@ def test_configurable_fields_example() -> None:
10021002
},
10031003
}
10041004

1005+
with pytest.raises(ValueError):
1006+
chain_configurable.with_config(configurable={"llm123": "chat"})
1007+
10051008
assert (
10061009
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(
10071010
{"name": "John"}

0 commit comments

Comments
 (0)