Skip to content

Commit

Permalink
Better layout
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 6, 2024
1 parent 3124e7b commit 61e696a
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 130 deletions.
3 changes: 3 additions & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@
# LLM provider
"CompletionConfig",
"CompletionProvider",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
## UTILS
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
3 changes: 3 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@
# LLM provider
"CompletionConfig",
"CompletionProvider",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
## UTILS
"RecursiveCharacterTextSplitter",
"TextSplitter",
Expand Down
6 changes: 4 additions & 2 deletions py/core/base/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .database import (
CollectionHandler,
DatabaseConfig,
UserConfig,
DatabaseConnectionManager,
DatabaseProvider,
DocumentHandler,
Expand All @@ -22,6 +21,7 @@
from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider
from .llm import CompletionConfig, CompletionProvider
from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow
from .user_management import UserManagementConfig, UserManagementProvider

__all__ = [
# Auth provider
Expand Down Expand Up @@ -52,7 +52,6 @@
"KGHandler",
"PromptHandler",
"FileHandler",
"UserConfig",
"DatabaseConfig",
"PostgresConfigurationSettings",
"DatabaseProvider",
Expand All @@ -66,4 +65,7 @@
"OrchestrationConfig",
"OrchestrationProvider",
"Workflow",
# User management provider
"UserManagementConfig",
"UserManagementProvider",
]
42 changes: 0 additions & 42 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
KGEnrichmentEstimationResponse,
UserResponse,
)
from core.base.utils import _decorate_vector_type

from ..logger import RunInfoLog
from ..logger.base import RunType
Expand Down Expand Up @@ -79,15 +78,6 @@
logger = logging.getLogger()


def escape_braces(s: str) -> str:
"""
Escape braces in a string.
This is a placeholder function - implement the actual logic as needed.
"""
# Implement your escape_braces logic here
return s.replace("{", "{{").replace("}", "}}")


logger = logging.getLogger()


Expand Down Expand Up @@ -195,38 +185,6 @@ async def initialize(self, pool: Any):
pass


class RoleLimits(BaseModel):
max_files: Optional[int] = None
max_chunks: Optional[int] = None
max_queries: Optional[int] = None
max_queries_window: Optional[int] = None # in minutes
max_collections: Optional[int] = None
max_tokens_per_request: Optional[int] = None


class UserConfig(ProviderConfig):
default_role: str = "default"
roles: dict[str, RoleLimits] = {
"default": RoleLimits(),
}

def validate_config(self) -> None:
"""Validate the user configuration."""
if not self.default_role:
raise ValueError("default_role must be specified")
if not self.roles:
raise ValueError("roles must be specified")
if self.default_role not in self.roles:
raise ValueError(
f"default_role '{self.default_role}' must exist in roles"
)

@property
def supported_providers(self) -> list[str]:
"""Define supported providers - not applicable for UserConfig but required by ProviderConfig."""
return ["r2r"] # Only r2r provider is supported for user configuration


class Handler(ABC):
def __init__(
self, project_name: str, connection_manager: DatabaseConnectionManager
Expand Down
51 changes: 51 additions & 0 deletions py/core/base/providers/user_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
from pydantic import BaseModel

from .base import Provider, ProviderConfig


class RoleLimits(BaseModel):
max_files: Optional[int] = None
max_chunks: Optional[int] = None
max_queries: Optional[int] = None
max_queries_window: Optional[int] = None


class UserManagementConfig(ProviderConfig):
default_role: str = "default"
roles: Dict[str, RoleLimits] = {
"default": RoleLimits(),
"basic": RoleLimits(
max_files=1000,
max_chunks=10000,
max_queries=1000,
max_queries_window=1440,
),
}

@property
def supported_providers(self) -> list[str]:
return ["r2r"]

def validate_config(self) -> None:
if not self.default_role in self.roles:
raise ValueError(
f"Default role '{self.default_role}' not found in roles configuration"
)

def get_role_limits(self, role: str) -> RoleLimits:
if role not in self.roles:
return self.roles[self.default_role]
return self.roles[role]


class UserManagementProvider(Provider, ABC):
def __init__(self, config: UserManagementConfig):
if not isinstance(config, UserManagementConfig):
raise ValueError(
"UserManagementProvider must be initialized with a UserManagementConfig"
)
print(f"UserManagementProvider config: {config}")
super().__init__(config)
self.config: UserManagementConfig = config
2 changes: 2 additions & 0 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DatabaseProvider,
EmbeddingProvider,
OrchestrationProvider,
UserManagementProvider,
RunManager,
)
from core.pipelines import KGEnrichmentPipeline, RAGPipeline, SearchPipeline
Expand Down Expand Up @@ -47,6 +48,7 @@ class ProviderOverrides:
llm: Optional[CompletionProvider] = None
crypto: Optional[CryptoProvider] = None
orchestration: Optional[OrchestrationProvider] = None
user_management: Optional[UserManagementProvider] = None


@dataclass
Expand Down
30 changes: 27 additions & 3 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EmbeddingProvider,
IngestionConfig,
OrchestrationConfig,
UserManagementConfig,
)
from core.pipelines import RAGPipeline, SearchPipeline
from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe
Expand Down Expand Up @@ -43,6 +44,7 @@
SupabaseAuthProvider,
UnstructuredIngestionConfig,
UnstructuredIngestionProvider,
R2RUserManagementProvider,
)


Expand Down Expand Up @@ -147,10 +149,22 @@ def create_orchestration_provider(
f"Orchestration provider {config.provider} not supported"
)

@staticmethod
def create_user_management_provider(
user_management_config: UserManagementConfig, *args, **kwargs
) -> R2RUserManagementProvider:
if user_management_config.provider == "r2r":
return R2RUserManagementProvider(user_management_config)
else:
raise ValueError(
f"User management provider {user_management_config.provider} not supported"
)

async def create_database_provider(
self,
db_config: DatabaseConfig,
crypto_provider: BCryptProvider,
user_management_provider: R2RUserManagementProvider,
*args,
**kwargs,
) -> PostgresDBProvider:
Expand All @@ -171,10 +185,11 @@ async def create_database_provider(
dimension,
crypto_provider=crypto_provider,
quantization_type=quantization_type,
user_management_provider=user_management_provider,
)
await database_provider.initialize()
logger.info(
f"Database provider initialized with user config: {self.config.user}"
f"Database provider initialized with user config: {self.config.user_management}"
)
return database_provider
else:
Expand Down Expand Up @@ -242,7 +257,7 @@ async def create_email_provider(
"""Creates an email provider based on configuration."""
if not email_config:
raise ValueError(
f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
)

if email_config.provider == "smtp":
Expand Down Expand Up @@ -298,10 +313,19 @@ async def create_providers(
crypto_provider_override
or self.create_crypto_provider(self.config.crypto, *args, **kwargs)
)

user_management_provider = self.create_user_management_provider(
self.config.user_management
)

database_provider = (
database_provider_override
or await self.create_database_provider(
self.config.database, crypto_provider, *args, **kwargs
self.config.database,
crypto_provider,
user_management_provider,
*args,
**kwargs,
)
)

Expand Down
9 changes: 5 additions & 4 deletions py/core/main/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from ..base.providers import AppConfig
from ..base.providers.auth import AuthConfig
from ..base.providers.crypto import CryptoConfig
from ..base.providers.database import DatabaseConfig, UserConfig
from ..base.providers.database import DatabaseConfig
from ..base.providers.email import EmailConfig
from ..base.providers.embedding import EmbeddingConfig
from ..base.providers.ingestion import IngestionConfig
from ..base.providers.llm import CompletionConfig
from ..base.providers.orchestration import OrchestrationConfig
from ..base.providers.user_management import UserManagementConfig

logger = logging.getLogger()

Expand Down Expand Up @@ -56,7 +57,7 @@ class R2RConfig:
"database": ["provider"],
"agent": ["generation_config"],
"orchestration": ["provider"],
"user": ["default_role"],
"user_management": ["default_role"],
}

app: AppConfig
Expand All @@ -70,7 +71,7 @@ class R2RConfig:
logging: PersistentLoggingConfig
agent: AgentConfig
orchestration: OrchestrationConfig
user: UserConfig
user_management: UserManagementConfig

def __init__(self, config_data: dict[str, Any]):
"""
Expand Down Expand Up @@ -124,7 +125,7 @@ def __init__(self, config_data: dict[str, Any]):
self.logging = PersistentLoggingConfig.create(**self.logging, app=self.app) # type: ignore
self.agent = AgentConfig.create(**self.agent, app=self.app) # type: ignore
self.orchestration = OrchestrationConfig.create(**self.orchestration, app=self.app) # type: ignore
self.user = UserConfig.create(**self.user, app=self.app) # type: ignore
self.user_management = UserManagementConfig.create(**self.user_management, app=self.app) # type: ignore

# override GenerationConfig defaults
GenerationConfig.set_default(
Expand Down
3 changes: 3 additions & 0 deletions py/core/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
HatchetOrchestrationProvider,
SimpleOrchestrationProvider,
)
from .user_management import R2RUserManagementProvider

__all__ = [
# Auth
Expand Down Expand Up @@ -49,4 +50,6 @@
"LiteLLMCompletionProvider",
# Logging
"SqlitePersistentLoggingProvider",
# User Management
"R2RUserManagementProvider",
]
21 changes: 5 additions & 16 deletions py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DatabaseProvider,
PostgresConfigurationSettings,
VectorQuantizationType,
UserConfig,
UserManagementConfig,
)
from core.providers import BCryptProvider
from core.providers.database.base import PostgresConnectionManager
Expand All @@ -23,6 +23,7 @@
from core.providers.database.tokens import PostgresTokenHandler
from core.providers.database.user import PostgresUserHandler
from core.providers.database.vector import PostgresVectorHandler
from core.providers.user_management import R2RUserManagementProvider

from .base import SemaphoreConnectionPool

Expand Down Expand Up @@ -54,7 +55,6 @@ class PostgresDBProvider(DatabaseProvider):
conn: Optional[Any]

crypto_provider: BCryptProvider
user_config: UserConfig
postgres_configuration_settings: PostgresConfigurationSettings
default_collection_name: str
default_collection_description: str
Expand All @@ -75,18 +75,13 @@ def __init__(
config: DatabaseConfig,
dimension: int,
crypto_provider: BCryptProvider,
user_management_provider: R2RUserManagementProvider,
quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
*args,
**kwargs,
):
super().__init__(config)

# Get user config from app config
self.user_config = getattr(config.app, "user", UserConfig())
logger.info(
f"Initialized database with user roles: {list(self.user_config.roles.keys())}"
)

env_vars = [
("user", "R2R_POSTGRES_USER", "POSTGRES_USER"),
("password", "R2R_POSTGRES_PASSWORD", "POSTGRES_PASSWORD"),
Expand Down Expand Up @@ -132,6 +127,7 @@ def __init__(
self.conn = None
self.config: DatabaseConfig = config
self.crypto_provider = crypto_provider
self.user_management_provider = user_management_provider
self.postgres_configuration_settings: PostgresConfigurationSettings = (
self._get_postgres_configuration_settings(config)
)
Expand All @@ -157,7 +153,7 @@ def __init__(
self.project_name,
self.connection_manager,
self.crypto_provider,
self.user_config,
self.user_management_provider,
)
self.vector_handler = PostgresVectorHandler(
self.project_name,
Expand All @@ -183,13 +179,6 @@ def __init__(
self.project_name, self.connection_manager
)

# Extract UserConfig from the main config
self.user_config = getattr(config.app, "user", UserConfig())
logger.info(
f"Initialized database with user roles: {list(self.user_config.roles.keys())}"
)
logger.info(f"Default role: {self.user_config.default_role}")

async def initialize(self):
logger.info("Initializing `PostgresDBProvider`.")
self.pool = SemaphoreConnectionPool(
Expand Down
Loading

0 comments on commit 61e696a

Please sign in to comment.