Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor reasoning pipeline #31

Merged
merged 5 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/kotaemon/kotaemon/base/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def set_output_queue(self, queue):
if isinstance(node, BaseComponent):
node.set_output_queue(queue)

def report_output(self, output: Optional[dict]):
def report_output(self, output: Optional[Document]):
if self._queue is not None:
self._queue.put_nowait(output)

Expand Down
3 changes: 2 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def prepare_client(self, async_version: bool = False):

def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
params_ = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
Expand All @@ -285,6 +285,7 @@ def openai_response(self, client, **kwargs):
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params = {k: v for k, v in params_.items() if v is not None}
params.update(kwargs)

return client.chat.completions.create(**params)
Expand Down
16 changes: 7 additions & 9 deletions libs/ktem/ktem/llms/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize

from kotaemon.base import BaseComponent
from kotaemon.llms import ChatLLM

from .db import LLMTable, engine

Expand All @@ -14,7 +14,7 @@ class LLMManager:
"""Represent a pool of models"""

def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._models: dict[str, ChatLLM] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
Expand Down Expand Up @@ -63,17 +63,15 @@ def load_vendors(self):

self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]

def __getitem__(self, key: str) -> BaseComponent:
def __getitem__(self, key: str) -> ChatLLM:
"""Get model by name"""
return self._models[key]

def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models

def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
"""Get model by name with default value"""
return self._models.get(key, default)

Expand Down Expand Up @@ -119,18 +117,18 @@ def get_default_name(self) -> str:

return self._default

def get_random(self) -> BaseComponent:
def get_random(self) -> ChatLLM:
"""Get random model"""
return self._models[self.get_random_name()]

def get_default(self) -> BaseComponent:
def get_default(self) -> ChatLLM:
"""Get default model

In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.

Returns:
BaseComponent: model
ChatLLM: model
"""
return self._models[self.get_default_name()]

Expand Down
Loading
Loading