diff --git a/py/core/__init__.py b/py/core/__init__.py index cb8c270bc..98d6ebdc8 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -147,6 +147,9 @@ # LLM provider "CompletionConfig", "CompletionProvider", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ## UTILS "RecursiveCharacterTextSplitter", "TextSplitter", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 6e14b3edc..d9c3e22c4 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -124,6 +124,9 @@ # LLM provider "CompletionConfig", "CompletionProvider", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ## UTILS "RecursiveCharacterTextSplitter", "TextSplitter", diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 5268b3740..f46ee0748 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -4,7 +4,6 @@ from .database import ( CollectionHandler, DatabaseConfig, - UserConfig, DatabaseConnectionManager, DatabaseProvider, DocumentHandler, @@ -22,6 +21,7 @@ from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider from .llm import CompletionConfig, CompletionProvider from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow +from .user_management import UserManagementConfig, UserManagementProvider __all__ = [ # Auth provider @@ -52,7 +52,6 @@ "KGHandler", "PromptHandler", "FileHandler", - "UserConfig", "DatabaseConfig", "PostgresConfigurationSettings", "DatabaseProvider", @@ -66,4 +65,7 @@ "OrchestrationConfig", "OrchestrationProvider", "Workflow", + # User management provider + "UserManagementConfig", + "UserManagementProvider", ] diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 34f41ac44..fe7a7491f 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -50,7 +50,6 @@ KGEnrichmentEstimationResponse, UserResponse, ) -from core.base.utils import _decorate_vector_type from ..logger import RunInfoLog from ..logger.base import RunType @@ -79,15 +78,6 @@ logger = logging.getLogger() -def escape_braces(s: str) -> str: - """ - Escape braces in a string. - This is a placeholder function - implement the actual logic as needed. - """ - # Implement your escape_braces logic here - return s.replace("{", "{{").replace("}", "}}") - - logger = logging.getLogger() @@ -195,38 +185,6 @@ async def initialize(self, pool: Any): pass -class RoleLimits(BaseModel): - max_files: Optional[int] = None - max_chunks: Optional[int] = None - max_queries: Optional[int] = None - max_queries_window: Optional[int] = None # in minutes - max_collections: Optional[int] = None - max_tokens_per_request: Optional[int] = None - - -class UserConfig(ProviderConfig): - default_role: str = "default" - roles: dict[str, RoleLimits] = { - "default": RoleLimits(), - } - - def validate_config(self) -> None: - """Validate the user configuration.""" - if not self.default_role: - raise ValueError("default_role must be specified") - if not self.roles: - raise ValueError("roles must be specified") - if self.default_role not in self.roles: - raise ValueError( - f"default_role '{self.default_role}' must exist in roles" - ) - - @property - def supported_providers(self) -> list[str]: - """Define supported providers - not applicable for UserConfig but required by ProviderConfig.""" - return ["r2r"] # Only r2r provider is supported for user configuration - - class Handler(ABC): def __init__( self, project_name: str, connection_manager: DatabaseConnectionManager diff --git a/py/core/base/providers/user_management.py b/py/core/base/providers/user_management.py new file mode 100644 index 000000000..3b655f137 --- /dev/null +++ b/py/core/base/providers/user_management.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional +from pydantic import BaseModel + +from .base import Provider, ProviderConfig + + +class RoleLimits(BaseModel): + max_files: Optional[int] = None + max_chunks: Optional[int] = None + max_queries: Optional[int] = None + max_queries_window: Optional[int] = None + + +class UserManagementConfig(ProviderConfig): + default_role: str = "default" + roles: Dict[str, RoleLimits] = { + "default": RoleLimits(), + "basic": RoleLimits( + max_files=1000, + max_chunks=10000, + max_queries=1000, + max_queries_window=1440, + ), + } + + @property + def supported_providers(self) -> list[str]: + return ["r2r"] + + def validate_config(self) -> None: + if not self.default_role in self.roles: + raise ValueError( + f"Default role '{self.default_role}' not found in roles configuration" + ) + + def get_role_limits(self, role: str) -> RoleLimits: + if role not in self.roles: + return self.roles[self.default_role] + return self.roles[role] + + +class UserManagementProvider(Provider, ABC): + def __init__(self, config: UserManagementConfig): + if not isinstance(config, UserManagementConfig): + raise ValueError( + "UserManagementProvider must be initialized with a UserManagementConfig" + ) + print(f"UserManagementProvider config: {config}") + super().__init__(config) + self.config: UserManagementConfig = config diff --git a/py/core/main/assembly/builder.py b/py/core/main/assembly/builder.py index 9f15578df..111d46ae5 100644 --- a/py/core/main/assembly/builder.py +++ b/py/core/main/assembly/builder.py @@ -11,6 +11,7 @@ DatabaseProvider, EmbeddingProvider, OrchestrationProvider, + UserManagementProvider, RunManager, ) from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline @@ -47,6 +48,7 @@ class ProviderOverrides: llm: Optional[CompletionProvider] = None crypto: Optional[CryptoProvider] = None orchestration: Optional[OrchestrationProvider] = None + user_management: Optional[UserManagementProvider] = None @dataclass diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 87e3ab5ff..174c885f9 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -15,6 +15,7 @@ EmbeddingProvider, IngestionConfig, OrchestrationConfig, + UserManagementConfig, ) from core.pipelines import RAGPipeline, SearchPipeline from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe @@ -43,6 +44,7 @@ SupabaseAuthProvider, UnstructuredIngestionConfig, UnstructuredIngestionProvider, + R2RUserManagementProvider, ) @@ -147,10 +149,22 @@ def create_orchestration_provider( f"Orchestration provider {config.provider} not supported" ) + @staticmethod + def create_user_management_provider( + user_management_config: UserManagementConfig, *args, **kwargs + ) -> R2RUserManagementProvider: + if user_management_config.provider == "r2r": + return R2RUserManagementProvider(user_management_config) + else: + raise ValueError( + f"User management provider {user_management_config.provider} not supported" + ) + async def create_database_provider( self, db_config: DatabaseConfig, crypto_provider: BCryptProvider, + user_management_provider: R2RUserManagementProvider, *args, **kwargs, ) -> PostgresDBProvider: @@ -171,10 +185,11 @@ async def create_database_provider( dimension, crypto_provider=crypto_provider, quantization_type=quantization_type, + user_management_provider=user_management_provider, ) await database_provider.initialize() logger.info( - f"Database provider initialized with user config: {self.config.user}" + f"Database provider initialized with user config: {self.config.user_management}" ) return database_provider else: @@ -242,7 +257,7 @@ async def create_email_provider( """Creates an email provider based on configuration.""" if not email_config: raise ValueError( - f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." + "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." ) if email_config.provider == "smtp": @@ -298,10 +313,19 @@ async def create_providers( crypto_provider_override or self.create_crypto_provider(self.config.crypto, *args, **kwargs) ) + + user_management_provider = self.create_user_management_provider( + self.config.user_management + ) + database_provider = ( database_provider_override or await self.create_database_provider( - self.config.database, crypto_provider, *args, **kwargs + self.config.database, + crypto_provider, + user_management_provider, + *args, + **kwargs, ) ) diff --git a/py/core/main/config.py b/py/core/main/config.py index e0f3239a5..ad19c5dd9 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -12,12 +12,13 @@ from ..base.providers import AppConfig from ..base.providers.auth import AuthConfig from ..base.providers.crypto import CryptoConfig -from ..base.providers.database import DatabaseConfig, UserConfig +from ..base.providers.database import DatabaseConfig from ..base.providers.email import EmailConfig from ..base.providers.embedding import EmbeddingConfig from ..base.providers.ingestion import IngestionConfig from ..base.providers.llm import CompletionConfig from ..base.providers.orchestration import OrchestrationConfig +from ..base.providers.user_management import UserManagementConfig logger = logging.getLogger() @@ -56,7 +57,7 @@ class R2RConfig: "database": ["provider"], "agent": ["generation_config"], "orchestration": ["provider"], - "user": ["default_role"], + "user_management": ["default_role"], } app: AppConfig @@ -70,7 +71,7 @@ class R2RConfig: logging: PersistentLoggingConfig agent: AgentConfig orchestration: OrchestrationConfig - user: UserConfig + user_management: UserManagementConfig def __init__(self, config_data: dict[str, Any]): """ @@ -124,7 +125,7 @@ def __init__(self, config_data: dict[str, Any]): self.logging = PersistentLoggingConfig.create(**self.logging, app=self.app) # type: ignore self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore - self.user = UserConfig.create(**self.user, app=self.app) # type: ignore + self.user_management = UserManagementConfig.create(**self.user_management, app=self.app) # type: ignore # override GenerationConfig defaults GenerationConfig.set_default( diff --git a/py/core/providers/__init__.py b/py/core/providers/__init__.py index a970f83eb..1320afe16 100644 --- a/py/core/providers/__init__.py +++ b/py/core/providers/__init__.py @@ -19,6 +19,7 @@ HatchetOrchestrationProvider, SimpleOrchestrationProvider, ) +from .user_management import R2RUserManagementProvider __all__ = [ # Auth @@ -49,4 +50,6 @@ "LiteLLMCompletionProvider", # Logging "SqlitePersistentLoggingProvider", + # User Management + "R2RUserManagementProvider", ] diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index cedd8da78..8c1bd1f86 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -10,7 +10,7 @@ DatabaseProvider, PostgresConfigurationSettings, VectorQuantizationType, - UserConfig, + UserManagementConfig, ) from core.providers import BCryptProvider from core.providers.database.base import PostgresConnectionManager @@ -23,6 +23,7 @@ from core.providers.database.tokens import PostgresTokenHandler from core.providers.database.user import PostgresUserHandler from core.providers.database.vector import PostgresVectorHandler +from core.providers.user_management import R2RUserManagementProvider from .base import SemaphoreConnectionPool @@ -54,7 +55,6 @@ class PostgresDBProvider(DatabaseProvider): conn: Optional[Any] crypto_provider: BCryptProvider - user_config: UserConfig postgres_configuration_settings: PostgresConfigurationSettings default_collection_name: str default_collection_description: str @@ -75,18 +75,13 @@ def __init__( config: DatabaseConfig, dimension: int, crypto_provider: BCryptProvider, + user_management_provider: R2RUserManagementProvider, quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, *args, **kwargs, ): super().__init__(config) - # Get user config from app config - self.user_config = getattr(config.app, "user", UserConfig()) - logger.info( - f"Initialized database with user roles: {list(self.user_config.roles.keys())}" - ) - env_vars = [ ("user", "R2R_POSTGRES_USER", "POSTGRES_USER"), ("password", "R2R_POSTGRES_PASSWORD", "POSTGRES_PASSWORD"), @@ -132,6 +127,7 @@ def __init__( self.conn = None self.config: DatabaseConfig = config self.crypto_provider = crypto_provider + self.user_management_provider = user_management_provider self.postgres_configuration_settings: PostgresConfigurationSettings = ( self._get_postgres_configuration_settings(config) ) @@ -157,7 +153,7 @@ def __init__( self.project_name, self.connection_manager, self.crypto_provider, - self.user_config, + self.user_management_provider, ) self.vector_handler = PostgresVectorHandler( self.project_name, @@ -183,13 +179,6 @@ def __init__( self.project_name, self.connection_manager ) - # Extract UserConfig from the main config - self.user_config = getattr(config.app, "user", UserConfig()) - logger.info( - f"Initialized database with user roles: {list(self.user_config.roles.keys())}" - ) - logger.info(f"Default role: {self.user_config.default_role}") - async def initialize(self): logger.info("Initializing `PostgresDBProvider`.") self.pool = SemaphoreConnectionPool( diff --git a/py/core/providers/database/user.py b/py/core/providers/database/user.py index 0c827278f..42b3db49d 100644 --- a/py/core/providers/database/user.py +++ b/py/core/providers/database/user.py @@ -3,7 +3,7 @@ from uuid import UUID from fastapi import HTTPException -from core.base import CryptoProvider, UserHandler, UserConfig +from core.base import CryptoProvider, UserHandler, UserManagementProvider from core.base.abstractions import R2RException, UserStats from core.base.api.models import UserResponse from core.utils import generate_user_id @@ -20,14 +20,12 @@ def __init__( project_name: str, connection_manager: PostgresConnectionManager, crypto_provider: CryptoProvider, - user_config: UserConfig, + user_config: UserManagementProvider, ): super().__init__(project_name, connection_manager) self.crypto_provider = crypto_provider self.user_config = user_config - print( - f"User handler initialized with roles: {list(user_config.roles.keys())}" - ) + print(f"The user config is: {vars(user_config)}") async def create_tables(self): query = f""" @@ -140,10 +138,10 @@ async def create_user( self, email: str, password: str, role: str = "default" ) -> UserResponse: """Modified create_user to include role""" - if role not in self.user_config.roles: - raise R2RException( - status_code=400, message=f"Invalid role: {role}" - ) + # if role not in self.user_config.roles: + # raise R2RException( + # status_code=400, message=f"Invalid role: {role}" + # ) try: if await self.get_user_by_email(email): @@ -632,46 +630,46 @@ async def get_user_verification_data( } } - async def check_file_limit(self, user_id: UUID) -> bool: - """Check if user has reached their file limit""" - user = await self.get_user_by_id(user_id) - role_limits = self.user_config.get_role_limits(user.role) - - if role_limits.max_files is None: - return True - - # Get current file count - query = f""" - SELECT COUNT(*) FROM {self._get_table_name('document_info')} - WHERE user_id = $1 - """ - result = await self.connection_manager.fetchrow_query(query, [user_id]) - return result[0] < role_limits.max_files - - async def check_query_limit(self, user_id: UUID) -> bool: - """Check if user has reached their query limit""" - user = await self.get_user_by_id(user_id) - role_limits = self.user_config.get_role_limits(user.role) - - if ( - role_limits.max_queries is None - or role_limits.max_queries_window is None - ): - return True - - # Get query count within window - window_start = datetime.utcnow() - timedelta( - minutes=role_limits.max_queries_window - ) - query = f""" - SELECT COUNT(*) FROM {self._get_table_name('logs')} - WHERE user_id = $1 AND created_at > $2 - AND run_type = 'query' - """ - result = await self.connection_manager.fetchrow_query( - query, [user_id, window_start] - ) - return result[0] < role_limits.max_queries + # async def check_file_limit(self, user_id: UUID) -> bool: + # """Check if user has reached their file limit""" + # user = await self.get_user_by_id(user_id) + # role_limits = self.user_config.get_role_limits(user.role) + + # if role_limits.max_files is None: + # return True + + # # Get current file count + # query = f""" + # SELECT COUNT(*) FROM {self._get_table_name('document_info')} + # WHERE user_id = $1 + # """ + # result = await self.connection_manager.fetchrow_query(query, [user_id]) + # return result[0] < role_limits.max_files + + # async def check_query_limit(self, user_id: UUID) -> bool: + # """Check if user has reached their query limit""" + # user = await self.get_user_by_id(user_id) + # role_limits = self.user_config.get_role_limits(user.role) + + # if ( + # role_limits.max_queries is None + # or role_limits.max_queries_window is None + # ): + # return True + + # # Get query count within window + # window_start = datetime.utcnow() - timedelta( + # minutes=role_limits.max_queries_window + # ) + # query = f""" + # SELECT COUNT(*) FROM {self._get_table_name('logs')} + # WHERE user_id = $1 AND created_at > $2 + # AND run_type = 'query' + # """ + # result = await self.connection_manager.fetchrow_query( + # query, [user_id, window_start] + # ) + # return result[0] < role_limits.max_queries async def increment_query_count(self, user_id: UUID) -> None: """Increment user's query count""" @@ -682,16 +680,16 @@ async def increment_query_count(self, user_id: UUID) -> None: """ await self.connection_manager.execute_query(query, [user_id]) - async def update_user_role(self, user_id: UUID, new_role: str) -> None: - """Update user's role""" - if new_role not in self.user_config.roles: - raise R2RException( - status_code=400, message=f"Invalid role: {new_role}" - ) - - query = f""" - UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} - SET role = $1 - WHERE user_id = $2 - """ - await self.connection_manager.execute_query(query, [new_role, user_id]) + # async def update_user_role(self, user_id: UUID, new_role: str) -> None: + # """Update user's role""" + # if new_role not in self.user_config.roles: + # raise R2RException( + # status_code=400, message=f"Invalid role: {new_role}" + # ) + + # query = f""" + # UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} + # SET role = $1 + # WHERE user_id = $2 + # """ + # await self.connection_manager.execute_query(query, [new_role, user_id]) diff --git a/py/core/providers/user_management/__init__.py b/py/core/providers/user_management/__init__.py new file mode 100644 index 000000000..9439e1b32 --- /dev/null +++ b/py/core/providers/user_management/__init__.py @@ -0,0 +1 @@ +from .r2r_user_management import R2RUserManagementProvider diff --git a/py/core/providers/user_management/r2r_user_management.py b/py/core/providers/user_management/r2r_user_management.py new file mode 100644 index 000000000..a5b3c5286 --- /dev/null +++ b/py/core/providers/user_management/r2r_user_management.py @@ -0,0 +1,18 @@ +from core.base.providers.user_management import ( + UserManagementProvider, + UserManagementConfig, + RoleLimits, +) + + +class R2RUserManagementProvider(UserManagementProvider): + def __init__(self, config: UserManagementConfig): + super().__init__(config) + self.roles = config.roles + self.default_role = config.default_role + print( + f"Initialized R2RUserManagementProvider with roles: {self.roles}" + ) + + def get_role_limits(self, role: str) -> RoleLimits: + return self.config.get_role_limits(role) diff --git a/py/r2r.toml b/py/r2r.toml index 4bb2551f4..8eddeb2f7 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -125,7 +125,8 @@ provider = "r2r" [email] provider = "console_mock" -[user] +[user_management] +provider = "r2r" default_role = "default" [user.default]