diff --git a/docker-compose.yaml b/docker-compose.yaml index 94044916b..926b8515c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -43,7 +43,7 @@ services: - redis_data:/data super__postgres: - image: "docker.io/library/postgres:latest" + image: "docker.io/library/postgres:15" environment: - POSTGRES_USER=superagi - POSTGRES_PASSWORD=password diff --git a/gui/pages/Content/APM/ApmDashboard.js b/gui/pages/Content/APM/ApmDashboard.js index 535de9ad8..a79dd3b3d 100644 --- a/gui/pages/Content/APM/ApmDashboard.js +++ b/gui/pages/Content/APM/ApmDashboard.js @@ -76,7 +76,7 @@ export default function ApmDashboard() { const fetchData = async () => { try { const [metricsResponse, agentsResponse, activeRunsResponse, toolsUsageResponse] = await Promise.all([getMetrics(), getAllAgents(), getActiveRuns(), getToolsUsage()]); - const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001']; + const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat']; assignDefaultDataPerModel(metricsResponse.data.agent_details.model_metrics, models); assignDefaultDataPerModel(metricsResponse.data.tokens_details.model_metrics, models); diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js index 234d903e0..f2e58d4bb 100644 --- a/gui/pages/Content/Agents/AgentCreate.js +++ b/gui/pages/Content/Agents/AgentCreate.js @@ -530,7 +530,7 @@ export default function AgentCreate({ const name = response.data.name; const executionId = response.data.execution_id; fetchAgents(); - getUserClick('Agent Created Successfully', {}) + getUserClick('Agent Created Successfully', {'templateName': template?.id ? template.name : ''}) getUserClick('Agent Run created successfully', {}) uploadResources(agentId, name, executionId) }) diff --git a/gui/pages/Content/Models/AddModel.js b/gui/pages/Content/Models/AddModel.js index e596cb80c..0ef3d5497 100644 --- a/gui/pages/Content/Models/AddModel.js +++ b/gui/pages/Content/Models/AddModel.js @@ -1,14 +1,14 @@ import React, {useEffect, useState} from "react"; import ModelForm from "./ModelForm"; -export default function AddModel({internalId, getModels, sendModelData}){ +export default function AddModel({internalId, getModels, sendModelData, env}){ return(
- +
diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index d8b248c56..83fff48ee 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -1,12 +1,12 @@ import React, {useEffect, useRef, useState} from "react"; import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils"; import Image from "next/image"; -import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService"; +import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/DashboardService"; import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; -export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; +export default function ModelForm({internalId, getModels, sendModelData, env}){ + const models = env === 'DEV' ? ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'] : ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); @@ -14,9 +14,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [modelEndpoint, setModelEndpoint] = useState(''); const [modelDropdown, setModelDropdown] = useState(false); const [modelVersion, setModelVersion] = useState(''); + const [modelContextLength, setContextLength] = useState(4096); const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) + const [modelStatus, setModelStatus] = useState(null); + const [createClickable, setCreateClickable] = useState(true); const modelRef = useRef(null); useEffect(() => { @@ -79,13 +82,31 @@ export default function ModelForm({internalId, getModels, sendModelData}){ }) } + const handleModelStatus = async () => { + try { + setCreateClickable(false); + const response = await testModel(); + if(response.status === 200) { + setModelStatus(true); + setCreateClickable(true); + } else { + setModelStatus(false); + setCreateClickable(true); + } + } catch(error) { + console.log("Error Message:: " + error); + setModelStatus(false); + setCreateClickable(true); + } + } + const handleModelSuccess = (model) => { model.contentType = 'Model' sendModelData(model) } const storeModelDetails = (modelProviderId) => { - storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{ + storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{ setIsLoading(false) let data = response.data if (data.error) { @@ -122,7 +143,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){
{modelDropdown &&
{models.map((model, index) => ( -
handleModelSelect(index)} style={{padding: '12px 14px', maxWidth: '100%'}}> +
{setModelStatus(null); handleModelSelect(index)}} style={{padding: '12px 14px', maxWidth: '100%'}}> {model}
))}
} @@ -153,18 +174,42 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelVersion(event.target.value)}/>
} + {(selectedModel === 'Local LLM') &&
+ Model Context Length + setContextLength(event.target.value)}/> +
} +
Token Limit setModelTokenLimit(parseInt(event.target.value, 10))}/>
-
- - + {selectedModel === 'Local LLM' && modelStatus===false &&
+ error-icon +
+ Test model failed +
+
} + + {selectedModel === 'Local LLM' && modelStatus===true &&
+ +
+ Test model successful +
+
} + +
+ {selectedModel==='Local LLM' && } +
+ + +
diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js index 0611a7be0..5ad7740b6 100644 --- a/gui/pages/Dashboard/Content.js +++ b/gui/pages/Dashboard/Content.js @@ -470,7 +470,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat organisationId={organisationId} sendKnowledgeData={addTab} sendAgentData={addTab} selectedProjectId={selectedProjectId} editAgentId={tab.id} fetchAgents={getAgentList} toolkits={toolkits} template={null} edit={true} agents={agents}/>} - {tab.contentType === 'Add_Model' && } + {tab.contentType === 'Add_Model' && } {tab.contentType === 'Model' && }
}
diff --git a/gui/pages/Dashboard/TopBar.js b/gui/pages/Dashboard/TopBar.js index d6601ca3e..b65aa70d9 100644 --- a/gui/pages/Dashboard/TopBar.js +++ b/gui/pages/Dashboard/TopBar.js @@ -24,6 +24,7 @@ export default function TopBar({selectedProject, userName, env}) { // localStorage.removeItem('accessToken'); Cookies.set('accessToken', '', { expires: new Date(0),domain: '.superagi.com', path: '/'}); Cookies.set('Source', 'app.superagi', {domain: '.superagi.com', path: '/'}); + Cookies.set('mixpanel_initialized', 'false', {domain: '.superagi.com', path: '/'}); refreshUrl(); router.reload(); }; diff --git a/gui/pages/_app.css b/gui/pages/_app.css index 88e8d2dd3..973582b49 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -231,18 +231,6 @@ input[type="range"]::-moz-range-track { z-index: 10; } -.dropdown_container_models { - flex-direction: column; - align-items: flex-start; - border-radius: 8px; - background: #2E293F; - box-shadow: -2px 2px 24px rgba(0, 0, 0, 0.4); - position: absolute; - width: fit-content; - height: fit-content; - padding: 8px; -} - .dropdown_container { width: 150px; height: auto; @@ -783,7 +771,6 @@ p { .mt_74{margin-top: 74px;} .mt_80{margin-top: 80px;} .mt_90{margin-top: 90px;} -.mt_130{margin-top: 130px;} .mb_1{margin-bottom: 1px;} .mb_2{margin-bottom: 2px;} @@ -991,22 +978,6 @@ p { line-height: normal; } -.text_20 { - color: #FFF; - font-size: 20px; - font-style: normal; - font-weight: 400; - line-height: normal; -} - -.text_20 { - color: #FFF; - font-size: 20px; - font-style: normal; - font-weight: 400; - line-height: normal; -} - .text_20_bold{ color: #FFF; font-size: 20px; @@ -1107,7 +1078,6 @@ p { .w_73{width: 73%} .w_97{width: 97%} .w_100{width: 100%} -.w_99vw{width: 99vw} .w_inherit{width: inherit} .w_fit_content{width:fit-content} .w_inherit{width: inherit} @@ -1125,11 +1095,11 @@ p { .h_80vh{height: 80vh} .h_calc92{height: calc(100vh - 92px)} .h_calc_add40{height: calc(80vh + 40px)} -.h_calc_sub_60{height: calc(92.5vh - 60px)} .mxh_78vh{max-height: 78vh} .flex_dir_col{flex-direction: column} +.flex_none{flex: none} .justify_center{justify-content: center} .justify_end{justify-content: flex-end} @@ -1138,8 +1108,6 @@ p { .display_flex{display: inline-flex} .display_flex_container{display: flex} -.display_none{display: none} -.display_block{display: block} .align_center{align-items: center} .align_start{align-items: flex-start} @@ -1178,8 +1146,6 @@ p { .bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);} -.bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);} - .color_white{color:#FFFFFF} .color_gray{color:#888888} @@ -1188,7 +1154,7 @@ p { .lh_18{line-height: 18px;} .lh_24{line-height: 24px;} -.padding_0{padding: 0} +.padding_0{padding: 0;} .padding_5{padding: 5px;} .padding_6{padding: 6px;} .padding_8{padding: 8px;} @@ -1505,7 +1471,6 @@ tr{ .bg_none{background: none;} .bg_primary{background: #2E293F;} .bg_secondary{background: #272335;} -.bg_none{background: none} .container { height: 100%; @@ -1871,6 +1836,13 @@ tr{ padding: 12px; } +.success_box{ + border-radius: 8px; + padding: 12px; + border-left: 4px solid rgba(255, 255, 255, 0.60); + background: rgba(255, 255, 255, 0.08); +} + .horizontal_line { margin: 16px 0 16px -16px; border: 1px solid #ffffff20; @@ -1922,26 +1894,4 @@ tr{ .tooltip-class { background-color: green; border-radius: 6px; -} - -.text_dropdown { - color: #FFFFFF; - font-family: Plus Jakarta Sans, sans-serif; - font-style: normal; - font-weight: 500; - line-height: normal; -} - -.text_dropdown_18 { - font-size: 18px; -} - -.vertical_divider { - background: transparent; - /*border-color: rgba(255, 255, 255, 0.08);*/ - border: 1.2px solid rgba(255, 255, 255, 0.08);; - height: 20px; - width: 0; -} - - +} \ No newline at end of file diff --git a/gui/pages/_app.js b/gui/pages/_app.js index a5466dfb6..8d1a69e38 100644 --- a/gui/pages/_app.js +++ b/gui/pages/_app.js @@ -61,6 +61,7 @@ export default function App() { }); } + const installFromMarketplace = () => { const toolkitName = localStorage.getItem('toolkit_to_install') || null; const agentTemplateId = localStorage.getItem('agent_to_install') || null; @@ -110,10 +111,11 @@ export default function App() { .then((response) => { const env = response.data.env; setEnv(env); - + const mixpanelInitialized = Cookies.get('mixpanel_initialized') === 'true' if (typeof window !== 'undefined') { - if(response.data.env === 'PROD' && mixpanelId()) - mixpanel.init(mixpanelId(), { debug: false, track_pageview: true, persistence: 'localStorage' }); + if(response.data.env === 'PROD' && mixpanelId()) { + mixpanel.init(mixpanelId(), {debug: false, track_pageview: !mixpanelInitialized, persistence: 'localStorage'}); + } localStorage.setItem('applicationEnvironment', env); } @@ -122,7 +124,7 @@ export default function App() { const queryParams = router.asPath.split('?')[1]; const parsedParams = querystring.parse(queryParams); let access_token = parsedParams.access_token || null; - let first_login = parsedParams.first_time_login || false + let first_login = parsedParams.first_time_login || '' const utmParams = getUTMParametersFromURL(); if (utmParams) { @@ -136,7 +138,7 @@ export default function App() { if (typeof window !== 'undefined' && access_token) { // localStorage.setItem('accessToken', access_token); - Cookies.set('accessToken', access_token, { domain: '.superagi.com', path: '/' }); + Cookies.set('accessToken', access_token, {domain: '.superagi.com', path: '/'}); refreshUrl(); } validateAccessToken() @@ -145,15 +147,19 @@ export default function App() { sendGAEvent(response.data.email, 'Signed Up Successfully', {'utm_source': signupSource || '', 'utm_medium': signupMedium || '', 'campaign': singupCampaign || ''}) if(mixpanelId()) mixpanel.identify(response.data.email) - if(first_login) + if(first_login === 'True') { getUserClick('New Sign Up', {}) - else - getUserClick('User Logged In', {}) + } + else { + if (first_login === 'False') + getUserClick('User Logged In', {}) + } if(signupSource) { handleSignUpSource(signupSource) } fetchOrganisation(response.data.id); + Cookies.set('mixpanel_initialized', 'true', {domain: '.superagi.com', path: '/'}); }) .catch((error) => { console.error('Error validating access token:', error); diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 38e13c698..2e5f93869 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -358,8 +358,12 @@ export const verifyEndPoint = (model_api_key, end_point, model_provider) => { }); } -export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => { - return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version}); +export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => { + return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length}); +} + +export const testModel = () => { + return api.get(`/models_controller/test_local_llm`); } export const fetchModels = () => { @@ -389,7 +393,6 @@ export const getToolLogs = (toolName) => { export const publishTemplateToMarketplace = (agentData) => { return api.post(`/agent_templates/publish_template`, agentData); }; - export const getKnowledgeMetrics = (knowledgeName) => { return api.get(`analytics/knowledge/${knowledgeName}/usage`) } diff --git a/main.py b/main.py index 2f93cfa77..cdfb83486 100644 --- a/main.py +++ b/main.py @@ -52,6 +52,7 @@ from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace from superagi.models.agent_template import AgentTemplate +from superagi.models.models_config import ModelsConfig from superagi.models.organisation import Organisation from superagi.models.types.login_request import LoginRequest from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py new file mode 100644 index 000000000..7e6371e8a --- /dev/null +++ b/migrations/versions/9270eb5a8475_local_llms.py @@ -0,0 +1,28 @@ +"""local_llms + +Revision ID: 9270eb5a8475 +Revises: 3867bb00a495 +Create Date: 2023-10-04 09:26:33.865424 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9270eb5a8475' +down_revision = '3867bb00a495' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('models', 'context_length') + # ### end Alembic commands ### diff --git a/requirements.txt b/requirements.txt index 4e00f97c6..a8a9f1dc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -159,3 +159,4 @@ honeybadger==0.16.0 unstructured==0.8.1 ai21==1.2.6 typing-extensions==4.5.0 +llama_cpp_python==0.2.7 diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 6748c10c3..e35198295 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -1,7 +1,9 @@ +from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, Body from superagi.helper.auth import check_auth, get_user_organisation from superagi.helper.models_helper import ModelsHelper from superagi.apm.call_log_helper import CallLogHelper +from superagi.lib.logger import logger from superagi.models.models import Models from superagi.models.models_config import ModelsConfig from superagi.config.config import get_config @@ -9,6 +11,7 @@ from fastapi_sqlalchemy import db import logging from pydantic import BaseModel +from superagi.helper.llm_loader import LLMLoader router = APIRouter() @@ -26,6 +29,7 @@ class StoreModelRequest(BaseModel): token_limit: int type: str version: str + context_length: Optional[int] class ModelName (BaseModel): model: str @@ -69,7 +73,12 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod @router.post("/store_model", status_code=200) async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)): try: - return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version) + #context_length = 4096 + logger.info(request) + if 'context_length' in request.dict(): + return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length) + else: + return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, 0) except Exception as e: logging.error(f"Error storing the Model Details: {str(e)}") raise HTTPException(status_code=500, detail="Internal Server Error") @@ -173,4 +182,32 @@ def get_models_details(page: int = 0): marketplace_models = Models.fetch_marketplace_list(page) marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id, ModelsTypes.MARKETPLACE.value) - return marketplace_models_with_install \ No newline at end of file + return marketplace_models_with_install + +@router.get("/test_local_llm", status_code=200) +def test_local_llm(): + try: + llm_loader = LLMLoader(context_length=4096) + llm_model = llm_loader.model + llm_grammar = llm_loader.grammar + if llm_model is None: + logger.error("Model not found.") + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") + if llm_grammar is None: + logger.error("Grammar not found.") + raise HTTPException(status_code=404, detail="Grammar not found.") + + messages = [ + {"role":"system", + "content":"You are an AI assistant. Give response in a proper JSON format"}, + {"role":"user", + "content":"Hi!"} + ] + response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return "Model loaded successfully." + + except Exception as e: + logger.info("Error: ",e) + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") \ No newline at end of file diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py index c550fd889..f4dce4b4f 100644 --- a/superagi/controllers/user.py +++ b/superagi/controllers/user.py @@ -14,6 +14,8 @@ from superagi.helper.auth import check_auth, get_current_user from superagi.lib.logger import logger +from superagi.models.models_config import ModelsConfig + # from superagi.types.db import UserBase, UserIn, UserOut router = APIRouter() @@ -73,6 +75,10 @@ def create_user(user: UserIn, organisation = Organisation.find_or_create_organisation(db.session, db_user) Project.find_or_create_default_project(db.session, organisation.id) logger.info("User created", db_user) + + #adding local llm configuration + ModelsConfig.add_llm_config(db.session, organisation.id) + return db_user diff --git a/superagi/helper/encyption_helper.py b/superagi/helper/encyption_helper.py index f2cf6fd13..4939b1d0c 100644 --- a/superagi/helper/encyption_helper.py +++ b/superagi/helper/encyption_helper.py @@ -1,9 +1,14 @@ from cryptography.fernet import Fernet, InvalidToken, InvalidSignature - +from superagi.config.config import get_config # Generate a key # key = Fernet.generate_key() -key = b'e3mp0E0Jr3jnVb96A31_lKzGZlSTPIp4-rPaVseyn58=' +key = get_config("ENCRYPTION_KEY") +if key is None: + raise Exception("Encryption key not found in config file.") +key = key.encode( + "utf-8" +) cipher_suite = Fernet(key) diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py new file mode 100644 index 000000000..8d78337da --- /dev/null +++ b/superagi/helper/llm_loader.py @@ -0,0 +1,38 @@ +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +from superagi.config.config import get_config +from superagi.lib.logger import logger + + +class LLMLoader: + _instance = None + _model = None + _grammar = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(LLMLoader, cls).__new__(cls) + return cls._instance + + def __init__(self, context_length): + self.context_length = context_length + + @property + def model(self): + if self._model is None: + try: + self._model = Llama( + model_path="/app/local_model_path", n_ctx=self.context_length) + except Exception as e: + logger.error(e) + return self._model + + @property + def grammar(self): + if self._grammar is None: + try: + self._grammar = LlamaGrammar.from_file( + "superagi/llms/grammar/json.gbnf") + except Exception as e: + logger.error(e) + return self._grammar \ No newline at end of file diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index 4418dd86e..45e91ee4d 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from sqlalchemy.orm import sessionmaker +from superagi.llms.local_llm import LocalLLM import superagi.worker from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler @@ -12,7 +13,6 @@ from superagi.lib.logger import logger from superagi.llms.google_palm import GooglePalm from superagi.llms.hugging_face import HuggingFace -from superagi.llms.replicate import Replicate from superagi.llms.llm_model_factory import get_model from superagi.llms.replicate import Replicate from superagi.models.agent import Agent @@ -27,8 +27,6 @@ from superagi.worker import execute_agent from superagi.agent.types.agent_workflow_step_action_types import AgentWorkflowStepAction from superagi.agent.types.agent_execution_status import AgentExecutionStatus -from superagi.vector_store.redis import Redis -from superagi.config.config import get_config # from superagi.helper.tool_helper import get_tool_config_by_key @@ -138,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key): return HuggingFace(api_key=model_api_key) if "Replicate" in model_source: return Replicate(api_key=model_api_key) + if "Custom" in model_source: + return LocalLLM() return None def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id): @@ -183,4 +183,4 @@ def execute_waiting_workflows(self): AgentWaitStepHandler(session=session, agent_id=agent_execution.agent_id, agent_execution_id=agent_execution.id).handle_next_step() execute_agent.delay(agent_execution.id, datetime.now()) - session.close() + session.close() \ No newline at end of file diff --git a/superagi/llms/grammar/json.gbnf b/superagi/llms/grammar/json.gbnf new file mode 100644 index 000000000..a9537cdf9 --- /dev/null +++ b/superagi/llms/grammar/json.gbnf @@ -0,0 +1,25 @@ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index 251c64a71..345c4f8c7 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -1,4 +1,5 @@ from superagi.llms.google_palm import GooglePalm +from superagi.llms.local_llm import LocalLLM from superagi.llms.openai import OpenAi from superagi.llms.replicate import Replicate from superagi.llms.hugging_face import HuggingFace @@ -33,6 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) + elif provider_name == 'Local LLM': + print("Provider is Local LLM") + return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length) else: print('Unknown provider.') @@ -45,5 +49,7 @@ def build_model_with_api_key(provider_name, api_key): return GooglePalm(api_key=api_key) elif provider_name.lower() == 'hugging face': return HuggingFace(api_key=api_key) + elif provider_name.lower() == 'local llm': + return LocalLLM(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py new file mode 100644 index 000000000..a146d7daa --- /dev/null +++ b/superagi/llms/local_llm.py @@ -0,0 +1,92 @@ +from superagi.config.config import get_config +from superagi.lib.logger import logger +from superagi.llms.base_llm import BaseLlm +from superagi.helper.llm_loader import LLMLoader + + +class LocalLLM(BaseLlm): + def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1, + frequency_penalty=0, + presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096): + """ + Args: + model (str): The model. + temperature (float): The temperature. + max_tokens (int): The maximum number of tokens. + top_p (float): The top p. + frequency_penalty (float): The frequency penalty. + presence_penalty (float): The presence penalty. + number_of_results (int): The number of results. + """ + self.model = model + self.api_key = api_key + self.temperature = temperature + self.max_tokens = max_tokens + self.top_p = top_p + self.frequency_penalty = frequency_penalty + self.presence_penalty = presence_penalty + self.number_of_results = number_of_results + self.context_length = context_length + + llm_loader = LLMLoader(self.context_length) + self.llm_model = llm_loader.model + self.llm_grammar = llm_loader.grammar + + def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")): + """ + Call the chat completion. + + Args: + messages (list): The messages. + max_tokens (int): The maximum number of tokens. + + Returns: + dict: The response. + """ + try: + if self.llm_model is None or self.llm_grammar is None: + logger.error("Model not found.") + return {"error": "Model loading error", "message": "Model not found. Please check your model path and try again."} + else: + response = self.llm_model.create_chat_completion(messages=messages, functions=None, function_call=None, temperature=self.temperature, top_p=self.top_p, + max_tokens=int(max_tokens), presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, grammar=self.llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return {"response": response, "content": content} + + except Exception as exception: + logger.info("Exception:", exception) + return {"error": "ERROR", "message": "Error: "+str(exception)} + + def get_source(self): + """ + Get the source. + + Returns: + str: The source. + """ + return "Local LLM" + + def get_api_key(self): + """ + Returns: + str: The API key. + """ + return self.api_key + + def get_model(self): + """ + Returns: + str: The model. + """ + return self.model + + def get_models(self): + """ + Returns: + list: The models. + """ + return self.model + + def verify_access_key(self, api_key): + return True \ No newline at end of file diff --git a/superagi/models/models.py b/superagi/models/models.py index 547cacf3a..47241afea 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -1,3 +1,4 @@ +import yaml from sqlalchemy import Column, Integer, String, and_ from sqlalchemy.sql import func from typing import List, Dict, Union @@ -5,6 +6,7 @@ from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging +from superagi.lib.logger import logger marketplace_url = "https://app.superagi.com/api" # marketplace_url = "http://localhost:8001" @@ -39,6 +41,7 @@ class Models(DBBaseModel): version = Column(String, nullable=False) org_id = Column(Integer, nullable=False) model_features = Column(String, nullable=False) + context_length = Column(Integer, nullable=True) def __repr__(self): """ @@ -104,7 +107,7 @@ def fetch_model_tokens(cls, session, organisation_id) -> Dict[str, int]: return {"error": "Unexpected Error Occured"} @classmethod - def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version): + def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version, context_length): from superagi.models.models_config import ModelsConfig if not model_name: return {"error": "Model Name is empty or undefined"} @@ -130,9 +133,12 @@ def store_model_details(cls, session, organisation_id, model_name, description, return model # Return error message if model not found # Check the 'provider' from ModelsConfig table - if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate']: + if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']: return {"error": "End Point is empty or undefined"} + if context_length is None: + context_length = 0 + try: model = Models( model_name=model_name, @@ -143,7 +149,8 @@ def store_model_details(cls, session, organisation_id, model_name, description, type=type, version=version, org_id=organisation_id, - model_features='' + model_features='', + context_length=context_length ) session.add(model) session.commit() @@ -229,4 +236,4 @@ def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[st except Exception as e: logging.error(f"Unexpected Error Occured: {e}") - return {"error": "Unexpected Error Occured"} + return {"error": "Unexpected Error Occured"} \ No newline at end of file diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index 0c8c13b95..03833ef05 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, and_, distinct +from superagi.lib.logger import logger from superagi.models.base_model import DBBaseModel from superagi.models.organisation import Organisation from superagi.models.project import Project @@ -69,6 +70,9 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str): if not config: return None + if config.provider == 'Local LLM': + return {"provider": config.provider, "api_key": config.api_key} if config else None + return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None @classmethod @@ -102,7 +106,7 @@ def storeGptModels(cls, session, organisation_id, model_provider_id, model_api_k for model in models: if model not in installed_models and model in default_models: result = Models.store_model_details(session, organisation_id, model, model, '', - model_provider_id, default_models[model], 'Custom', '') + model_provider_id, default_models[model], 'Custom', '', 0) @classmethod def fetch_api_keys(cls, session, organisation_id): @@ -123,8 +127,13 @@ def fetch_api_key(cls, session, organisation_id, model_provider): api_key_data = session.query(ModelsConfig.id, ModelsConfig.provider, ModelsConfig.api_key).filter( and_(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == model_provider)).first() + logger.info(api_key_data) if api_key_data is None: return [] + elif api_key_data.provider == 'Local LLM': + api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, + 'api_key': api_key_data.api_key}] + return api_key else: api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, 'api_key': decrypt_data(api_key_data.api_key)}] @@ -145,4 +154,12 @@ def fetch_model_by_id_marketplace(cls, session, model_provider_id): if model is None: return {"error": "Model not found"} else: - return {"provider": model.provider} \ No newline at end of file + return {"provider": model.provider} + + @classmethod + def add_llm_config(cls, session, organisation_id): + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == 'Local LLM').first() + if existing_models_config is None: + models_config = ModelsConfig(org_id=organisation_id, provider='Local LLM', api_key="EMPTY") + session.add(models_config) + session.commit() \ No newline at end of file diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py index f811a60c6..6e9de18ad 100644 --- a/superagi/types/model_source_types.py +++ b/superagi/types/model_source_types.py @@ -6,6 +6,7 @@ class ModelSourceType(Enum): OpenAI = 'OpenAi' Replicate = 'Replicate' HuggingFace = 'Hugging Face' + LocalLLM = 'Local LLM' @classmethod def get_model_source_type(cls, name): diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 489cff636..790229789 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -2,6 +2,11 @@ import pytest from fastapi.testclient import TestClient from main import app +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +import llama_cpp + +from superagi.helper.llm_loader import LLMLoader client = TestClient(app) @@ -50,7 +55,8 @@ def test_store_model_success(mock_get_db): "model_provider_id": 1, "token_limit": 10, "type": "mock_type", - "version": "mock_version" + "version": "mock_version", + "context_length":4096 } with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ patch('superagi.helper.auth.db') as mock_auth_db: @@ -100,3 +106,13 @@ def test_get_marketplace_models_list_success(mock_get_db): patch('superagi.helper.auth.db') as mock_auth_db: response = client.get("/models_controller/marketplace/list/0") assert response.status_code == 200 + +def test_get_local_llm(): + with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model: + with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar: + + mock_model.create_chat_completion.return_value = {"choices": [{"message": {"content": "Hello!"}}]} + + response = client.get("/models_controller/test_local_llm") + + assert response.status_code == 200 \ No newline at end of file diff --git a/tests/unit_tests/models/test_models.py b/tests/unit_tests/models/test_models.py index 3bdc43075..d4880538c 100644 --- a/tests/unit_tests/models/test_models.py +++ b/tests/unit_tests/models/test_models.py @@ -133,6 +133,7 @@ def test_store_model_details_when_model_exists(mock_session): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -161,6 +162,7 @@ def test_store_model_details_when_model_not_exists(mock_session, monkeypatch): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -187,6 +189,7 @@ def test_store_model_details_when_unexpected_error_occurs(mock_session, monkeypa token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -229,6 +232,4 @@ def test_fetch_model_details(mock_models_config, mock_session): "token_limit": 100, "type": "type1", "model_provider": "example_provider" - } - - + } \ No newline at end of file