diff --git a/gui/pages/Content/APM/ApmDashboard.js b/gui/pages/Content/APM/ApmDashboard.js index 535de9ad8..a79dd3b3d 100644 --- a/gui/pages/Content/APM/ApmDashboard.js +++ b/gui/pages/Content/APM/ApmDashboard.js @@ -76,7 +76,7 @@ export default function ApmDashboard() { const fetchData = async () => { try { const [metricsResponse, agentsResponse, activeRunsResponse, toolsUsageResponse] = await Promise.all([getMetrics(), getAllAgents(), getActiveRuns(), getToolsUsage()]); - const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001']; + const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']; assignDefaultDataPerModel(metricsResponse.data.agent_details.model_metrics, models); assignDefaultDataPerModel(metricsResponse.data.tokens_details.model_metrics, models); diff --git a/gui/pages/_app.js b/gui/pages/_app.js index 125ef27d0..ee3ee41ec 100644 --- a/gui/pages/_app.js +++ b/gui/pages/_app.js @@ -61,6 +61,7 @@ export default function App() { }); } + const installFromMarketplace = () => { const toolkitName = localStorage.getItem('toolkit_to_install') || null; const agentTemplateId = localStorage.getItem('agent_to_install') || null; diff --git a/main.py b/main.py index 2e4b11852..cf4807483 100644 --- a/main.py +++ b/main.py @@ -50,6 +50,7 @@ from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace from superagi.models.agent_template import AgentTemplate +from superagi.models.models_config import ModelsConfig from superagi.models.organisation import Organisation from superagi.models.types.login_request import LoginRequest from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest @@ -215,6 +216,13 @@ def register_toolkit_for_master_organisation(): Organisation.id == marketplace_organisation_id).first() if marketplace_organisation is not None: register_marketplace_toolkits(session, marketplace_organisation) + + def local_llm_model_config(): + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first() + if existing_models_config is None: + models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY") + session.add(models_config) + session.commit() IterationWorkflowSeed.build_single_step_agent(session) IterationWorkflowSeed.build_task_based_agents(session) @@ -238,7 +246,8 @@ def register_toolkit_for_master_organisation(): # AgentWorkflowSeed.doc_search_and_code(session) # AgentWorkflowSeed.build_research_email_workflow(session) replace_old_iteration_workflows(session) - + local_llm_model_config() + if env != "PROD": register_toolkit_for_all_organisation() else: diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py new file mode 100644 index 000000000..7e6371e8a --- /dev/null +++ b/migrations/versions/9270eb5a8475_local_llms.py @@ -0,0 +1,28 @@ +"""local_llms + +Revision ID: 9270eb5a8475 +Revises: 3867bb00a495 +Create Date: 2023-10-04 09:26:33.865424 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9270eb5a8475' +down_revision = '3867bb00a495' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('models', 'context_length') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index ab45bb1c7..1a2892c25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -158,3 +158,4 @@ google-generativeai==0.1.0 unstructured==0.8.1 ai21==1.2.6 typing-extensions==4.5.0 +llama_cpp_python==0.2.7 diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py new file mode 100644 index 000000000..8c2b19e45 --- /dev/null +++ b/superagi/helper/llm_loader.py @@ -0,0 +1,38 @@ +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +from superagi.config.config import get_config +from superagi.lib.logger import logger + + +class LLMLoader: + _instance = None + _model = None + _grammar = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(LLMLoader, cls).__new__(cls) + return cls._instance + + def __init__(self, context_length): + self.context_length = context_length + + @property + def model(self): + if self._model is None: + try: + self._model = Llama( + model_path="/app/local_model_path", n_ctx=self.context_length) + except Exception as e: + logger.error(e) + return self._model + + @property + def grammar(self): + if self._grammar is None: + try: + self._grammar = LlamaGrammar.from_file( + "superagi/llms/grammar/json.gbnf") + except Exception as e: + logger.error(e) + return self._grammar diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 4418dd86e..e47486a05 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import sessionmaker +from superagi.llms.local_llm import LocalLLM import superagi.worker from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler diff --git a/superagi/llms/grammar/json.gbnf b/superagi/llms/grammar/json.gbnf new file mode 100644 index 000000000..a9537cdf9 --- /dev/null +++ b/superagi/llms/grammar/json.gbnf @@ -0,0 +1,25 @@ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index 251c64a71..af6cfedf6 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -1,4 +1,5 @@ from superagi.llms.google_palm import GooglePalm +from superagi.llms.local_llm import LocalLLM from superagi.llms.openai import OpenAi from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py new file mode 100644 index 000000000..608afa289 --- /dev/null +++ b/superagi/llms/local_llm.py @@ -0,0 +1,92 @@ +from superagi.config.config import get_config +from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.helper.llm_loader import LLMLoader + + +class LocalLLM(BaseLlm): + def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, + frequency_penalty=0, + presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096): + """ + Args: + model (str): The model. + temperature (float): The temperature. + max_tokens (int): The maximum number of tokens. + top_p (float): The top p. + frequency_penalty (float): The frequency penalty. + presence_penalty (float): The presence penalty. + number_of_results (int): The number of results. + """ + self.model = model + self.api_key = api_key + self.temperature = temperature + self.max_tokens = max_tokens + self.top_p = top_p + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.number_of_results = number_of_results + self.context_length = context_length + + llm_loader = LLMLoader(self.context_length) + self.llm_model = llm_loader.model + self.llm_grammar = llm_loader.grammar + + def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")): + """ + Call the chat completion. + + Args: + messages (list): The messages. + max_tokens (int): The maximum number of tokens. + + Returns: + dict: The response. + """ + try: + if self.llm_model is None or self.llm_grammar is None: + logger.error("Model not found.") + return {"error": "Model loading error", "message": "Model not found. Please check your model path and try again."} + else: + response = self.llm_model.create_chat_completion(messages=messages, functions=None, function_call=None, temperature=self.temperature, top_p=self.top_p, + max_tokens=int(max_tokens), presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, grammar=self.llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return {"response": response, "content": content} + + except Exception as exception: + logger.info("Exception:", exception) + return {"error": "ERROR", "message": "Error: "+str(exception)} + + def get_source(self): + """ + Get the source. + + Returns: + str: The source. + """ + return "Local LLM" + + def get_api_key(self): + """ + Returns: + str: The API key. + """ + return self.api_key + + def get_model(self): + """ + Returns: + str: The model. + """ + return self.model + + def get_models(self): + """ + Returns: + list: The models. + """ + return self.model + + def verify_access_key(self, api_key): + return True