diff --git a/docker-compose.yaml b/docker-compose.yaml
index 94044916b..926b8515c 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -43,7 +43,7 @@ services:
- redis_data:/data
super__postgres:
- image: "docker.io/library/postgres:latest"
+ image: "docker.io/library/postgres:15"
environment:
- POSTGRES_USER=superagi
- POSTGRES_PASSWORD=password
diff --git a/gui/pages/Content/APM/ApmDashboard.js b/gui/pages/Content/APM/ApmDashboard.js
index 535de9ad8..a79dd3b3d 100644
--- a/gui/pages/Content/APM/ApmDashboard.js
+++ b/gui/pages/Content/APM/ApmDashboard.js
@@ -76,7 +76,7 @@ export default function ApmDashboard() {
const fetchData = async () => {
try {
const [metricsResponse, agentsResponse, activeRunsResponse, toolsUsageResponse] = await Promise.all([getMetrics(), getAllAgents(), getActiveRuns(), getToolsUsage()]);
- const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001'];
+ const models = ['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-32k', 'google-palm-bison-001', 'replicate-llama13b-v2-chat'];
assignDefaultDataPerModel(metricsResponse.data.agent_details.model_metrics, models);
assignDefaultDataPerModel(metricsResponse.data.tokens_details.model_metrics, models);
diff --git a/gui/pages/Content/Agents/AgentCreate.js b/gui/pages/Content/Agents/AgentCreate.js
index 234d903e0..f2e58d4bb 100644
--- a/gui/pages/Content/Agents/AgentCreate.js
+++ b/gui/pages/Content/Agents/AgentCreate.js
@@ -530,7 +530,7 @@ export default function AgentCreate({
const name = response.data.name;
const executionId = response.data.execution_id;
fetchAgents();
- getUserClick('Agent Created Successfully', {})
+ getUserClick('Agent Created Successfully', {'templateName': template?.id ? template.name : ''})
getUserClick('Agent Run created successfully', {})
uploadResources(agentId, name, executionId)
})
diff --git a/gui/pages/Content/Models/AddModel.js b/gui/pages/Content/Models/AddModel.js
index e596cb80c..0ef3d5497 100644
--- a/gui/pages/Content/Models/AddModel.js
+++ b/gui/pages/Content/Models/AddModel.js
@@ -1,14 +1,14 @@
import React, {useEffect, useState} from "react";
import ModelForm from "./ModelForm";
-export default function AddModel({internalId, getModels, sendModelData}){
+export default function AddModel({internalId, getModels, sendModelData, env}){
return(
diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js
index d8b248c56..83fff48ee 100644
--- a/gui/pages/Content/Models/ModelForm.js
+++ b/gui/pages/Content/Models/ModelForm.js
@@ -1,12 +1,12 @@
import React, {useEffect, useRef, useState} from "react";
import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils";
import Image from "next/image";
-import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService";
+import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/DashboardService";
import {BeatLoader, ClipLoader} from "react-spinners";
import {ToastContainer, toast} from 'react-toastify';
-export default function ModelForm({internalId, getModels, sendModelData}){
- const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm'];
+export default function ModelForm({internalId, getModels, sendModelData, env}){
+ const models = env === 'DEV' ? ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'] : ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm'];
const [selectedModel, setSelectedModel] = useState('Select a Model');
const [modelName, setModelName] = useState('');
const [modelDescription, setModelDescription] = useState('');
@@ -14,9 +14,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){
const [modelEndpoint, setModelEndpoint] = useState('');
const [modelDropdown, setModelDropdown] = useState(false);
const [modelVersion, setModelVersion] = useState('');
+ const [modelContextLength, setContextLength] = useState(4096);
const [tokenError, setTokenError] = useState(false);
const [lockAddition, setLockAddition] = useState(true);
const [isLoading, setIsLoading] = useState(false)
+ const [modelStatus, setModelStatus] = useState(null);
+ const [createClickable, setCreateClickable] = useState(true);
const modelRef = useRef(null);
useEffect(() => {
@@ -79,13 +82,31 @@ export default function ModelForm({internalId, getModels, sendModelData}){
})
}
+ const handleModelStatus = async () => {
+ try {
+ setCreateClickable(false);
+ const response = await testModel();
+ if(response.status === 200) {
+ setModelStatus(true);
+ setCreateClickable(true);
+ } else {
+ setModelStatus(false);
+ setCreateClickable(true);
+ }
+ } catch(error) {
+ console.log("Error Message:: " + error);
+ setModelStatus(false);
+ setCreateClickable(true);
+ }
+ }
+
const handleModelSuccess = (model) => {
model.contentType = 'Model'
sendModelData(model)
}
const storeModelDetails = (modelProviderId) => {
- storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{
+ storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{
setIsLoading(false)
let data = response.data
if (data.error) {
@@ -122,7 +143,7 @@ export default function ModelForm({internalId, getModels, sendModelData}){
{modelDropdown &&
{models.map((model, index) => (
-
handleModelSelect(index)} style={{padding: '12px 14px', maxWidth: '100%'}}>
+
{setModelStatus(null); handleModelSelect(index)}} style={{padding: '12px 14px', maxWidth: '100%'}}>
{model}
))}
}
@@ -153,18 +174,42 @@ export default function ModelForm({internalId, getModels, sendModelData}){
onChange={(event) => setModelVersion(event.target.value)}/>
}
+ {(selectedModel === 'Local LLM') &&
+ Model Context Length
+ setContextLength(event.target.value)}/>
+
}
+
Token Limit
setModelTokenLimit(parseInt(event.target.value, 10))}/>
-
-
-
+ {selectedModel === 'Local LLM' && modelStatus===false &&
+
+
+ Test model failed
+
+
}
+
+ {selectedModel === 'Local LLM' && modelStatus===true &&
+
+
+ Test model successful
+
+
}
+
+
+ {selectedModel==='Local LLM' &&
}
+
+
+
+
diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js
index 0611a7be0..5ad7740b6 100644
--- a/gui/pages/Dashboard/Content.js
+++ b/gui/pages/Dashboard/Content.js
@@ -470,7 +470,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat
organisationId={organisationId} sendKnowledgeData={addTab}
sendAgentData={addTab} selectedProjectId={selectedProjectId} editAgentId={tab.id}
fetchAgents={getAgentList} toolkits={toolkits} template={null} edit={true} agents={agents}/>}
- {tab.contentType === 'Add_Model' &&
}
+ {tab.contentType === 'Add_Model' && }
{tab.contentType === 'Model' && }
}
diff --git a/gui/pages/Dashboard/TopBar.js b/gui/pages/Dashboard/TopBar.js
index d6601ca3e..b65aa70d9 100644
--- a/gui/pages/Dashboard/TopBar.js
+++ b/gui/pages/Dashboard/TopBar.js
@@ -24,6 +24,7 @@ export default function TopBar({selectedProject, userName, env}) {
// localStorage.removeItem('accessToken');
Cookies.set('accessToken', '', { expires: new Date(0),domain: '.superagi.com', path: '/'});
Cookies.set('Source', 'app.superagi', {domain: '.superagi.com', path: '/'});
+ Cookies.set('mixpanel_initialized', 'false', {domain: '.superagi.com', path: '/'});
refreshUrl();
router.reload();
};
diff --git a/gui/pages/_app.css b/gui/pages/_app.css
index 88e8d2dd3..973582b49 100644
--- a/gui/pages/_app.css
+++ b/gui/pages/_app.css
@@ -231,18 +231,6 @@ input[type="range"]::-moz-range-track {
z-index: 10;
}
-.dropdown_container_models {
- flex-direction: column;
- align-items: flex-start;
- border-radius: 8px;
- background: #2E293F;
- box-shadow: -2px 2px 24px rgba(0, 0, 0, 0.4);
- position: absolute;
- width: fit-content;
- height: fit-content;
- padding: 8px;
-}
-
.dropdown_container {
width: 150px;
height: auto;
@@ -783,7 +771,6 @@ p {
.mt_74{margin-top: 74px;}
.mt_80{margin-top: 80px;}
.mt_90{margin-top: 90px;}
-.mt_130{margin-top: 130px;}
.mb_1{margin-bottom: 1px;}
.mb_2{margin-bottom: 2px;}
@@ -991,22 +978,6 @@ p {
line-height: normal;
}
-.text_20 {
- color: #FFF;
- font-size: 20px;
- font-style: normal;
- font-weight: 400;
- line-height: normal;
-}
-
-.text_20 {
- color: #FFF;
- font-size: 20px;
- font-style: normal;
- font-weight: 400;
- line-height: normal;
-}
-
.text_20_bold{
color: #FFF;
font-size: 20px;
@@ -1107,7 +1078,6 @@ p {
.w_73{width: 73%}
.w_97{width: 97%}
.w_100{width: 100%}
-.w_99vw{width: 99vw}
.w_inherit{width: inherit}
.w_fit_content{width:fit-content}
.w_inherit{width: inherit}
@@ -1125,11 +1095,11 @@ p {
.h_80vh{height: 80vh}
.h_calc92{height: calc(100vh - 92px)}
.h_calc_add40{height: calc(80vh + 40px)}
-.h_calc_sub_60{height: calc(92.5vh - 60px)}
.mxh_78vh{max-height: 78vh}
.flex_dir_col{flex-direction: column}
+.flex_none{flex: none}
.justify_center{justify-content: center}
.justify_end{justify-content: flex-end}
@@ -1138,8 +1108,6 @@ p {
.display_flex{display: inline-flex}
.display_flex_container{display: flex}
-.display_none{display: none}
-.display_block{display: block}
.align_center{align-items: center}
.align_start{align-items: flex-start}
@@ -1178,8 +1146,6 @@ p {
.bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);}
-.bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);}
-
.color_white{color:#FFFFFF}
.color_gray{color:#888888}
@@ -1188,7 +1154,7 @@ p {
.lh_18{line-height: 18px;}
.lh_24{line-height: 24px;}
-.padding_0{padding: 0}
+.padding_0{padding: 0;}
.padding_5{padding: 5px;}
.padding_6{padding: 6px;}
.padding_8{padding: 8px;}
@@ -1505,7 +1471,6 @@ tr{
.bg_none{background: none;}
.bg_primary{background: #2E293F;}
.bg_secondary{background: #272335;}
-.bg_none{background: none}
.container {
height: 100%;
@@ -1871,6 +1836,13 @@ tr{
padding: 12px;
}
+.success_box{
+ border-radius: 8px;
+ padding: 12px;
+ border-left: 4px solid rgba(255, 255, 255, 0.60);
+ background: rgba(255, 255, 255, 0.08);
+}
+
.horizontal_line {
margin: 16px 0 16px -16px;
border: 1px solid #ffffff20;
@@ -1922,26 +1894,4 @@ tr{
.tooltip-class {
background-color: green;
border-radius: 6px;
-}
-
-.text_dropdown {
- color: #FFFFFF;
- font-family: Plus Jakarta Sans, sans-serif;
- font-style: normal;
- font-weight: 500;
- line-height: normal;
-}
-
-.text_dropdown_18 {
- font-size: 18px;
-}
-
-.vertical_divider {
- background: transparent;
- /*border-color: rgba(255, 255, 255, 0.08);*/
- border: 1.2px solid rgba(255, 255, 255, 0.08);;
- height: 20px;
- width: 0;
-}
-
-
+}
\ No newline at end of file
diff --git a/gui/pages/_app.js b/gui/pages/_app.js
index a5466dfb6..8d1a69e38 100644
--- a/gui/pages/_app.js
+++ b/gui/pages/_app.js
@@ -61,6 +61,7 @@ export default function App() {
});
}
+
const installFromMarketplace = () => {
const toolkitName = localStorage.getItem('toolkit_to_install') || null;
const agentTemplateId = localStorage.getItem('agent_to_install') || null;
@@ -110,10 +111,11 @@ export default function App() {
.then((response) => {
const env = response.data.env;
setEnv(env);
-
+ const mixpanelInitialized = Cookies.get('mixpanel_initialized') === 'true'
if (typeof window !== 'undefined') {
- if(response.data.env === 'PROD' && mixpanelId())
- mixpanel.init(mixpanelId(), { debug: false, track_pageview: true, persistence: 'localStorage' });
+ if(response.data.env === 'PROD' && mixpanelId()) {
+ mixpanel.init(mixpanelId(), {debug: false, track_pageview: !mixpanelInitialized, persistence: 'localStorage'});
+ }
localStorage.setItem('applicationEnvironment', env);
}
@@ -122,7 +124,7 @@ export default function App() {
const queryParams = router.asPath.split('?')[1];
const parsedParams = querystring.parse(queryParams);
let access_token = parsedParams.access_token || null;
- let first_login = parsedParams.first_time_login || false
+ let first_login = parsedParams.first_time_login || ''
const utmParams = getUTMParametersFromURL();
if (utmParams) {
@@ -136,7 +138,7 @@ export default function App() {
if (typeof window !== 'undefined' && access_token) {
// localStorage.setItem('accessToken', access_token);
- Cookies.set('accessToken', access_token, { domain: '.superagi.com', path: '/' });
+ Cookies.set('accessToken', access_token, {domain: '.superagi.com', path: '/'});
refreshUrl();
}
validateAccessToken()
@@ -145,15 +147,19 @@ export default function App() {
sendGAEvent(response.data.email, 'Signed Up Successfully', {'utm_source': signupSource || '', 'utm_medium': signupMedium || '', 'campaign': singupCampaign || ''})
if(mixpanelId())
mixpanel.identify(response.data.email)
- if(first_login)
+ if(first_login === 'True') {
getUserClick('New Sign Up', {})
- else
- getUserClick('User Logged In', {})
+ }
+ else {
+ if (first_login === 'False')
+ getUserClick('User Logged In', {})
+ }
if(signupSource) {
handleSignUpSource(signupSource)
}
fetchOrganisation(response.data.id);
+ Cookies.set('mixpanel_initialized', 'true', {domain: '.superagi.com', path: '/'});
})
.catch((error) => {
console.error('Error validating access token:', error);
diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js
index 38e13c698..2e5f93869 100644
--- a/gui/pages/api/DashboardService.js
+++ b/gui/pages/api/DashboardService.js
@@ -358,8 +358,12 @@ export const verifyEndPoint = (model_api_key, end_point, model_provider) => {
});
}
-export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => {
- return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version});
+export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => {
+ return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length});
+}
+
+export const testModel = () => {
+ return api.get(`/models_controller/test_local_llm`);
}
export const fetchModels = () => {
@@ -389,7 +393,6 @@ export const getToolLogs = (toolName) => {
export const publishTemplateToMarketplace = (agentData) => {
return api.post(`/agent_templates/publish_template`, agentData);
};
-
export const getKnowledgeMetrics = (knowledgeName) => {
return api.get(`analytics/knowledge/${knowledgeName}/usage`)
}
diff --git a/main.py b/main.py
index 2f93cfa77..cdfb83486 100644
--- a/main.py
+++ b/main.py
@@ -52,6 +52,7 @@
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
from superagi.models.agent_template import AgentTemplate
+from superagi.models.models_config import ModelsConfig
from superagi.models.organisation import Organisation
from superagi.models.types.login_request import LoginRequest
from superagi.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
diff --git a/migrations/versions/9270eb5a8475_local_llms.py b/migrations/versions/9270eb5a8475_local_llms.py
new file mode 100644
index 000000000..7e6371e8a
--- /dev/null
+++ b/migrations/versions/9270eb5a8475_local_llms.py
@@ -0,0 +1,28 @@
+"""local_llms
+
+Revision ID: 9270eb5a8475
+Revises: 3867bb00a495
+Create Date: 2023-10-04 09:26:33.865424
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '9270eb5a8475'
+down_revision = '3867bb00a495'
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('models', sa.Column('context_length', sa.Integer(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('models', 'context_length')
+ # ### end Alembic commands ###
diff --git a/requirements.txt b/requirements.txt
index 4e00f97c6..a8a9f1dc6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -159,3 +159,4 @@ honeybadger==0.16.0
unstructured==0.8.1
ai21==1.2.6
typing-extensions==4.5.0
+llama_cpp_python==0.2.7
diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py
index 6748c10c3..e35198295 100644
--- a/superagi/controllers/models_controller.py
+++ b/superagi/controllers/models_controller.py
@@ -1,7 +1,9 @@
+from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from superagi.helper.auth import check_auth, get_user_organisation
from superagi.helper.models_helper import ModelsHelper
from superagi.apm.call_log_helper import CallLogHelper
+from superagi.lib.logger import logger
from superagi.models.models import Models
from superagi.models.models_config import ModelsConfig
from superagi.config.config import get_config
@@ -9,6 +11,7 @@
from fastapi_sqlalchemy import db
import logging
from pydantic import BaseModel
+from superagi.helper.llm_loader import LLMLoader
router = APIRouter()
@@ -26,6 +29,7 @@ class StoreModelRequest(BaseModel):
token_limit: int
type: str
version: str
+ context_length: Optional[int]
class ModelName (BaseModel):
model: str
@@ -69,7 +73,12 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod
@router.post("/store_model", status_code=200)
async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)):
try:
- return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version)
+ #context_length = 4096
+ logger.info(request)
+ if 'context_length' in request.dict():
+ return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length)
+ else:
+ return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, 0)
except Exception as e:
logging.error(f"Error storing the Model Details: {str(e)}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@@ -173,4 +182,32 @@ def get_models_details(page: int = 0):
marketplace_models = Models.fetch_marketplace_list(page)
marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id,
ModelsTypes.MARKETPLACE.value)
- return marketplace_models_with_install
\ No newline at end of file
+ return marketplace_models_with_install
+
+@router.get("/test_local_llm", status_code=200)
+def test_local_llm():
+ try:
+ llm_loader = LLMLoader(context_length=4096)
+ llm_model = llm_loader.model
+ llm_grammar = llm_loader.grammar
+ if llm_model is None:
+ logger.error("Model not found.")
+ raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
+ if llm_grammar is None:
+ logger.error("Grammar not found.")
+ raise HTTPException(status_code=404, detail="Grammar not found.")
+
+ messages = [
+ {"role":"system",
+ "content":"You are an AI assistant. Give response in a proper JSON format"},
+ {"role":"user",
+ "content":"Hi!"}
+ ]
+ response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar)
+ content = response["choices"][0]["message"]["content"]
+ logger.info(content)
+ return "Model loaded successfully."
+
+ except Exception as e:
+ logger.info("Error: ",e)
+ raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.")
\ No newline at end of file
diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py
index c550fd889..f4dce4b4f 100644
--- a/superagi/controllers/user.py
+++ b/superagi/controllers/user.py
@@ -14,6 +14,8 @@
from superagi.helper.auth import check_auth, get_current_user
from superagi.lib.logger import logger
+from superagi.models.models_config import ModelsConfig
+
# from superagi.types.db import UserBase, UserIn, UserOut
router = APIRouter()
@@ -73,6 +75,10 @@ def create_user(user: UserIn,
organisation = Organisation.find_or_create_organisation(db.session, db_user)
Project.find_or_create_default_project(db.session, organisation.id)
logger.info("User created", db_user)
+
+ #adding local llm configuration
+ ModelsConfig.add_llm_config(db.session, organisation.id)
+
return db_user
diff --git a/superagi/helper/encyption_helper.py b/superagi/helper/encyption_helper.py
index f2cf6fd13..4939b1d0c 100644
--- a/superagi/helper/encyption_helper.py
+++ b/superagi/helper/encyption_helper.py
@@ -1,9 +1,14 @@
from cryptography.fernet import Fernet, InvalidToken, InvalidSignature
-
+from superagi.config.config import get_config
# Generate a key
# key = Fernet.generate_key()
-key = b'e3mp0E0Jr3jnVb96A31_lKzGZlSTPIp4-rPaVseyn58='
+key = get_config("ENCRYPTION_KEY")
+if key is None:
+ raise Exception("Encryption key not found in config file.")
+key = key.encode(
+ "utf-8"
+)
cipher_suite = Fernet(key)
diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py
new file mode 100644
index 000000000..8d78337da
--- /dev/null
+++ b/superagi/helper/llm_loader.py
@@ -0,0 +1,38 @@
+from llama_cpp import Llama
+from llama_cpp import LlamaGrammar
+from superagi.config.config import get_config
+from superagi.lib.logger import logger
+
+
+class LLMLoader:
+ _instance = None
+ _model = None
+ _grammar = None
+
+ def __new__(cls, *args, **kwargs):
+ if cls._instance is None:
+ cls._instance = super(LLMLoader, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self, context_length):
+ self.context_length = context_length
+
+ @property
+ def model(self):
+ if self._model is None:
+ try:
+ self._model = Llama(
+ model_path="/app/local_model_path", n_ctx=self.context_length)
+ except Exception as e:
+ logger.error(e)
+ return self._model
+
+ @property
+ def grammar(self):
+ if self._grammar is None:
+ try:
+ self._grammar = LlamaGrammar.from_file(
+ "superagi/llms/grammar/json.gbnf")
+ except Exception as e:
+ logger.error(e)
+ return self._grammar
\ No newline at end of file
diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py
index 4418dd86e..45e91ee4d 100644
--- a/superagi/jobs/agent_executor.py
+++ b/superagi/jobs/agent_executor.py
@@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from sqlalchemy.orm import sessionmaker
+from superagi.llms.local_llm import LocalLLM
import superagi.worker
from superagi.agent.agent_iteration_step_handler import AgentIterationStepHandler
@@ -12,7 +13,6 @@
from superagi.lib.logger import logger
from superagi.llms.google_palm import GooglePalm
from superagi.llms.hugging_face import HuggingFace
-from superagi.llms.replicate import Replicate
from superagi.llms.llm_model_factory import get_model
from superagi.llms.replicate import Replicate
from superagi.models.agent import Agent
@@ -27,8 +27,6 @@
from superagi.worker import execute_agent
from superagi.agent.types.agent_workflow_step_action_types import AgentWorkflowStepAction
from superagi.agent.types.agent_execution_status import AgentExecutionStatus
-from superagi.vector_store.redis import Redis
-from superagi.config.config import get_config
# from superagi.helper.tool_helper import get_tool_config_by_key
@@ -138,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key):
return HuggingFace(api_key=model_api_key)
if "Replicate" in model_source:
return Replicate(api_key=model_api_key)
+ if "Custom" in model_source:
+ return LocalLLM()
return None
def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id):
@@ -183,4 +183,4 @@ def execute_waiting_workflows(self):
AgentWaitStepHandler(session=session, agent_id=agent_execution.agent_id,
agent_execution_id=agent_execution.id).handle_next_step()
execute_agent.delay(agent_execution.id, datetime.now())
- session.close()
+ session.close()
\ No newline at end of file
diff --git a/superagi/llms/grammar/json.gbnf b/superagi/llms/grammar/json.gbnf
new file mode 100644
index 000000000..a9537cdf9
--- /dev/null
+++ b/superagi/llms/grammar/json.gbnf
@@ -0,0 +1,25 @@
+root ::= object
+value ::= object | array | string | number | ("true" | "false" | "null") ws
+
+object ::=
+ "{" ws (
+ string ":" ws value
+ ("," ws string ":" ws value)*
+ )? "}" ws
+
+array ::=
+ "[" ws (
+ value
+ ("," ws value)*
+ )? "]" ws
+
+string ::=
+ "\"" (
+ [^"\\] |
+ "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
+ )* "\"" ws
+
+number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
+
+# Optional space: by convention, applied in this grammar after literal chars when allowed
+ws ::= ([ \t\n] ws)?
diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py
index 251c64a71..345c4f8c7 100644
--- a/superagi/llms/llm_model_factory.py
+++ b/superagi/llms/llm_model_factory.py
@@ -1,4 +1,5 @@
from superagi.llms.google_palm import GooglePalm
+from superagi.llms.local_llm import LocalLLM
from superagi.llms.openai import OpenAi
from superagi.llms.replicate import Replicate
from superagi.llms.hugging_face import HuggingFace
@@ -33,6 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs):
elif provider_name == 'Hugging Face':
print("Provider is Hugging Face")
return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs)
+ elif provider_name == 'Local LLM':
+ print("Provider is Local LLM")
+ return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length)
else:
print('Unknown provider.')
@@ -45,5 +49,7 @@ def build_model_with_api_key(provider_name, api_key):
return GooglePalm(api_key=api_key)
elif provider_name.lower() == 'hugging face':
return HuggingFace(api_key=api_key)
+ elif provider_name.lower() == 'local llm':
+ return LocalLLM(api_key=api_key)
else:
print('Unknown provider.')
\ No newline at end of file
diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py
new file mode 100644
index 000000000..a146d7daa
--- /dev/null
+++ b/superagi/llms/local_llm.py
@@ -0,0 +1,92 @@
+from superagi.config.config import get_config
+from superagi.lib.logger import logger
+from superagi.llms.base_llm import BaseLlm
+from superagi.helper.llm_loader import LLMLoader
+
+
+class LocalLLM(BaseLlm):
+ def __init__(self, temperature=0.6, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT"), top_p=1,
+ frequency_penalty=0,
+ presence_penalty=0, number_of_results=1, model=None, api_key='EMPTY', context_length=4096):
+ """
+ Args:
+ model (str): The model.
+ temperature (float): The temperature.
+ max_tokens (int): The maximum number of tokens.
+ top_p (float): The top p.
+ frequency_penalty (float): The frequency penalty.
+ presence_penalty (float): The presence penalty.
+ number_of_results (int): The number of results.
+ """
+ self.model = model
+ self.api_key = api_key
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+ self.top_p = top_p
+ self.frequency_penalty = frequency_penalty
+ self.presence_penalty = presence_penalty
+ self.number_of_results = number_of_results
+ self.context_length = context_length
+
+ llm_loader = LLMLoader(self.context_length)
+ self.llm_model = llm_loader.model
+ self.llm_grammar = llm_loader.grammar
+
+ def chat_completion(self, messages, max_tokens=get_config("MAX_MODEL_TOKEN_LIMIT")):
+ """
+ Call the chat completion.
+
+ Args:
+ messages (list): The messages.
+ max_tokens (int): The maximum number of tokens.
+
+ Returns:
+ dict: The response.
+ """
+ try:
+ if self.llm_model is None or self.llm_grammar is None:
+ logger.error("Model not found.")
+ return {"error": "Model loading error", "message": "Model not found. Please check your model path and try again."}
+ else:
+ response = self.llm_model.create_chat_completion(messages=messages, functions=None, function_call=None, temperature=self.temperature, top_p=self.top_p,
+ max_tokens=int(max_tokens), presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, grammar=self.llm_grammar)
+ content = response["choices"][0]["message"]["content"]
+ logger.info(content)
+ return {"response": response, "content": content}
+
+ except Exception as exception:
+ logger.info("Exception:", exception)
+ return {"error": "ERROR", "message": "Error: "+str(exception)}
+
+ def get_source(self):
+ """
+ Get the source.
+
+ Returns:
+ str: The source.
+ """
+ return "Local LLM"
+
+ def get_api_key(self):
+ """
+ Returns:
+ str: The API key.
+ """
+ return self.api_key
+
+ def get_model(self):
+ """
+ Returns:
+ str: The model.
+ """
+ return self.model
+
+ def get_models(self):
+ """
+ Returns:
+ list: The models.
+ """
+ return self.model
+
+ def verify_access_key(self, api_key):
+ return True
\ No newline at end of file
diff --git a/superagi/models/models.py b/superagi/models/models.py
index 547cacf3a..47241afea 100644
--- a/superagi/models/models.py
+++ b/superagi/models/models.py
@@ -1,3 +1,4 @@
+import yaml
from sqlalchemy import Column, Integer, String, and_
from sqlalchemy.sql import func
from typing import List, Dict, Union
@@ -5,6 +6,7 @@
from superagi.controllers.types.models_types import ModelsTypes
from superagi.helper.encyption_helper import decrypt_data
import requests, logging
+from superagi.lib.logger import logger
marketplace_url = "https://app.superagi.com/api"
# marketplace_url = "http://localhost:8001"
@@ -39,6 +41,7 @@ class Models(DBBaseModel):
version = Column(String, nullable=False)
org_id = Column(Integer, nullable=False)
model_features = Column(String, nullable=False)
+ context_length = Column(Integer, nullable=True)
def __repr__(self):
"""
@@ -104,7 +107,7 @@ def fetch_model_tokens(cls, session, organisation_id) -> Dict[str, int]:
return {"error": "Unexpected Error Occured"}
@classmethod
- def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version):
+ def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version, context_length):
from superagi.models.models_config import ModelsConfig
if not model_name:
return {"error": "Model Name is empty or undefined"}
@@ -130,9 +133,12 @@ def store_model_details(cls, session, organisation_id, model_name, description,
return model # Return error message if model not found
# Check the 'provider' from ModelsConfig table
- if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate']:
+ if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']:
return {"error": "End Point is empty or undefined"}
+ if context_length is None:
+ context_length = 0
+
try:
model = Models(
model_name=model_name,
@@ -143,7 +149,8 @@ def store_model_details(cls, session, organisation_id, model_name, description,
type=type,
version=version,
org_id=organisation_id,
- model_features=''
+ model_features='',
+ context_length=context_length
)
session.add(model)
session.commit()
@@ -229,4 +236,4 @@ def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[st
except Exception as e:
logging.error(f"Unexpected Error Occured: {e}")
- return {"error": "Unexpected Error Occured"}
+ return {"error": "Unexpected Error Occured"}
\ No newline at end of file
diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py
index 0c8c13b95..03833ef05 100644
--- a/superagi/models/models_config.py
+++ b/superagi/models/models_config.py
@@ -1,4 +1,5 @@
from sqlalchemy import Column, Integer, String, and_, distinct
+from superagi.lib.logger import logger
from superagi.models.base_model import DBBaseModel
from superagi.models.organisation import Organisation
from superagi.models.project import Project
@@ -69,6 +70,9 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str):
if not config:
return None
+ if config.provider == 'Local LLM':
+ return {"provider": config.provider, "api_key": config.api_key} if config else None
+
return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None
@classmethod
@@ -102,7 +106,7 @@ def storeGptModels(cls, session, organisation_id, model_provider_id, model_api_k
for model in models:
if model not in installed_models and model in default_models:
result = Models.store_model_details(session, organisation_id, model, model, '',
- model_provider_id, default_models[model], 'Custom', '')
+ model_provider_id, default_models[model], 'Custom', '', 0)
@classmethod
def fetch_api_keys(cls, session, organisation_id):
@@ -123,8 +127,13 @@ def fetch_api_key(cls, session, organisation_id, model_provider):
api_key_data = session.query(ModelsConfig.id, ModelsConfig.provider, ModelsConfig.api_key).filter(
and_(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == model_provider)).first()
+ logger.info(api_key_data)
if api_key_data is None:
return []
+ elif api_key_data.provider == 'Local LLM':
+ api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
+ 'api_key': api_key_data.api_key}]
+ return api_key
else:
api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider,
'api_key': decrypt_data(api_key_data.api_key)}]
@@ -145,4 +154,12 @@ def fetch_model_by_id_marketplace(cls, session, model_provider_id):
if model is None:
return {"error": "Model not found"}
else:
- return {"provider": model.provider}
\ No newline at end of file
+ return {"provider": model.provider}
+
+ @classmethod
+ def add_llm_config(cls, session, organisation_id):
+ existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == 'Local LLM').first()
+ if existing_models_config is None:
+ models_config = ModelsConfig(org_id=organisation_id, provider='Local LLM', api_key="EMPTY")
+ session.add(models_config)
+ session.commit()
\ No newline at end of file
diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py
index f811a60c6..6e9de18ad 100644
--- a/superagi/types/model_source_types.py
+++ b/superagi/types/model_source_types.py
@@ -6,6 +6,7 @@ class ModelSourceType(Enum):
OpenAI = 'OpenAi'
Replicate = 'Replicate'
HuggingFace = 'Hugging Face'
+ LocalLLM = 'Local LLM'
@classmethod
def get_model_source_type(cls, name):
diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py
index 489cff636..790229789 100644
--- a/tests/unit_tests/controllers/test_models_controller.py
+++ b/tests/unit_tests/controllers/test_models_controller.py
@@ -2,6 +2,11 @@
import pytest
from fastapi.testclient import TestClient
from main import app
+from llama_cpp import Llama
+from llama_cpp import LlamaGrammar
+import llama_cpp
+
+from superagi.helper.llm_loader import LLMLoader
client = TestClient(app)
@@ -50,7 +55,8 @@ def test_store_model_success(mock_get_db):
"model_provider_id": 1,
"token_limit": 10,
"type": "mock_type",
- "version": "mock_version"
+ "version": "mock_version",
+ "context_length":4096
}
with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \
patch('superagi.helper.auth.db') as mock_auth_db:
@@ -100,3 +106,13 @@ def test_get_marketplace_models_list_success(mock_get_db):
patch('superagi.helper.auth.db') as mock_auth_db:
response = client.get("/models_controller/marketplace/list/0")
assert response.status_code == 200
+
+def test_get_local_llm():
+ with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model:
+ with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar:
+
+ mock_model.create_chat_completion.return_value = {"choices": [{"message": {"content": "Hello!"}}]}
+
+ response = client.get("/models_controller/test_local_llm")
+
+ assert response.status_code == 200
\ No newline at end of file
diff --git a/tests/unit_tests/models/test_models.py b/tests/unit_tests/models/test_models.py
index 3bdc43075..d4880538c 100644
--- a/tests/unit_tests/models/test_models.py
+++ b/tests/unit_tests/models/test_models.py
@@ -133,6 +133,7 @@ def test_store_model_details_when_model_exists(mock_session):
token_limit=500,
type="type",
version="v1.0",
+ context_length=4096
)
# Assert
@@ -161,6 +162,7 @@ def test_store_model_details_when_model_not_exists(mock_session, monkeypatch):
token_limit=500,
type="type",
version="v1.0",
+ context_length=4096
)
# Assert
@@ -187,6 +189,7 @@ def test_store_model_details_when_unexpected_error_occurs(mock_session, monkeypa
token_limit=500,
type="type",
version="v1.0",
+ context_length=4096
)
# Assert
@@ -229,6 +232,4 @@ def test_fetch_model_details(mock_models_config, mock_session):
"token_limit": 100,
"type": "type1",
"model_provider": "example_provider"
- }
-
-
+ }
\ No newline at end of file