diff --git a/controller/src/controller/__main__.py b/controller/src/controller/__main__.py index 31f49ef..b9cee6c 100644 --- a/controller/src/controller/__main__.py +++ b/controller/src/controller/__main__.py @@ -46,7 +46,7 @@ def initdb(): """ click.echo("Running Init DB") db_session = client.get_db_session() - client.create_tables(True) + client.create_database(True) # Create admin user: click.echo("Creating admin user") diff --git a/controller/src/controller/config.py b/controller/src/controller/config.py index 09f21ca..1aff52c 100644 --- a/controller/src/controller/config.py +++ b/controller/src/controller/config.py @@ -32,11 +32,13 @@ class CtrlConfig(BaseModel): """Configuration for the agent.""" - verbose: bool = True log_level: str = "DEBUG" - # SQL Database + # database kwargs: + db: dict[str, str] = { + "db_url": default_db_path, + "verbose": True, + } db_type: str = "sql" - sql_connection_str: str = default_db_path application_url: str = "http://localhost:8000" def print(self): diff --git a/controller/src/controller/db/__init__.py b/controller/src/controller/db/__init__.py index 55c3bc1..1c1badb 100644 --- a/controller/src/controller/db/__init__.py +++ b/controller/src/controller/db/__init__.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum + from controller.config import config -from controller.db.sqlclient import SqlClient client = None + +class DatabaseType(str, Enum): + SQL = "sql" + + if config.db_type == "sql": - client = SqlClient(config.sql_connection_str, verbose=config.verbose) + from controller.db.sql import SqlClient + + client = SqlClient(**config.db) diff --git a/controller/src/controller/db/client.py b/controller/src/controller/db/client.py new file mode 100644 index 0000000..f19875f --- /dev/null +++ b/controller/src/controller/db/client.py @@ -0,0 +1,732 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List, Optional, Type, Union + +import genai_factory.schemas as api_models + + +class Client(ABC): + @abstractmethod + def get_local_session(self): + """ + Get a local session from the local session maker. + This is the session that is inserted into the API endpoints. + + :return: The session. + """ + pass + + @abstractmethod + def create_database(self, drop_old: bool = False, names: list = None): + """ + Create a new database. + + :param drop_old: Whether to drop the old data before creating the new data. + :param names: The names of the entities to create. If None, all entities will be created. + """ + pass + + @abstractmethod + def create_user( + self, user: Union[api_models.User, dict], **kwargs + ) -> api_models.User: + """ + Create a new user in the database. + + :param user: The user object to create. + + :return: The created user. + """ + pass + + @abstractmethod + def get_user( + self, uid: str = None, name: str = None, email: str = None, **kwargs + ) -> Optional[api_models.User]: + """ + Get a user from the database. + Either user_id or user_name or email must be provided. + + :param uid: The UID of the user to get. + :param name: The name of the user to get. + :param email: The email of the user to get. + + :return: The user. + """ + pass + + @abstractmethod + def update_user( + self, name: str, user: Union[api_models.User, dict], **kwargs + ) -> api_models.User: + """ + Update an existing user in the database. + + :param name: The name of the user to update. + :param user: The user object with the new data. + + :return: The updated user. + """ + pass + + @abstractmethod + def delete_user(self, name: str, **kwargs): + """ + Delete a user from the database. + + :param name: The name of the user to delete. + """ + pass + + @abstractmethod + def list_users( + self, + name: str = None, + email: str = None, + full_name: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.User]]: + """ + List users from the database. + + :param name: The name to filter the users by. + :param email: The email to filter the users by. + :param full_name: The full name to filter the users by. + :param labels_match: The labels to match, filter the users by labels. + :param output_mode: The output mode. + + :return: List of users. + """ + pass + + @abstractmethod + def create_project( + self, project: Union[api_models.Project, dict], **kwargs + ) -> api_models.Project: + """ + Create a new project in the database. + + :param project: The project object to create. + + :return: The created project. + """ + pass + + @abstractmethod + def get_project(self, name: str, **kwargs) -> Optional[api_models.Project]: + """ + Get a project from the database. + + :param name: The name of the project to get. + + :return: The requested project. + """ + pass + + @abstractmethod + def update_project( + self, name: str, project: Union[api_models.Project, dict], **kwargs + ) -> api_models.Project: + """ + Update an existing project in the database. + + :param name: The name of the project to update. + :param project: The project object with the new data. + + :return: The updated project. + """ + pass + + @abstractmethod + def delete_project(self, name: str, **kwargs): + """ + Delete a project from the database. + + :param name: The name of the project to delete. + """ + pass + + @abstractmethod + def list_projects( + self, + name: str = None, + owner_id: str = None, + version: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.Project]]: + """ + List projects from the database. + + :param name: The name to filter the projects by. + :param owner_id: The owner to filter the projects by. + :param version: The version to filter the projects by. + :param labels_match: The labels to match, filter the projects by labels. + :param output_mode: The output mode. + + :return: List of projects. + """ + pass + + @abstractmethod + def create_data_source( + self, data_source: Union[api_models.DataSource, dict], **kwargs + ) -> api_models.DataSource: + """ + Create a new data source in the database. + + :param data_source: The data source object to create. + + :return: The created data source. + """ + pass + + @abstractmethod + def get_data_source(self, name: str, **kwargs) -> Optional[api_models.DataSource]: + """ + Get a data source from the database. + + :param name: The name of the data source to get. + + :return: The requested data source. + """ + pass + + @abstractmethod + def update_data_source( + self, name: str, data_source: Union[api_models.DataSource, dict], **kwargs + ) -> api_models.DataSource: + """ + Update an existing data source in the database. + + :param name: The name of the data source to update. + :param data_source: The data source object with the new data. + + :return: The updated data source. + """ + pass + + @abstractmethod + def delete_data_source(self, name: str, **kwargs): + """ + Delete a data source from the database. + + :param name: The name of the data source to delete. + + :return: A response object with the success status. + """ + pass + + @abstractmethod + def list_data_sources( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + data_source_type: Union[api_models.DataSourceType, str] = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.DataSource]]: + """ + List data sources from the database. + + :param name: The name to filter the data sources by. + :param owner_id: The owner to filter the data sources by. + :param version: The version to filter the data sources by. + :param project_id: The project to filter the data sources by. + :param data_source_type: The data source type to filter the data sources by. + :param labels_match: The labels to match, filter the data sources by labels. + :param output_mode: The output mode. + + :return: List of data sources. + """ + pass + + @abstractmethod + def create_dataset( + self, dataset: Union[api_models.Dataset, dict], **kwargs + ) -> api_models.Dataset: + """ + Create a new dataset in the database. + + :param dataset: The dataset object to create. + + :return: The created dataset. + """ + pass + + @abstractmethod + def get_dataset(self, name: str, **kwargs) -> Optional[api_models.Dataset]: + """ + Get a dataset from the database. + + :param name: The name of the dataset to get. + + :return: The requested dataset. + """ + pass + + @abstractmethod + def update_dataset( + self, name: str, dataset: Union[api_models.Dataset, dict], **kwargs + ) -> api_models.Dataset: + """ + Update an existing dataset in the database. + + :param name: The name of the dataset to update. + :param dataset: The dataset object with the new data. + + :return: The updated dataset. + """ + pass + + @abstractmethod + def delete_dataset(self, name: str, **kwargs): + """ + Delete a dataset from the database. + + :param name: The name of the dataset to delete. + """ + pass + + @abstractmethod + def list_datasets( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + task: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.Dataset]]: + """ + List datasets from the database. + + :param name: The name to filter the datasets by. + :param owner_id: The owner to filter the datasets by. + :param version: The version to filter the datasets by. + :param project_id: The project to filter the datasets by. + :param task: The task to filter the datasets by. + :param labels_match: The labels to match, filter the datasets by labels. + :param output_mode: The output mode. + + :return: The list of datasets. + """ + pass + + @abstractmethod + def create_model(self, model: Union[api_models.Model, dict]) -> api_models.Model: + """ + Create a new model in the database. + + :param model: The model object to create. + + :return: The created model. + """ + pass + + @abstractmethod + def get_model(self, name: str, **kwargs) -> Optional[api_models.Model]: + """ + Get a model from the database. + + :param name: The name of the model to get. + + :return: The requested model. + """ + pass + + @abstractmethod + def update_model( + self, name: str, model: Union[api_models.Model, dict], **kwargs + ) -> api_models.Model: + """ + Update an existing model in the database. + + :param name: The name of the model to update. + :param model: The model object with the new data. + + :return: The updated model. + """ + pass + + @abstractmethod + def delete_model(self, name: str, **kwargs): + """ + Delete a model from the database. + + :param name: The name of the model to delete. + """ + pass + + @abstractmethod + def list_models( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + model_type: str = None, + task: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.Model]]: + """ + List models from the database. + + :param name: The name to filter the models by. + :param owner_id: The owner to filter the models by. + :param version: The version to filter the models by. + :param project_id: The project to filter the models by. + :param model_type: The model type to filter the models by. + :param task: The task to filter the models by. + :param labels_match: The labels to match, filter the models by labels. + :param output_mode: The output mode. + + :return: The list of models. + """ + pass + + @abstractmethod + def create_prompt_template( + self, prompt_template: Union[api_models.PromptTemplate, dict], **kwargs + ) -> api_models.PromptTemplate: + """ + Create a new prompt template in the database. + + :param prompt_template: The prompt template object to create. + + :return: The created prompt template. + """ + pass + + @abstractmethod + def get_prompt_template( + self, name: str, **kwargs + ) -> Optional[api_models.PromptTemplate]: + """ + Get a prompt template from the database. + + :param name: The name of the prompt template to get. + + :return: The requested prompt template. + """ + pass + + @abstractmethod + def update_prompt_template( + self, + name: str, + prompt_template: Union[api_models.PromptTemplate, dict], + **kwargs, + ) -> api_models.PromptTemplate: + """ + Update an existing prompt template in the database. + + :param name: The name of the prompt template to update. + :param prompt_template: The prompt template object with the new data. + + :return: The updated prompt template. + """ + pass + + @abstractmethod + def delete_prompt_template(self, name: str, **kwargs): + """ + Delete a prompt template from the database. + + :param name: The name of the prompt template to delete. + """ + pass + + @abstractmethod + def list_prompt_templates( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.PromptTemplate]]: + """ + List prompt templates from the database. + + :param name: The name to filter the prompt templates by. + :param owner_id: The owner to filter the prompt templates by. + :param version: The version to filter the prompt templates by. + :param project_id: The project to filter the prompt templates by. + :param labels_match: The labels to match, filter the prompt templates by labels. + :param output_mode: The output mode. + + :return: The list of prompt templates. + """ + pass + + @abstractmethod + def create_document( + self, document: Union[api_models.Document, dict], **kwargs + ) -> api_models.Document: + """ + Create a new document in the database. + + :param document: The document object to create. + + :return: The created document. + """ + pass + + @abstractmethod + def get_document(self, name: str, **kwargs) -> Optional[api_models.Document]: + """ + Get a document from the database. + + :param name: The name of the document to get. + + :return: The requested document. + """ + pass + + @abstractmethod + def update_document( + self, name: str, document: Union[api_models.Document, dict], **kwargs + ) -> api_models.Document: + """ + Update an existing document in the database. + + :param name: The name of the document to update. + :param document: The document object with the new data. + + :return: The updated document. + """ + pass + + @abstractmethod + def delete_document(self, name: str, **kwargs): + """ + Delete a document from the database. + + :param name: The name of the document to delete. + """ + pass + + @abstractmethod + def list_documents( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.Document]]: + """ + List documents from the database. + + :param name: The name to filter the documents by. + :param owner_id: The owner to filter the documents by. + :param version: The version to filter the documents by. + :param project_id: The project to filter the documents by. + :param labels_match: The labels to match, filter the documents by labels. + :param output_mode: The output mode. + + :return: The list of documents. + """ + pass + + @abstractmethod + def create_workflow( + self, workflow: Union[api_models.Workflow, dict], **kwargs + ) -> api_models.Workflow: + """ + Create a new workflow in the database. + + :param workflow: The workflow object to create. + + :return: The created workflow. + """ + pass + + @abstractmethod + def get_workflow(self, name: str, **kwargs) -> Type[api_models.Base]: + """ + Get a workflow from the database. + + :param name: The name of the workflow to get. + + :return: The requested workflow. + """ + pass + + @abstractmethod + def update_workflow( + self, name: str, workflow: Union[api_models.Workflow, dict], **kwargs + ) -> api_models.Workflow: + """ + Update an existing workflow in the database. + + :param name: The name of the workflow to update. + :param workflow: The workflow object with the new data. + + :return: The updated workflow. + """ + pass + + @abstractmethod + def delete_workflow(self, name: str, **kwargs): + """ + Delete a workflow from the database. + + :param name: The name of the workflow to delete. + """ + pass + + @abstractmethod + def list_workflows( + self, + name: str = None, + owner_id: str = None, + version: str = None, + project_id: str = None, + workflow_type: Union[api_models.WorkflowType, str] = None, + labels_match: Union[list, str] = None, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ) -> List[Optional[api_models.Workflow]]: + """ + List workflows from the database. + + :param name: The name to filter the workflows by. + :param owner_id: The owner to filter the workflows by. + :param version: The version to filter the workflows by. + :param project_id: The project to filter the workflows by. + :param workflow_type: The workflow type to filter the workflows by. + :param labels_match: The labels to match, filter the workflows by labels. + :param output_mode: The output mode. + + :return: The list of workflows. + """ + pass + + @abstractmethod + def create_session( + self, session: Union[api_models.ChatSession, dict], **kwargs + ) -> api_models.ChatSession: + """ + Create a new session in the database. + + :param session: The chat session object to create. + + :return: The created session. + """ + pass + + @abstractmethod + def get_session( + self, name: str = None, uid: str = None, user_id: str = None, **kwargs + ) -> Optional[api_models.ChatSession]: + """ + Get a session from the database. + + :param name: The name of the session to get. + :param uid: The ID of the session to get. + :param user_id: The UID of the user to get the last session for. + + :return: The requested session. + """ + pass + + @abstractmethod + def update_session( + self, name: str, session: Union[api_models.ChatSession, dict], **kwargs + ) -> api_models.ChatSession: + """ + Update a session in the database. + + :param name: The name of the session to update. + :param session: The session object with the new data. + + :return: The updated chat session. + """ + pass + + @abstractmethod + def delete_session(self, name: str, **kwargs): + """ + Delete a session from the database. + + :param name: The name of the session to delete. + """ + pass + + @abstractmethod + def list_sessions( + self, + name: str = None, + user_id: str = None, + workflow_id: str = None, + created_after=None, + last=0, + output_mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + **kwargs, + ): + """ + List sessions from the database. + + :param name: The name to filter the chat sessions by. + :param user_id: The user ID to filter the chat sessions by. + :param workflow_id: The workflow ID to filter the chat sessions by. + :param created_after: The date to filter the chat sessions by. + :param last: The number of last chat sessions to return. + :param output_mode: The output mode. + + :return: The list of chat sessions. + """ + pass + + def _process_output( + self, + items, + obj_class, + mode: api_models.OutputMode = api_models.OutputMode.DETAILS, + ) -> Union[list, dict]: + """ + Process the output of a query. Use this method to convert the output to the desired format. + For example when listing. + + :param items: The items to process. + :param obj_class: The class of the items. + :param mode: The output mode. + + :return: The processed items. + """ + if mode == api_models.OutputMode.NAMES: + return [item.name for item in items] + items = [self._from_db_object(item, obj_class) for item in items] + if mode == api_models.OutputMode.DETAILS: + return items + short = mode == api_models.OutputMode.SHORT + return [item.to_dict(short=short) for item in items] diff --git a/controller/src/controller/db/sql/__init__.py b/controller/src/controller/db/sql/__init__.py new file mode 100644 index 0000000..a1d4b1b --- /dev/null +++ b/controller/src/controller/db/sql/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from controller.db.sql.sqlclient import SqlClient diff --git a/controller/src/controller/db/sqlclient.py b/controller/src/controller/db/sql/sqlclient.py similarity index 91% rename from controller/src/controller/db/sqlclient.py rename to controller/src/controller/db/sql/sqlclient.py index 3521bde..7ed0b00 100644 --- a/controller/src/controller/db/sqlclient.py +++ b/controller/src/controller/db/sql/sqlclient.py @@ -19,12 +19,13 @@ import sqlalchemy from sqlalchemy.orm import sessionmaker -import controller.db.sqldb as db +import controller.db.sql.sqldb as db import genai_factory.schemas as api_models from controller.config import logger +from controller.db.client import Client -class SqlClient: +class SqlClient(Client): """ This is the SQL client that interact with the SQL database. """ @@ -32,7 +33,7 @@ class SqlClient: def __init__(self, db_url: str, verbose: bool = False): self.db_url = db_url self.engine = sqlalchemy.create_engine( - db_url, echo=verbose, connect_args={"check_same_thread": False} + self.db_url, echo=verbose, connect_args={"check_same_thread": False} ) self._session_maker = sessionmaker(bind=self.engine) self._local_maker = sessionmaker( @@ -57,7 +58,97 @@ def get_local_session(self): """ return self._local_maker() - def create_tables(self, drop_old: bool = False, names: list = None): + @staticmethod + def _to_schema_object( + obj, schema_class: Type[api_models.Base] + ) -> Type[api_models.Base]: + """ + Convert an object from the database to an API object. + + :param obj: The object from the database. + :param schema_class: The API class of the object. + + :return: The API object. + """ + object_dict = {} + for field in obj.__table__.columns: + object_dict[field.name] = getattr(obj, field.name) + spec = object_dict.pop("spec", {}) + object_dict.update(spec) + if obj.labels: + object_dict["labels"] = {label.name: label.value for label in obj.labels} + return schema_class.from_dict(object_dict) + + @staticmethod + def _to_db_object(obj, obj_class, uid=None): + """ + Convert an API object to a database object. + + :param obj: The API object. + :param obj_class: The DB class of the object. + :param uid: The UID of the object. + + :return: The database object. + """ + struct = obj.to_dict(drop_none=False, short=False) + obj_dict = { + k: v + for k, v in struct.items() + if k in (api_models.metadata_fields + obj._top_level_fields) + and k not in ["created", "updated"] + } + obj_dict["spec"] = { + k: v + for k, v in struct.items() + if k not in api_models.metadata_fields + obj._top_level_fields + } + labels = obj_dict.pop("labels", None) + if uid: + obj_dict["uid"] = uid + obj = obj_class(**obj_dict) + if labels: + obj.labels.clear() + for name, value in labels.items(): + obj.labels.append(obj.Label(name=name, value=value, parent=obj.name)) + return obj + + @staticmethod + def _merge_into_db_object(obj, orm_object): + """ + Merge an API object into a database object. + + :param obj: The API object. + :param orm_object: The ORM object. + + :return: The updated ORM object. + """ + struct = obj.to_dict(drop_none=True) + spec = orm_object.spec or {} + labels = struct.pop("labels", None) + for k, v in struct.items(): + if k in (api_models.metadata_fields + obj._top_level_fields) and k not in [ + "created", + "updated", + ]: + setattr(orm_object, k, v) + if k not in [api_models.metadata_fields + obj._top_level_fields]: + spec[k] = v + orm_object.spec = spec + if labels: + old = {label.name: label for label in orm_object.labels} + orm_object.labels.clear() + for name, value in labels.items(): + if name in old: + if value is not None: # None means delete + old[name].value = value + orm_object.labels.append(old[name]) + else: + orm_object.labels.append( + orm_object.Label(name=name, value=value, parent=orm_object.name) + ) + return orm_object + + def create_database(self, drop_old: bool = False, names: list = None): """ Create the tables in the database. @@ -87,10 +178,10 @@ def _create( session = self.get_db_session(session) # try: uid = uuid.uuid4().hex - db_object = obj.to_orm_object(db_class, uid=uid) + db_object = self._to_db_object(obj, db_class, uid=uid) session.add(db_object) session.commit() - return obj.__class__.from_orm_object(db_object) + return self._to_schema_object(db_object, obj.__class__) def _get( self, session: sqlalchemy.orm.Session, db_class, api_class, **kwargs @@ -116,7 +207,7 @@ def _get( else: obj = query.one_or_none() if obj: - return api_class.from_orm_object(obj) + return self._to_schema_object(obj, api_class) def _update( self, session: sqlalchemy.orm.Session, db_class, api_object, **kwargs @@ -135,10 +226,10 @@ def _update( session = self.get_db_session(session) obj = session.query(db_class).filter_by(**kwargs).one_or_none() if obj: - api_object.merge_into_orm_object(obj) + obj = self._merge_into_db_object(api_object, obj) session.add(obj) session.commit() - return api_object.__class__.from_orm_object(obj) + return self._to_schema_object(obj, api_object.__class__) else: # Create a new object if not found logger.debug(f"Object not found, creating a new one: {api_object}") @@ -192,7 +283,7 @@ def _list( pass output = query.all() logger.debug(f"output: {output}") - return _process_output(output, api_class, output_mode) + return self._process_output(output, api_class, output_mode) @staticmethod def _drop_none(**kwargs): @@ -1268,22 +1359,4 @@ def list_sessions( query = query.order_by(db.Session.updated.desc()) if last > 0: query = query.limit(last) - return _process_output(query.all(), api_models.ChatSession, output_mode) - - -def _dict_to_object(cls, d): - if isinstance(d, dict): - return cls.from_dict(d) - return d - - -def _process_output( - items, obj_class, mode: api_models.OutputMode = api_models.OutputMode.DETAILS -) -> Union[list, dict]: - if mode == api_models.OutputMode.NAMES: - return [item.name for item in items] - items = [obj_class.from_orm_object(item) for item in items] - if mode == api_models.OutputMode.DETAILS: - return items - short = mode == api_models.OutputMode.SHORT - return [item.to_dict(short=short) for item in items] + return self._process_output(query.all(), api_models.ChatSession, output_mode) diff --git a/controller/src/controller/db/sqldb.py b/controller/src/controller/db/sql/sqldb.py similarity index 100% rename from controller/src/controller/db/sqldb.py rename to controller/src/controller/db/sql/sqldb.py diff --git a/examples/agent/workflow.py b/examples/agent/workflow.py index a3a1776..ee23c13 100644 --- a/examples/agent/workflow.py +++ b/examples/agent/workflow.py @@ -13,9 +13,9 @@ # limitations under the License. from examples.agent.agent import Agent +from genai_factory import workflow_server from genai_factory.chains.base import HistorySaver, SessionLoader from genai_factory.chains.refine import RefineQuery -from genai_factory import workflow_server workflow_graph = [ SessionLoader(), diff --git a/examples/quick_start/workflow.py b/examples/quick_start/workflow.py index 60d9647..4ec193c 100644 --- a/examples/quick_start/workflow.py +++ b/examples/quick_start/workflow.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from genai_factory import workflow_server from genai_factory.chains.base import HistorySaver, SessionLoader from genai_factory.chains.refine import RefineQuery from genai_factory.chains.retrieval import MultiRetriever -from genai_factory import workflow_server workflow_graph = [ SessionLoader(), diff --git a/genai_factory/src/genai_factory/chains/retrieval.py b/genai_factory/src/genai_factory/chains/retrieval.py index 7ea802c..bc246e7 100644 --- a/genai_factory/src/genai_factory/chains/retrieval.py +++ b/genai_factory/src/genai_factory/chains/retrieval.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler @@ -25,9 +24,6 @@ from genai_factory.schemas import WorkflowEvent from genai_factory.utils import logger -#TODO use workflow server logger -logger = logging.getLogger(__name__) - class DocumentCallbackHandler(BaseCallbackHandler): """Callback handler that adds index numbers to retrieved documents.""" @@ -49,8 +45,8 @@ class DocumentRetriever: Example: vector_store = get_vector_db(config) llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo") - query = "What is an llm?" - dr = document_retrevial(llm, vector_store) + query = "What is an LLM?" + dr = document_retrieval(llm, vector_store) dr.get_answer(query) Args: diff --git a/genai_factory/src/genai_factory/schemas/__init__.py b/genai_factory/src/genai_factory/schemas/__init__.py index 6407044..fd9f4c7 100644 --- a/genai_factory/src/genai_factory/schemas/__init__.py +++ b/genai_factory/src/genai_factory/schemas/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from genai_factory.schemas.base import APIDictResponse, APIResponse, Base, OutputMode +from genai_factory.schemas.base import ( + APIDictResponse, + APIResponse, + Base, + OutputMode, + metadata_fields, +) from genai_factory.schemas.data_source import DataSource, DataSourceType from genai_factory.schemas.dataset import Dataset from genai_factory.schemas.document import Document diff --git a/genai_factory/src/genai_factory/schemas/base.py b/genai_factory/src/genai_factory/schemas/base.py index 3080eac..1ef6d23 100644 --- a/genai_factory/src/genai_factory/schemas/base.py +++ b/genai_factory/src/genai_factory/schemas/base.py @@ -69,69 +69,6 @@ def from_dict(cls, data: dict): return cls.parse_obj(data) # return cls.model_validate(data) # pydantic v2 - @classmethod - def from_orm_object(cls, obj): - object_dict = {} - for field in obj.__table__.columns: - object_dict[field.name] = getattr(obj, field.name) - spec = object_dict.pop("spec", {}) - object_dict.update(spec) - if obj.labels: - object_dict["labels"] = {label.name: label.value for label in obj.labels} - return cls.from_dict(object_dict) - - def merge_into_orm_object(self, orm_object): - struct = self.to_dict(drop_none=True) - spec = orm_object.spec or {} - labels = struct.pop("labels", None) - for k, v in struct.items(): - if k in (metadata_fields + self._top_level_fields) and k not in [ - "created", - "updated", - ]: - setattr(orm_object, k, v) - if k not in [metadata_fields + self._top_level_fields]: - spec[k] = v - orm_object.spec = spec - - if labels: - old = {label.name: label for label in orm_object.labels} - orm_object.labels.clear() - for name, value in labels.items(): - if name in old: - if value is not None: # None means delete - old[name].value = value - orm_object.labels.append(old[name]) - else: - orm_object.labels.append( - orm_object.Label(name=name, value=value, parent=orm_object.name) - ) - - return orm_object - - def to_orm_object(self, obj_class, uid=None): - struct = self.to_dict(drop_none=False, short=False) - obj_dict = { - k: v - for k, v in struct.items() - if k in (metadata_fields + self._top_level_fields) - and k not in ["created", "updated"] - } - obj_dict["spec"] = { - k: v - for k, v in struct.items() - if k not in metadata_fields + self._top_level_fields - } - labels = obj_dict.pop("labels", None) - if uid: - obj_dict["uid"] = uid - obj = obj_class(**obj_dict) - if labels: - obj.labels.clear() - for name, value in labels.items(): - obj.labels.append(obj.Label(name=name, value=value, parent=obj.name)) - return obj - def to_yaml(self, drop_none=True): return yaml.dump(self.to_dict(drop_none=drop_none))