From 42b22c610c67943362d26a1a2affb5cc624f4704 Mon Sep 17 00:00:00 2001 From: Rounak Bhatia Date: Tue, 31 Oct 2023 17:06:02 +0530 Subject: [PATCH] fixes (#1353) --- main.py | 10 +--------- superagi/controllers/user.py | 6 ++++++ superagi/models/models_config.py | 10 +++++++++- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index cf4807483..bae2f1a87 100644 --- a/main.py +++ b/main.py @@ -216,13 +216,6 @@ def register_toolkit_for_master_organisation(): Organisation.id == marketplace_organisation_id).first() if marketplace_organisation is not None: register_marketplace_toolkits(session, marketplace_organisation) - - def local_llm_model_config(): - existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == default_user.organisation_id, ModelsConfig.provider == 'Local LLM').first() - if existing_models_config is None: - models_config = ModelsConfig(org_id=default_user.organisation_id, provider='Local LLM', api_key="EMPTY") - session.add(models_config) - session.commit() IterationWorkflowSeed.build_single_step_agent(session) IterationWorkflowSeed.build_task_based_agents(session) @@ -246,8 +239,7 @@ def local_llm_model_config(): # AgentWorkflowSeed.doc_search_and_code(session) # AgentWorkflowSeed.build_research_email_workflow(session) replace_old_iteration_workflows(session) - local_llm_model_config() - + if env != "PROD": register_toolkit_for_all_organisation() else: diff --git a/superagi/controllers/user.py b/superagi/controllers/user.py index c550fd889..f4dce4b4f 100644 --- a/superagi/controllers/user.py +++ b/superagi/controllers/user.py @@ -14,6 +14,8 @@ from superagi.helper.auth import check_auth, get_current_user from superagi.lib.logger import logger +from superagi.models.models_config import ModelsConfig + # from superagi.types.db import UserBase, UserIn, UserOut router = APIRouter() @@ -73,6 +75,10 @@ def create_user(user: UserIn, organisation = Organisation.find_or_create_organisation(db.session, db_user) Project.find_or_create_default_project(db.session, organisation.id) logger.info("User created", db_user) + + #adding local llm configuration + ModelsConfig.add_llm_config(db.session, organisation.id) + return db_user diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index 0c8c13b95..998e8170c 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -145,4 +145,12 @@ def fetch_model_by_id_marketplace(cls, session, model_provider_id): if model is None: return {"error": "Model not found"} else: - return {"provider": model.provider} \ No newline at end of file + return {"provider": model.provider} + + @classmethod + def add_llm_config(cls, session, organisation_id): + existing_models_config = session.query(ModelsConfig).filter(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == 'Local LLM').first() + if existing_models_config is None: + models_config = ModelsConfig(org_id=organisation_id, provider='Local LLM', api_key="EMPTY") + session.add(models_config) + session.commit() \ No newline at end of file