Skip to content

Commit

Permalink
Refactor reasoning pipeline (#31)
Browse files Browse the repository at this point in the history
* Move the text rendering out for reusability

* Refactor common operations in the reasoning pipeline

* Add run method

* Provide dedicated method for invoke
  • Loading branch information
trducng authored Apr 13, 2024
1 parent af38708 commit 0417610
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 258 deletions.
2 changes: 1 addition & 1 deletion libs/kotaemon/kotaemon/base/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def set_output_queue(self, queue):
if isinstance(node, BaseComponent):
node.set_output_queue(queue)

def report_output(self, output: Optional[dict]):
def report_output(self, output: Optional[Document]):
if self._queue is not None:
self._queue.put_nowait(output)

Expand Down
3 changes: 2 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def prepare_client(self, async_version: bool = False):

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
params_ = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
Expand All @@ -285,6 +285,7 @@ def openai_response(self, client, **kwargs):
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

return client.chat.completions.create(**params)
Expand Down
16 changes: 7 additions & 9 deletions libs/ktem/ktem/llms/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize

from kotaemon.base import BaseComponent
from kotaemon.llms import ChatLLM

from .db import LLMTable, engine

Expand All @@ -14,7 +14,7 @@ class LLMManager:
"""Represent a pool of models"""

def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._models: dict[str, ChatLLM] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
Expand Down Expand Up @@ -63,17 +63,15 @@ def load_vendors(self):

self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]

def __getitem__(self, key: str) -> BaseComponent:
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""
return self._models[key]

def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models

def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)

Expand Down Expand Up @@ -119,18 +117,18 @@ def get_default_name(self) -> str:

return self._default

def get_random(self) -> BaseComponent:
def get_random(self) -> ChatLLM:
"""Get random model"""
return self._models[self.get_random_name()]

def get_default(self) -> BaseComponent:
def get_default(self) -> ChatLLM:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseComponent: model
ChatLLM: model
"""
return self._models[self.get_default_name()]

Expand Down
Loading

0 comments on commit 0417610

Please sign in to comment.