-
Notifications
You must be signed in to change notification settings - Fork 0
/
base_strategy.py
124 lines (105 loc) Β· 3.84 KB
/
base_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
)
from langchain.chains.base import Chain
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class BaseCustomStrategy(Chain, ABC):
return_intermediate_steps: bool = False
return_finish_log: bool = False
max_iterations: int = 15
verbose: bool = True
@property
def input_keys(self) -> List[str]:
"""Keys expected to be in the chain input."""
return []
@property
def output_keys(self) -> List[str]:
"""Keys expected to be in the chain output."""
return []
@abstractmethod
def _run_strategy(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Iterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]: ...
async def _arun_strategy(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AsyncIterator[Tuple[AgentFinish, List[Tuple[AgentAction, str]]]]:
loop = asyncio.get_event_loop()
sync_run_manager = run_manager.get_sync() if run_manager is not None else None
result = await loop.run_in_executor(None, self._run_strategy, inputs, sync_run_manager)
for item in result:
yield item
def _return(
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if run_manager:
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
if self.return_finish_log:
final_output["finish_log"] = output.log
return final_output
async def _areturn(
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if run_manager:
await run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
if self.return_finish_log:
final_output["finish_log"] = output.log
return final_output
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
outputs = [
self._return(output, intermediate_steps, run_manager=run_manager)
for output, intermediate_steps in self._run_strategy(
inputs=inputs,
run_manager=run_manager,
)
]
return {key: [output[key] for output in outputs] for key in outputs[0]}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
_outputs = self._arun_strategy(
inputs=inputs,
run_manager=run_manager,
)
outputs = []
async for _output, _intermediate_steps in _outputs:
output = await self._areturn(_output, _intermediate_steps, run_manager=run_manager)
outputs.append(output)
return {key: [output[key] for output in outputs] for key in outputs[0]}
# TODO: what should the interface be?
class BaseLangGraphStrategy(ABC): ...