Skip to content

Commit

Permalink
Export merge_configs function (langchain-ai#11916)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Oct 17, 2023
2 parents 57a0292 + 19319e1 commit 2a8ded6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 30 deletions.
52 changes: 22 additions & 30 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
get_callback_manager_for_config,
get_config_list,
get_executor_for_config,
merge_configs,
patch_config,
)
from langchain.schema.runnable.utils import (
Expand Down Expand Up @@ -564,7 +565,12 @@ def with_config(
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
bound=self,
config=cast(
RunnableConfig,
{**(config or {}), **kwargs},
), # type: ignore[misc]
kwargs={},
)

def with_retry(
Expand Down Expand Up @@ -2291,7 +2297,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]):

kwargs: Mapping[str, Any]

config: Mapping[str, Any] = Field(default_factory=dict)
config: RunnableConfig = Field(default_factory=dict)

class Config:
arbitrary_types_allowed = True
Expand All @@ -2301,7 +2307,7 @@ def __init__(
*,
bound: Runnable[Input, Output],
kwargs: Mapping[str, Any],
config: Optional[Mapping[str, Any]] = None,
config: Optional[RunnableConfig] = None,
**other_kwargs: Any,
) -> None:
config = config or {}
Expand Down Expand Up @@ -2347,22 +2353,6 @@ def is_lc_serializable(cls) -> bool:
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]

def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
if key == "metadata":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
elif key == "tags":
copy[key] = (copy.get(key) or []) + config[key] # type: ignore
elif key == "configurable":
copy[key] = {**copy.get(key, {}), **config[key]} # type: ignore
else:
# Even though the keys aren't literals this is correct
# because both dicts are same type
copy[key] = config[key] or copy.get(key) # type: ignore
return copy

def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
Expand All @@ -2377,7 +2367,7 @@ def with_config(
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
)

def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
Expand All @@ -2395,7 +2385,7 @@ def invoke(
) -> Output:
return self.bound.invoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2407,7 +2397,7 @@ async def ainvoke(
) -> Output:
return await self.bound.ainvoke(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2421,11 +2411,12 @@ def batch(
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(
Expand All @@ -2445,11 +2436,12 @@ async def abatch(
) -> List[Output]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig], [self._merge_config(conf) for conf in config]
List[RunnableConfig],
[merge_configs(self.config, conf) for conf in config],
)
else:
configs = [
patch_config(self._merge_config(config), copy_locals=True)
patch_config(merge_configs(self.config, config), copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(
Expand All @@ -2467,7 +2459,7 @@ def stream(
) -> Iterator[Output]:
yield from self.bound.stream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2479,7 +2471,7 @@ async def astream(
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item
Expand All @@ -2492,7 +2484,7 @@ def transform(
) -> Iterator[Output]:
yield from self.bound.transform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
)

Expand All @@ -2504,7 +2496,7 @@ async def atransform(
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
self._merge_config(config),
merge_configs(self.config, config),
**{**self.kwargs, **kwargs},
):
yield item
Expand Down
25 changes: 25 additions & 0 deletions libs/langchain/langchain/schema/runnable/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,31 @@ def patch_config(
return config


def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base: RunnableConfig = {}
# Even though the keys aren't literals this is correct
# because both dicts are same type
for config in (c for c in configs if c is not None):
for key in config:
if key == "metadata":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
elif key == "tags":
base[key] = list( # type: ignore
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
)
elif key == "configurable":
base[key] = { # type: ignore
**base.get(key, {}), # type: ignore
**(config.get(key) or {}), # type: ignore
}
else:
base[key] = config[key] or base.get(key) # type: ignore
return base


def call_func_with_variable_args(
func: Union[
Callable[[Input], Output],
Expand Down

0 comments on commit 2a8ded6

Please sign in to comment.