From b4f343adf7228c9abcf87bfb3931ceaae85e500b Mon Sep 17 00:00:00 2001 From: Taofiqq Date: Tue, 18 Feb 2025 08:56:26 +0100 Subject: [PATCH] update retrievers --- langchain_permit/retrievers.py | 513 ++++++++++++++++++++------------- 1 file changed, 305 insertions(+), 208 deletions(-) diff --git a/langchain_permit/retrievers.py b/langchain_permit/retrievers.py index ad8883d..f107c13 100644 --- a/langchain_permit/retrievers.py +++ b/langchain_permit/retrievers.py @@ -1,233 +1,330 @@ -# retrievers.py -"""Langchain retrievers with Permit.io integration for authorization.""" - -from typing import List, Dict, Optional, Any -from langchain.retrievers.self_query.base import SelfQueryRetriever -from langchain.retrievers import EnsembleRetriever -from langchain_community.retrievers import BM25Retriever -from langchain.chains.query_constructor.base import AttributeInfo -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langchain_chroma import Chroma -from langchain_core.documents import Document -from permit import Permit -from permit import Permit +"""Permit.io integration retrievers for Langchain.""" import os +from typing import Any, List, Optional, Dict, Callable +from pydantic import BaseModel, Field, field_validator, PrivateAttr +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever +from langchain.retrievers import SelfQueryRetriever, EnsembleRetriever +from permit import Permit, User, Action, Context +import asyncio +from langchain_core.language_models import BaseLanguageModel +from langchain_core.vectorstores import VectorStore +from langchain.chains.query_constructor.base import StructuredQueryOutputParser, get_query_constructor_prompt +from langchain.chains.query_constructor.schema import AttributeInfo -# Initialize Permit client -permit_client = Permit( - token=os.getenv("PERMIT_API_KEY"), - pdp=os.getenv("PERMIT_PDP_URL") -) -class ReBACSelfQueryRetriever(SelfQueryRetriever): - """A retriever that uses ReBAC (Relationship-Based Access Control) with self-query capabilities. - - This retriever extends the standard SelfQueryRetriever to include relationship-based - access control through Permit.io integration. It allows querying documents based on - both content and user relationships. +class PermitSelfQueryRetriever(SelfQueryRetriever, BaseModel): + """Retriever that uses natural language to query permitted documents with Permit.io authorization.""" - Example: - >>> retriever = ReBACSelfQueryRetriever( - ... llm=ChatOpenAI(), - ... vectorstore=vectorstore, - ... permit_client=permit_client - ... ) - >>> # Query with relationship context - >>> docs = retriever.get_relevant_documents( - ... "Find project proposals", - ... user_context={"user_id": "user-123", "relationships": ["team-a"]} - ... ) - """ + # Configuration fields + api_key: str = Field( + default_factory=lambda: os.getenv('PERMIT_API_KEY', ''), + description="Permit.io API key" + ) + pdp_url: Optional[str] = Field( + default_factory=lambda: os.getenv('PERMIT_PDP_URL'), + description="Optional PDP URL" + ) + # user: User = Field(..., description="User to check permissions for") + user: Dict[str, Any] = Field(..., description="User to check permissions for") + resource_type: str = Field(..., description="Type of resource to query") + action: str = Field(..., description="Action being performed") + llm: BaseLanguageModel = Field(..., description="Language model for query construction") + vectorstore: VectorStore = Field(..., description="Vector store for document retrieval") + enable_limit: bool = Field(default=False, description="Whether to enable limit in queries") - def __init__(self, permit_client: Permit, *args, **kwargs): - super().__init__(*args, **kwargs) - self.permit_client = permit_client - self.metadata_field_info = [ + # Private fields + _permit_client: Optional[Permit] = PrivateAttr(default=None) + _allowed_ids: List[str] = PrivateAttr(default_factory=list) + + class Config: + arbitrary_types_allowed = True + + def __init__(self, **data): + super().__init__(**data) + + # Initialize Permit client + self._permit_client = Permit( + token=self.api_key, + pdp=self.pdp_url + ) + + # Get initial allowed IDs + self._allowed_ids = self._get_permitted_ids() + + # Create metadata field info + metadata_field_info = [ AttributeInfo( - name="owner", - description="The owner of the document", + name="id", + description="The document identifier that must be in the allowed list", type="string", - ), - AttributeInfo( - name="relationships", - description="List of users who have a relationship with this document", - type="list[string]", + enum=self._allowed_ids ), AttributeInfo( name="resource_type", - description="Type of resource - document, file, etc.", - type="string", - ), + description="The type of resource", + type="string" + ) ] + + # Create query constructor chain + prompt = get_query_constructor_prompt( + document_content_description=f"Document of type {self.resource_type}", + metadata_field_info=metadata_field_info, + ) + output_parser = StructuredQueryOutputParser.from_components() + query_constructor = prompt | self.llm | output_parser + + # Initialize the SelfQueryRetriever + super(SelfQueryRetriever, self).__init__( + llm=self.llm, + vectorstore=self.vectorstore, + document_content_description=f"Document of type {self.resource_type}", + metadata_field_info=metadata_field_info, + structured_query_translator=self._create_translator(), + query_constructor=query_constructor, + enable_limit=self.enable_limit + ) - async def _aget_relevant_documents(self, query: str, user_context: Optional[Dict] = None) -> List[Document]: - """Get documents relevant to the query while enforcing ReBAC policies. + def _get_permitted_ids(self) -> List[str]: + """Get list of permitted document IDs from Permit.io.""" + permissions = self._permit_client.get_user_permissions( + user=self.user, + resource_types=[self.resource_type] + ) - Args: - query: User's query string - user_context: Dictionary containing user information and relationships - - Returns: - List of documents that the user has permission to access - """ - # First get relevant documents based on query - docs = await super()._aget_relevant_documents(query) + allowed_ids = [] + for resource in permissions.get("default", {}).get(self.resource_type, []): + if self.action in resource.get("actions", []): + allowed_ids.append(resource["id"]) - if not user_context: - return [] - - # Filter based on relationships - allowed_docs = [] - for doc in docs: - allowed = await self.permit_client.check( - user=user_context, - action="read", - resource={ - "type": doc.metadata.get("resource_type", "document"), - "attributes": doc.metadata + return allowed_ids + + def _create_translator(self): + """Create query translator that always includes ID filter.""" + base_translator = self.vectorstore.as_query_transformer() + + def wrapped_translator(structured_query): + # Add ID constraint to every query + if not structured_query.filter: + structured_query.filter = {"id": {"$in": self._allowed_ids}} + else: + structured_query.filter = { + "$and": [ + structured_query.filter, + {"id": {"$in": self._allowed_ids}} + ] } + return base_translator.visit_structured_query(structured_query) + + return wrapped_translator + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + **kwargs: Any + ) -> List[Document]: + """Get relevant documents with permissions built into the query.""" + # Log retrieval start + run_manager.on_retriever_start( + query, + { + "user_id": self.user.key, + "resource_type": self.resource_type, + "action": self.action, + "allowed_ids_count": len(self._allowed_ids) + } + ) + + try: + # Refresh permissions before querying + self._allowed_ids = self._get_permitted_ids() + + # Get documents using parent method + docs = await super()._aget_relevant_documents( + query, + run_manager=run_manager, + **kwargs ) - if allowed: - allowed_docs.append(doc) - - return allowed_docs + + run_manager.on_retriever_end(docs) + return docs + + except Exception as e: + run_manager.on_retriever_error(f"{e.__class__.__name__}: {str(e)}") + raise + def get_relevant_documents( + self, + query: str, + *, + run_manager: Optional[CallbackManagerForRetrieverRun] = None, + **kwargs: Any + ) -> List[Document]: + """Synchronous entry point that wraps the async retrieval.""" + try: + return asyncio.run(self._aget_relevant_documents( + query, + run_manager=run_manager or CallbackManagerForRetrieverRun.get_noop_manager(), + **kwargs + )) + except RuntimeError: + # If there's an active event loop, fall back to get_event_loop() + loop = asyncio.get_event_loop() + return loop.run_until_complete(self._aget_relevant_documents( + query, + run_manager=run_manager or CallbackManagerForRetrieverRun.get_noop_manager(), + **kwargs + )) -class RBACEnsembleRetriever(EnsembleRetriever): - """An ensemble retriever that combines semantic search with RBAC/ABAC policies. - - This retriever uses multiple underlying retrievers and applies role-based and - attribute-based access control through Permit.io. - - Example: - >>> retriever = RBACEnsembleRetriever( - ... retrievers=[semantic_retriever, permission_retriever], - ... weights=[0.7, 0.3], - ... permit_client=permit_client - ... ) - >>> # Query with role context - >>> docs = retriever.get_relevant_documents( - ... "HR policies", - ... user_context={"roles": ["hr_staff"], "attributes": {"department": "HR"}} - ... ) +class PermitEnsembleRetriever(EnsembleRetriever, BaseModel): + """ + Ensemble retriever with Permit.io permission filtering. """ - - def __init__(self, permit_client: Permit, *args, **kwargs): - super().__init__(*args, **kwargs) - self.permit_client = permit_client - async def _aget_relevant_documents(self, query: str, user_context: Optional[Dict] = None) -> List[Document]: - """Get documents relevant to the query while enforcing RBAC/ABAC policies. - - Args: - query: User's query string - user_context: Dictionary containing user roles and attributes - - Returns: - List of documents that the user has permission to access based on roles and attributes - """ - # Get initial results from ensemble - docs = await super()._aget_relevant_documents(query) + # Instance configuration + api_key: str = Field( + default_factory=lambda: os.getenv('PERMIT_API_KEY', ''), + description="Permit.io API key" + ) + pdp_url: Optional[str] = Field( + default_factory=lambda: os.getenv('PERMIT_PDP_URL'), + description="Optional PDP URL" + ) + user: str = Field(..., description="User to check permissions for") + action: str = Field(..., description="Action being performed") + resource_type: str = Field(..., description="Type of resource being accessed") + retrievers: List[BaseRetriever] = Field(..., description="List of retrievers to ensemble") + weights: Optional[List[float]] = Field(default=None, description="Optional weights for retrievers") + + class Config: + arbitrary_types_allowed = True + + @field_validator('api_key') + def validate_api_key(cls, v): + if not v: + raise ValueError("PERMIT_API_KEY must be provided either through environment variable or directly") + return v + + def __init__(self, **data): + # Initialize base EnsembleRetriever first + EnsembleRetriever.__init__( + self, + retrievers=data.get('retrievers', []), + weights=data.get('weights') + ) + # Initialize Pydantic BaseModel + BaseModel.__init__(self, **data) + + # Create the Permit client + self._permit_client = Permit( + token=self.api_key, + pdp=self.pdp_url + ) + + async def _filter_by_permissions( + self, + documents: List[Document] + ) -> List[Document]: + """Filter documents by permissions.""" + # Extract document IDs + doc_ids = [doc.metadata.get("id") for doc in documents if "id" in doc.metadata] - if not user_context: + if not doc_ids: return [] + + try: + # Prepare resources for permission check + resources = [ + {"id": doc_id, "type": self.resource_type} + for doc_id in doc_ids + ] - # Filter based on RBAC/ABAC - allowed_docs = [] - for doc in docs: - allowed = await self.permit_client.check( - user=user_context, - action="read", - resource={ - "type": doc.metadata.get("resource_type", "document"), - "attributes": doc.metadata - } + # Check permissions through Permit.io + filtered_resources = await self._permit_client.filter_objects( + user=self.user, + action=self.action, + context=Context(), + resources=resources ) - if allowed: - allowed_docs.append(doc) - - return allowed_docs - -# Usage Examples -async def demo_rebac_retriever(): - """Example usage of ReBAC retriever""" - docs = [ - Document( - page_content="Confidential project proposal for Project X", - metadata={ - "owner": "user-123", - "relationships": ["team-a", "managers"], - "resource_type": "proposal" - } - ), - Document( - page_content="Public company announcement", - metadata={ - "owner": "user-456", - "relationships": ["all-employees"], - "resource_type": "announcement" - } - ) - ] - - # Initialize components - embeddings = OpenAIEmbeddings() - vectorstore = Chroma.from_documents(docs, embeddings) - - rebac_retriever = ReBACSelfQueryRetriever( - llm=ChatOpenAI(temperature=0), - vectorstore=vectorstore, - permit_client=permit_client - ) - - # Example queries - queries = [ - ("Find all project proposals", {"user_id": "user-123", "relationships": ["team-a"]}), - ("Show me company announcements", {"user_id": "user-789", "relationships": ["all-employees"]}) - ] - - for query, context in queries: - results = await rebac_retriever._aget_relevant_documents(query, context) - print(f"\nQuery: {query}") - print(f"Results: {[doc.page_content for doc in results]}") - -async def demo_rbac_ensemble_retriever(): - """Example usage of RBAC/ABAC ensemble retriever""" - docs = [ - Document( - page_content="HR Policy: Work from home guidelines", - metadata={ - "department": "HR", - "classification": "internal", - "required_role": "hr_staff" - } - ), - Document( - page_content="Employee handbook", - metadata={ - "department": "HR", - "classification": "public", - "required_role": "employee" + + # Get allowed IDs + allowed_ids = {r["id"] for r in filtered_resources} + + # Filter documents + return [ + doc for doc in documents + if doc.metadata.get("id") in allowed_ids + ] + + except Exception as e: + raise RuntimeError(f"Permission filtering failed: {str(e)}") + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + **kwargs: Any + ) -> List[Document]: + """Get relevant documents from ensemble and filter by permissions.""" + # Start retrieval process + run_manager.on_retriever_start( + query, + { + "retriever_type": self.__class__.__name__, + "num_retrievers": len(self.retrievers), + "resource_type": self.resource_type, + "action": self.action } ) - ] - - # Initialize retrievers and create ensemble - semantic_retriever = Chroma.from_documents(docs, OpenAIEmbeddings()).as_retriever() - permission_retriever = BM25Retriever.from_documents(docs) - - rbac_retriever = RBACEnsembleRetriever( - retrievers=[semantic_retriever, permission_retriever], - weights=[0.6, 0.4], - permit_client=permit_client - ) - - # Example queries - queries = [ - ("HR policies", {"roles": ["hr_staff"], "attributes": {"department": "HR"}}), - ("Employee handbook", {"roles": ["employee"]}) - ] - - for query, context in queries: - results = await rbac_retriever._aget_relevant_documents(query, context) - print(f"\nQuery: {query}") - print(f"Results: {[doc.page_content for doc in results]}") \ No newline at end of file + + try: + # Get documents from ensemble retrievers + docs = await super()._aget_relevant_documents( + query, + run_manager=run_manager, + **kwargs + ) + + run_manager.on_event( + "ensemble_retrieval_complete", + {"retrieved_count": len(docs)} + ) + + # Apply permission filtering + filtered_docs = await self._filter_by_permissions(docs) + + run_manager.on_retriever_end( + filtered_docs, + { + "initial_count": len(docs), + "permitted_count": len(filtered_docs), + "filtered_out": len(docs) - len(filtered_docs) + } + ) + + return filtered_docs + + except Exception as e: + run_manager.on_retriever_error(f"{e.__class__.__name__}: {str(e)}") + raise + + def get_relevant_documents( + self, + query: str, + *, + run_manager: Optional[CallbackManagerForRetrieverRun] = None, + **kwargs: Any + ) -> List[Document]: + """Synchronous entry point that wraps the async retrieval.""" + import asyncio + try: + # Attempt to use asyncio.run() if no event loop is running. + return asyncio.run(self._aget_relevant_documents(query, **kwargs)) + except RuntimeError: + # If there's an active event loop, fall back to get_event_loop(). + loop = asyncio.get_event_loop() + return loop.run_until_complete(self._aget_relevant_documents(query, **kwargs)) \ No newline at end of file