diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e41e8e1..e8efc6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -136,7 +136,30 @@ jobs: docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/slack-bot:$TAG # ============================================ - # Stage 2b: Build jobs Docker Image + # Stage 2b: Build MCP Server Docker Image + # ============================================ + build-mcp: + needs: unit-tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + - uses: google-github-actions/setup-gcloud@v2 + - name: Configure Docker + run: gcloud auth configure-docker ${{ env.REGION }}-docker.pkg.dev + - name: Build and push mcp-server image + env: + TAG: ${{ github.sha }} + run: | + docker build -f deploy/docker/Dockerfile.mcp -t ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG . + docker tag ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:latest + docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG + docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:latest + + # ============================================ + # Stage 2c: Build jobs Docker Image # ============================================ build-jobs: needs: unit-tests @@ -166,11 +189,11 @@ jobs: docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG # ============================================ - # Stage 2c: Build Gate (for branch protection) + # Stage 2d: Build Gate (for branch protection) # ============================================ build: # Gate job: requires builds + reviewer approvals (on PRs) before staging deploy - needs: [build-slack-bot, build-jobs, ai-review, security-review] + needs: [build-slack-bot, build-mcp, build-jobs, ai-review, security-review] if: always() runs-on: ubuntu-latest steps: @@ -178,17 +201,20 @@ jobs: env: EVENT_NAME: ${{ github.event_name }} BUILD_BOT: ${{ needs.build-slack-bot.result }} + BUILD_MCP: ${{ needs.build-mcp.result }} BUILD_JOBS: ${{ needs.build-jobs.result }} AI_REVIEW: ${{ needs.ai-review.result }} SEC_REVIEW: ${{ needs.security-review.result }} run: | echo "Build Slack Bot: $BUILD_BOT" + echo "Build MCP Server: $BUILD_MCP" echo "Build Jobs: $BUILD_JOBS" echo "AI Review: $AI_REVIEW" echo "Security Review: $SEC_REVIEW" # Build jobs must succeed [[ "$BUILD_BOT" == "success" ]] || { echo "::error::build-slack-bot failed"; exit 1; } + [[ "$BUILD_MCP" == "success" ]] || { echo "::error::build-mcp failed"; exit 1; } [[ "$BUILD_JOBS" == "success" ]] || { echo "::error::build-jobs failed"; exit 1; } # For PRs: security review is required, AI review is advisory @@ -231,12 +257,34 @@ jobs: ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG \ ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:staging + # ============================================ + # Stage 3b: Deploy MCP Server to Staging + # ============================================ + deploy-mcp-staging: + needs: build + if: always() && needs.build.result == 'success' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + - uses: google-github-actions/setup-gcloud@v2 + - name: Deploy mcp-server to staging + env: + TAG: ${{ github.sha }} + run: | + gcloud run services update kb-mcp-staging \ + --image=${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG \ + --region=${{ env.REGION }} \ + --project=${{ env.PROJECT_ID }} + # ============================================ # Stage 4: E2E Tests against Staging (ALL tests) # ============================================ e2e-tests-staging: - needs: deploy-staging - if: always() && needs.deploy-staging.result == 'success' + needs: [deploy-staging, deploy-mcp-staging] + if: always() && needs.deploy-staging.result == 'success' && needs.deploy-mcp-staging.result == 'success' runs-on: ubuntu-latest env: # Slack configuration (Pattern: {SERVICE}_{ENV}_*) @@ -344,3 +392,25 @@ jobs: gcloud artifacts docker tags add \ ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG \ ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:latest + + # ============================================ + # Stage 5b: Deploy MCP Server to Production (main only) + # ============================================ + deploy-mcp-production: + needs: e2e-tests-staging + if: always() && needs.e2e-tests-staging.result == 'success' && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + - uses: google-github-actions/setup-gcloud@v2 + - name: Deploy mcp-server to production + env: + TAG: ${{ github.sha }} + run: | + gcloud run services update kb-mcp \ + --image=${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG \ + --region=${{ env.REGION }} \ + --project=${{ env.PROJECT_ID }} diff --git a/deploy/docker/Dockerfile.mcp b/deploy/docker/Dockerfile.mcp new file mode 100644 index 0000000..a0e1bee --- /dev/null +++ b/deploy/docker/Dockerfile.mcp @@ -0,0 +1,33 @@ +# ============================================================================= +# MCP Server Docker Image +# ============================================================================= +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Copy all source files +COPY pyproject.toml . +COPY src/ ./src/ + +# Install Python dependencies +RUN pip install --no-cache-dir . starlette + +# Create non-root user +RUN useradd --create-home --shell /bin/bash appuser && \ + chown -R appuser:appuser /app + +USER appuser + +ENV PORT=8080 + +CMD ["python", "-m", "knowledge_base.mcp.server"] diff --git a/deploy/terraform/cloudrun-mcp.tf b/deploy/terraform/cloudrun-mcp.tf new file mode 100644 index 0000000..004ce70 --- /dev/null +++ b/deploy/terraform/cloudrun-mcp.tf @@ -0,0 +1,398 @@ +# ============================================================================= +# MCP Server Cloud Run Services (Production + Staging) +# ============================================================================= + +# ----------------------------------------------------------------------------- +# Service Account +# ----------------------------------------------------------------------------- +resource "google_service_account" "mcp_server" { + account_id = "mcp-server" + display_name = "MCP Server Service Account" +} + +# Vertex AI access for embeddings and LLM +resource "google_project_iam_member" "mcp_server_vertex_ai" { + project = var.project_id + role = "roles/aiplatform.user" + member = "serviceAccount:${google_service_account.mcp_server.email}" +} + +# ----------------------------------------------------------------------------- +# Secret Manager - MCP OAuth Secrets +# ----------------------------------------------------------------------------- +resource "google_secret_manager_secret" "mcp_oauth_client_id" { + secret_id = "mcp-oauth-client-id" + + replication { + auto {} + } + + labels = { + environment = var.environment + purpose = "mcp" + } +} + +resource "google_secret_manager_secret_version" "mcp_oauth_client_id" { + secret = google_secret_manager_secret.mcp_oauth_client_id.id + secret_data = "REPLACE_ME" + + lifecycle { + ignore_changes = [secret_data] + } +} + +resource "google_secret_manager_secret" "mcp_oauth_client_secret" { + secret_id = "mcp-oauth-client-secret" + + replication { + auto {} + } + + labels = { + environment = var.environment + purpose = "mcp" + } +} + +resource "google_secret_manager_secret_version" "mcp_oauth_client_secret" { + secret = google_secret_manager_secret.mcp_oauth_client_secret.id + secret_data = "REPLACE_ME" + + lifecycle { + ignore_changes = [secret_data] + } +} + +# IAM: MCP server SA can access OAuth secrets +resource "google_secret_manager_secret_iam_member" "mcp_oauth_client_id_access" { + secret_id = google_secret_manager_secret.mcp_oauth_client_id.secret_id + role = "roles/secretmanager.secretAccessor" + member = "serviceAccount:${google_service_account.mcp_server.email}" +} + +resource "google_secret_manager_secret_iam_member" "mcp_oauth_client_secret_access" { + secret_id = google_secret_manager_secret.mcp_oauth_client_secret.secret_id + role = "roles/secretmanager.secretAccessor" + member = "serviceAccount:${google_service_account.mcp_server.email}" +} + +# ----------------------------------------------------------------------------- +# Production MCP Server +# ----------------------------------------------------------------------------- +resource "google_cloud_run_v2_service" "mcp_server" { + name = "kb-mcp" + location = var.region + + template { + scaling { + min_instance_count = 1 + max_instance_count = 5 + } + + containers { + image = "${var.region}-docker.pkg.dev/${var.project_id}/knowledge-base/mcp-server:latest" + + resources { + limits = { + cpu = "1" + memory = "512Mi" + } + } + + # Graph Database Configuration (Graphiti + Neo4j) + env { + name = "GRAPH_BACKEND" + value = "neo4j" + } + + env { + name = "GRAPH_ENABLE_GRAPHITI" + value = "true" + } + + # Neo4j connection - using internal GCE VM IP via VPC connector + env { + name = "NEO4J_URI" + value = "bolt://${google_compute_instance.neo4j_prod.network_interface[0].network_ip}:7687" + } + + env { + name = "NEO4J_USER" + value = "neo4j" + } + + env { + name = "NEO4J_PASSWORD" + value = random_password.neo4j_prod_password.result + } + + # LLM Configuration + env { + name = "LLM_PROVIDER" + value = "gemini" + } + + env { + name = "GEMINI_CONVERSATION_MODEL" + value = "gemini-2.5-flash" + } + + env { + name = "EMBEDDING_PROVIDER" + value = "vertex-ai" + } + + env { + name = "GCP_PROJECT_ID" + value = var.project_id + } + + env { + name = "VERTEX_AI_PROJECT" + value = var.project_id + } + + env { + name = "VERTEX_AI_LOCATION" + value = var.region + } + + env { + name = "GOOGLE_GENAI_USE_VERTEXAI" + value = "true" + } + + # MCP OAuth Configuration + env { + name = "MCP_OAUTH_CLIENT_ID" + value_source { + secret_key_ref { + secret = google_secret_manager_secret.mcp_oauth_client_id.secret_id + version = "latest" + } + } + } + + env { + name = "MCP_OAUTH_CLIENT_SECRET" + value_source { + secret_key_ref { + secret = google_secret_manager_secret.mcp_oauth_client_secret.secret_id + version = "latest" + } + } + } + + env { + name = "MCP_OAUTH_RESOURCE_IDENTIFIER" + value = "https://kb-mcp.${var.base_domain}" + } + + env { + name = "MCP_DEV_MODE" + value = "false" + } + + # Health check + startup_probe { + http_get { + path = "/health" + port = 8080 + } + initial_delay_seconds = 5 + period_seconds = 10 + failure_threshold = 3 + } + } + + vpc_access { + connector = google_vpc_access_connector.connector.id + egress = "PRIVATE_RANGES_ONLY" + } + + service_account = google_service_account.mcp_server.email + } + + traffic { + type = "TRAFFIC_TARGET_ALLOCATION_TYPE_LATEST" + percent = 100 + } + + depends_on = [ + google_secret_manager_secret_version.mcp_oauth_client_id, + google_secret_manager_secret_version.mcp_oauth_client_secret, + ] +} + +# Allow unauthenticated access (OAuth is handled at application level) +resource "google_cloud_run_v2_service_iam_member" "mcp_server_invoker" { + project = var.project_id + location = var.region + name = google_cloud_run_v2_service.mcp_server.name + role = "roles/run.invoker" + member = "allUsers" +} + +# ----------------------------------------------------------------------------- +# Staging MCP Server +# ----------------------------------------------------------------------------- +resource "google_cloud_run_v2_service" "mcp_server_staging" { + name = "kb-mcp-staging" + location = var.region + + template { + scaling { + min_instance_count = 0 + max_instance_count = 3 + } + + containers { + image = "${var.region}-docker.pkg.dev/${var.project_id}/knowledge-base/mcp-server:staging" + + resources { + limits = { + cpu = "1" + memory = "512Mi" + } + } + + # Graph Database Configuration (Graphiti + Neo4j) + env { + name = "GRAPH_BACKEND" + value = "neo4j" + } + + env { + name = "GRAPH_ENABLE_GRAPHITI" + value = "true" + } + + # Neo4j staging connection - using internal GCE VM IP via VPC connector + env { + name = "NEO4J_URI" + value = "bolt://${google_compute_instance.neo4j_staging.network_interface[0].network_ip}:7687" + } + + env { + name = "NEO4J_USER" + value = "neo4j" + } + + env { + name = "NEO4J_PASSWORD" + value = random_password.neo4j_staging_password.result + } + + # LLM Configuration + env { + name = "LLM_PROVIDER" + value = "gemini" + } + + env { + name = "GEMINI_CONVERSATION_MODEL" + value = "gemini-2.5-flash" + } + + env { + name = "EMBEDDING_PROVIDER" + value = "vertex-ai" + } + + env { + name = "GCP_PROJECT_ID" + value = var.project_id + } + + env { + name = "VERTEX_AI_PROJECT" + value = var.project_id + } + + env { + name = "VERTEX_AI_LOCATION" + value = var.region + } + + env { + name = "GOOGLE_GENAI_USE_VERTEXAI" + value = "true" + } + + # MCP OAuth Configuration (reuse same secrets for staging) + env { + name = "MCP_OAUTH_CLIENT_ID" + value_source { + secret_key_ref { + secret = google_secret_manager_secret.mcp_oauth_client_id.secret_id + version = "latest" + } + } + } + + env { + name = "MCP_OAUTH_CLIENT_SECRET" + value_source { + secret_key_ref { + secret = google_secret_manager_secret.mcp_oauth_client_secret.secret_id + version = "latest" + } + } + } + + env { + name = "MCP_OAUTH_RESOURCE_IDENTIFIER" + value = "https://kb-mcp-staging.${var.staging_domain}" + } + + env { + name = "MCP_DEV_MODE" + value = "false" + } + + env { + name = "ENVIRONMENT" + value = "staging" + } + + # Health check + startup_probe { + http_get { + path = "/health" + port = 8080 + } + initial_delay_seconds = 5 + period_seconds = 10 + failure_threshold = 3 + } + } + + vpc_access { + connector = google_vpc_access_connector.connector.id + egress = "PRIVATE_RANGES_ONLY" + } + + service_account = google_service_account.mcp_server.email + } + + traffic { + type = "TRAFFIC_TARGET_ALLOCATION_TYPE_LATEST" + percent = 100 + } + + depends_on = [ + google_secret_manager_secret_version.mcp_oauth_client_id, + google_secret_manager_secret_version.mcp_oauth_client_secret, + google_compute_instance.neo4j_staging, + ] +} + +# Allow unauthenticated access (OAuth is handled at application level) +resource "google_cloud_run_v2_service_iam_member" "mcp_server_staging_invoker" { + project = var.project_id + location = var.region + name = google_cloud_run_v2_service.mcp_server_staging.name + role = "roles/run.invoker" + member = "allUsers" +} diff --git a/pyproject.toml b/pyproject.toml index 072cc70..ae2c512 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,10 @@ dependencies = [ "google-genai>=1.0.0", # For Gemini LLM client in Graphiti "kuzu>=0.4.0", "neo4j>=5.26.0", + # MCP server dependencies + "mcp>=1.25.0", + "authlib>=1.6.0", + "PyJWT>=2.10.0", ] [project.optional-dependencies] diff --git a/src/knowledge_base/core/__init__.py b/src/knowledge_base/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/knowledge_base/core/qa.py b/src/knowledge_base/core/qa.py new file mode 100644 index 0000000..5705624 --- /dev/null +++ b/src/knowledge_base/core/qa.py @@ -0,0 +1,131 @@ +"""Core Q&A logic for knowledge base search and answer generation. + +Extracted from the Slack bot to be reusable across interfaces (Slack, MCP, API). +""" + +import logging + +from knowledge_base.rag.factory import get_llm +from knowledge_base.rag.exceptions import LLMError +from knowledge_base.search import HybridRetriever, SearchResult + +logger = logging.getLogger(__name__) + + +async def search_knowledge(query: str, limit: int = 5) -> list[SearchResult]: + """Search for relevant chunks using Graphiti hybrid search. + + Uses HybridRetriever which delegates to Graphiti's unified search: + - Semantic similarity (embeddings) + - BM25 keyword matching + - Graph relationships + + Returns SearchResult objects with content and metadata. + """ + logger.info(f"Searching for: '{query[:100]}...'") + + try: + retriever = HybridRetriever() + health = await retriever.check_health() + logger.info(f"Hybrid search health: {health}") + + # Use Graphiti hybrid search + results = await retriever.search(query, k=limit) + logger.info(f"Hybrid search returned {len(results)} results") + + # Log first result for debugging + if results: + first = results[0] + logger.info( + f"First result: chunk_id={first.chunk_id}, " + f"title={first.page_title}, content_len={len(first.content)}" + ) + + return results + + except Exception as e: + logger.error(f"Hybrid search FAILED (returning 0 results): {e}", exc_info=True) + + return [] + + +async def generate_answer( + question: str, + chunks: list[SearchResult], + conversation_history: list[dict[str, str]] | None = None, +) -> str: + """Generate an answer using LLM with retrieved chunks. + + Args: + question: The user's question + chunks: SearchResult objects from Graphiti containing content and metadata + conversation_history: Previous messages in the conversation thread + """ + if not chunks: + return "I couldn't find relevant information in the knowledge base to answer your question." + + # Build context from chunks (SearchResult has page_title property and content attribute) + context_parts = [] + for i, chunk in enumerate(chunks, 1): + context_parts.append( + f"[Source {i}: {chunk.page_title}]\n{chunk.content[:1000]}" + ) + context = "\n\n---\n\n".join(context_parts) + + # Build conversation history section + conversation_section = "" + if conversation_history: + history_parts = [] + for msg in conversation_history[-6:]: # Last 6 messages for context + role = "User" if msg["role"] == "user" else "Assistant" + # Truncate long messages in history + content = msg["content"][:500] + "..." if len(msg["content"]) > 500 else msg["content"] + history_parts.append(f"{role}: {content}") + if history_parts: + conversation_section = f""" +PREVIOUS CONVERSATION: +{chr(10).join(history_parts)} + +(Use this context to understand what the user is asking about and provide continuity) +""" + + prompt = f"""You are Keboola's internal knowledge base assistant. Answer questions ONLY based on the provided context documents. + +CRITICAL RULES: +- ONLY use information explicitly stated in the context documents below. +- Do NOT make up, assume, or hallucinate any information not in the documents. +- If the context doesn't contain enough information to answer, say so clearly. +- When referencing information, mention which source it came from. +{conversation_section} +CONTEXT DOCUMENTS: +{context} + +CURRENT QUESTION: {question} + +INSTRUCTIONS: +- Answer based strictly on the context documents above. +- Be concise and helpful. Use bullet points for multiple items. +- If the documents only partially answer the question, share what IS available and note what's missing. +- Do NOT invent tool names, process steps, or policies not mentioned in the documents. + +Provide your answer:""" + + try: + llm = await get_llm() + logger.info(f"Using LLM provider: {llm.provider_name}") + + # Skip health check - generate() has proper retry logic and error handling + answer = await llm.generate(prompt) + return answer.strip() + except LLMError as e: + logger.error(f"LLM provider error: {e}") + return ( + f"I found {len(chunks)} relevant documents but couldn't generate " + f"an answer at this time. Please try again later." + ) + except Exception as e: + logger.error(f"LLM generation failed: {e}") + return ( + f"I found {len(chunks)} relevant documents but couldn't generate " + f"an answer at this time. Please try again later." + ) diff --git a/src/knowledge_base/mcp/__init__.py b/src/knowledge_base/mcp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/knowledge_base/mcp/config.py b/src/knowledge_base/mcp/config.py new file mode 100644 index 0000000..32f9ff8 --- /dev/null +++ b/src/knowledge_base/mcp/config.py @@ -0,0 +1,105 @@ +"""MCP server configuration using pydantic-settings.""" + +from pydantic import field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class MCPSettings(BaseSettings): + """MCP server settings loaded from environment variables. + + Required variables (no defaults - fail fast if missing): + MCP_OAUTH_CLIENT_ID: Google OAuth client ID + MCP_OAUTH_RESOURCE_IDENTIFIER: Resource server identifier (e.g. "https://kb-mcp.keboola.com") + """ + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # OAuth 2.1 Configuration (Google OAuth) + MCP_OAUTH_CLIENT_ID: str # Required - fail fast if missing + MCP_OAUTH_CLIENT_SECRET: str # Required - needed for token exchange with Google + + @field_validator("MCP_OAUTH_CLIENT_ID", "MCP_OAUTH_CLIENT_SECRET") + @classmethod + def must_be_non_empty(cls, v: str, info) -> str: + """Ensure OAuth credentials are non-empty strings.""" + if not v or not v.strip(): + raise ValueError( + f"{info.field_name} must be a non-empty string" + ) + return v + + MCP_OAUTH_AUTHORIZATION_SERVER: str = "https://accounts.google.com" + MCP_OAUTH_AUTHORIZATION_ENDPOINT: str = ( + "https://accounts.google.com/o/oauth2/v2/auth" + ) + MCP_OAUTH_TOKEN_ENDPOINT: str = "https://oauth2.googleapis.com/token" + MCP_OAUTH_JWKS_URI: str = "https://www.googleapis.com/oauth2/v3/certs" + MCP_OAUTH_ISSUER: str = "https://accounts.google.com" + MCP_OAUTH_RESOURCE_IDENTIFIER: str # Required - e.g. "https://kb-mcp.keboola.com" + MCP_OAUTH_SCOPES: str = "openid email profile" + + # Rate Limiting + MCP_RATE_LIMIT_READ_PER_MINUTE: int = 30 + MCP_RATE_LIMIT_WRITE_PER_HOUR: int = 20 + + # Server + MCP_HOST: str = "0.0.0.0" + MCP_PORT: int = 8080 + MCP_DEV_MODE: bool = False + MCP_DEBUG: bool = False + + +# OAuth scope definitions +OAUTH_SCOPES: dict[str, str] = { + "openid": "OpenID Connect authentication", + "email": "User email address", + "profile": "User profile information", + "kb.read": "Search and query the knowledge base", + "kb.write": "Create knowledge and submit feedback", +} + +# Required scopes per MCP tool +TOOL_SCOPE_REQUIREMENTS: dict[str, list[str]] = { + "ask_question": ["kb.read"], + "search_knowledge": ["kb.read"], + "create_knowledge": ["kb.write"], + "ingest_document": ["kb.write"], + "submit_feedback": ["kb.write"], + "check_health": ["kb.read"], +} + +# Tools that perform write operations (subject to stricter rate limits) +WRITE_TOOLS: list[str] = [ + tool + for tool, scopes in TOOL_SCOPE_REQUIREMENTS.items() + if "kb.write" in scopes +] + +# Rate limit configuration keyed by operation type +RATE_LIMITS: dict[str, dict[str, int]] = { + "read": { + "requests": 30, # Default, overridden by MCPSettings.MCP_RATE_LIMIT_READ_PER_MINUTE + "window_seconds": 60, + }, + "write": { + "requests": 20, # Default, overridden by MCPSettings.MCP_RATE_LIMIT_WRITE_PER_HOUR + "window_seconds": 3600, + }, +} + + +def check_scope_access(required_scopes: list[str], granted_scopes: list[str]) -> bool: + """Check if any of the required scopes are in the granted scopes. + + Args: + required_scopes: Scopes required by the tool (at least one must match). + granted_scopes: Scopes granted to the authenticated user/token. + + Returns: + True if at least one required scope is present in granted scopes. + """ + return any(scope in granted_scopes for scope in required_scopes) diff --git a/src/knowledge_base/mcp/oauth/__init__.py b/src/knowledge_base/mcp/oauth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/knowledge_base/mcp/oauth/metadata.py b/src/knowledge_base/mcp/oauth/metadata.py new file mode 100644 index 0000000..2298d5a --- /dev/null +++ b/src/knowledge_base/mcp/oauth/metadata.py @@ -0,0 +1,68 @@ +""" +RFC 9728 OAuth Protected Resource Metadata + +Implements the Protected Resource Metadata endpoint for OAuth 2.1 resource servers. +""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ProtectedResourceMetadata: + """ + RFC 9728 Protected Resource Metadata. + + Describes the OAuth-protected resource server and its requirements. + This metadata is served at /.well-known/oauth-protected-resource + """ + + # Required: The resource identifier (typically the server URL) + resource: str + + # Required: List of authorization servers that can issue tokens + authorization_servers: list[str] + + # Optional: Supported OAuth scopes + scopes_supported: list[str] = field(default_factory=list) + + # Optional: Bearer token methods supported + bearer_methods_supported: list[str] = field( + default_factory=lambda: ["header"] + ) + + # Optional: Resource documentation URL + resource_documentation: Optional[str] = None + + # Optional: Additional resource metadata + resource_signing_alg_values_supported: list[str] = field( + default_factory=lambda: ["RS256", "ES256"] + ) + + def to_dict(self) -> dict: + """Serialize metadata to dictionary for JSON response.""" + result = { + "resource": self.resource, + "authorization_servers": self.authorization_servers, + } + + if self.scopes_supported: + result["scopes_supported"] = self.scopes_supported + + if self.bearer_methods_supported: + result["bearer_methods_supported"] = self.bearer_methods_supported + + if self.resource_documentation: + result["resource_documentation"] = self.resource_documentation + + if self.resource_signing_alg_values_supported: + result["resource_signing_alg_values_supported"] = ( + self.resource_signing_alg_values_supported + ) + + return result + + def to_json(self) -> str: + """Serialize metadata to JSON string.""" + import json + return json.dumps(self.to_dict(), indent=2) diff --git a/src/knowledge_base/mcp/oauth/resource_server.py b/src/knowledge_base/mcp/oauth/resource_server.py new file mode 100644 index 0000000..0c2b664 --- /dev/null +++ b/src/knowledge_base/mcp/oauth/resource_server.py @@ -0,0 +1,259 @@ +""" +OAuth 2.1 Resource Server Implementation + +Provides FastAPI middleware for OAuth token validation and user context extraction. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from .token_validator import TokenValidator, TokenValidationError +from .metadata import ProtectedResourceMetadata + +logger = logging.getLogger(__name__) + + +def extract_user_context(claims: dict[str, Any]) -> dict[str, Any]: + """ + Extract user context from validated token claims. + + Args: + claims: Decoded JWT claims + + Returns: + User context dictionary with: + - email: User's email address + - scopes: List of granted scopes + - sub: Subject identifier + + Scope logic: + - All verified Google users get: kb.read + - Verified @keboola.com users additionally get: kb.write + - Non-Google tokens: scopes extracted from the 'scope' claim as-is + + Note: + Google ID tokens don't include a 'scope' claim. For Google OAuth, + we grant default scopes based on email verification and domain. + """ + # Extract scopes (space-separated string to list) + scope_string = claims.get("scope", "") + scopes = scope_string.split() if scope_string else [] + + # For Google OAuth tokens without scope claim, grant default scopes + # Google tokens have 'iss' = 'https://accounts.google.com' and 'email_verified' + if not scopes and claims.get("iss") == "https://accounts.google.com": + email = claims.get("email", "") + email_verified = claims.get("email_verified", False) + + if email_verified and email: + # Grant read access to all verified Google users + scopes = [ + "openid", + "email", + "profile", + "kb.read", + ] + logger.info("Google OAuth: granted default scopes for verified user") + + # Grant write access for @keboola.com domain (internal users) + # email_verified is already checked above, but enforce explicitly for write scope + if email.endswith("@keboola.com") and email_verified: + scopes.append("kb.write") + logger.info("Google OAuth: granted write scope for internal user") + + return { + "sub": claims.get("sub"), + "email": claims.get("email", claims.get("sub")), + "scopes": scopes, + "claims": claims, + } + + +@dataclass +class OAuthResourceServer: + """ + OAuth 2.1 Resource Server configuration. + + Manages token validation and protected resource metadata. + """ + + resource: str + authorization_servers: list[str] + audience: str + scopes_supported: list[str] = field(default_factory=list) + + # Internal components + _validator: Optional[TokenValidator] = field(default=None, repr=False) + _metadata: Optional[ProtectedResourceMetadata] = field(default=None, repr=False) + + def __post_init__(self): + """Initialize internal components.""" + if self.authorization_servers: + self._validator = TokenValidator( + issuer=self.authorization_servers[0], + audience=self.audience, + ) + + self._metadata = ProtectedResourceMetadata( + resource=self.resource, + authorization_servers=self.authorization_servers, + scopes_supported=self.scopes_supported, + ) + + @property + def metadata(self) -> ProtectedResourceMetadata: + """Get protected resource metadata.""" + return self._metadata + + @property + def validator(self) -> TokenValidator: + """Get token validator.""" + if not self._validator: + raise RuntimeError("No authorization server configured") + return self._validator + + def validate_token(self, token: str) -> dict[str, Any]: + """Validate access token and return claims.""" + return self.validator.validate(token) + + async def validate_token_async(self, token: str) -> dict[str, Any]: + """Validate access token asynchronously.""" + return await self.validator.validate_async(token) + + +class OAuthMiddleware(BaseHTTPMiddleware): + """ + FastAPI/Starlette middleware for OAuth token validation. + + Validates Bearer tokens in Authorization header and adds + user context to request.state. + """ + + def __init__( + self, + app, + resource_server: Optional[OAuthResourceServer] = None, + exclude_paths: Optional[list[str]] = None, + dev_mode: bool = False, + ): + """ + Initialize OAuth middleware. + + Args: + app: FastAPI/Starlette application + resource_server: OAuth resource server configuration + exclude_paths: Paths to exclude from auth (e.g., /health) + dev_mode: If True, skip validation (for development only) + """ + super().__init__(app) + self.resource_server = resource_server + self.exclude_paths = exclude_paths or [ + "/health", + "/.well-known/oauth-protected-resource", + "/callback", + ] + self.dev_mode = dev_mode + + def _extract_token(self, request: Request) -> Optional[str]: + """Extract Bearer token from Authorization header.""" + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] + return None + + def _should_skip_auth(self, request: Request) -> bool: + """Check if path should skip authentication.""" + path = request.url.path + return any( + path == excluded or path.startswith(f"{excluded}/") + for excluded in self.exclude_paths + ) + + def _unauthorized_response(self, error: str = "Unauthorized") -> Response: + """Return 401 Unauthorized response with WWW-Authenticate header.""" + return JSONResponse( + status_code=401, + content={"error": "unauthorized", "error_description": error}, + headers={ + "WWW-Authenticate": 'Bearer realm="knowledge-base-mcp", error="invalid_token"' + }, + ) + + def _forbidden_response(self, error: str = "Forbidden") -> Response: + """Return 403 Forbidden response.""" + return JSONResponse( + status_code=403, + content={"error": "insufficient_scope", "error_description": error}, + ) + + async def dispatch( + self, request: Request, call_next: Callable + ) -> Response: + """Process request through OAuth validation.""" + + # Skip auth for excluded paths + if self._should_skip_auth(request): + return await call_next(request) + + # Extract token + token = self._extract_token(request) + if not token: + return self._unauthorized_response("Missing Bearer token") + + # Dev mode: skip validation + if self.dev_mode: + request.state.user = { + "sub": "dev-user", + "email": "dev@example.com", + "scopes": ["openid", "kb.read", "kb.write"], + "claims": {}, + } + return await call_next(request) + + # Validate token + if not self.resource_server: + return self._unauthorized_response("OAuth not configured") + + try: + claims = await self.resource_server.validate_token_async(token) + request.state.user = extract_user_context(claims) + return await call_next(request) + except TokenValidationError: + return self._unauthorized_response("Token validation failed") + except Exception: + return self._unauthorized_response("Token validation failed") + + +def require_scopes(*required_scopes: str): + """ + Dependency for requiring specific OAuth scopes. + + Usage: + @app.get("/api/search") + async def search(user: dict = Depends(require_scopes("kb.read"))): + ... + """ + from fastapi import HTTPException, Request + + async def dependency(request: Request) -> dict: + user = getattr(request.state, "user", None) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + user_scopes = user.get("scopes", []) + has_scope = any(scope in user_scopes for scope in required_scopes) + + if not has_scope: + raise HTTPException( + status_code=403, + detail=f"Required scope: {' or '.join(required_scopes)}", + ) + + return user + + return dependency diff --git a/src/knowledge_base/mcp/oauth/token_validator.py b/src/knowledge_base/mcp/oauth/token_validator.py new file mode 100644 index 0000000..9aa647b --- /dev/null +++ b/src/knowledge_base/mcp/oauth/token_validator.py @@ -0,0 +1,271 @@ +""" +JWT Token Validator for OAuth 2.1 Resource Server + +Validates access tokens issued by the authorization server. +""" + +import time +from dataclasses import dataclass, field +from typing import Any, Optional +import logging + +import httpx + +logger = logging.getLogger(__name__) + + +class TokenValidationError(Exception): + """Base exception for token validation errors.""" + pass + + +class TokenExpiredError(TokenValidationError): + """Token has expired.""" + pass + + +class InvalidTokenError(TokenValidationError): + """Token is invalid (bad signature, format, etc.).""" + pass + + +class InvalidIssuerError(TokenValidationError): + """Token issuer doesn't match expected issuer.""" + pass + + +class InvalidAudienceError(TokenValidationError): + """Token audience doesn't match expected audience.""" + pass + + +@dataclass +class TokenValidator: + """ + Validates JWT access tokens using JWKS from authorization server. + + Supports: + - RS256, RS384, RS512 (RSA) + - ES256, ES384, ES512 (ECDSA) + - Google OAuth (ID tokens with client_id as audience) + """ + + issuer: str + audience: str + jwks_uri: Optional[str] = None + jwks: dict = field(default_factory=dict) + _jwks_cache_time: float = field(default=0.0, repr=False) + _jwks_cache_ttl: int = 3600 # 1 hour + # Google OAuth specific + is_google: bool = False + authorized_party: Optional[str] = None # For Google: must match client_id + + def __post_init__(self): + """Initialize JWKS URI from issuer if not provided.""" + if not self.jwks_uri: + # Google uses a different JWKS endpoint + if self.issuer == "https://accounts.google.com": + self.jwks_uri = "https://www.googleapis.com/oauth2/v3/certs" + self.is_google = True + else: + self.jwks_uri = f"{self.issuer.rstrip('/')}/.well-known/jwks.json" + + async def fetch_jwks(self) -> dict: + """ + Fetch JWKS from authorization server. + + Caches the JWKS for _jwks_cache_ttl seconds. + """ + now = time.time() + if self.jwks and (now - self._jwks_cache_time) < self._jwks_cache_ttl: + return self.jwks + + async with httpx.AsyncClient() as client: + response = await client.get(self.jwks_uri) + response.raise_for_status() + self.jwks = response.json() + self._jwks_cache_time = now + + return self.jwks + + def validate(self, token: str) -> dict[str, Any]: + """ + Validate a token synchronously. + + For Google OAuth, handles both: + - JWT id_tokens (3 parts separated by dots) + - Opaque access_tokens (validated via Google's tokeninfo endpoint) + + Args: + token: Token string (JWT or opaque) + + Returns: + Token claims if valid + + Raises: + TokenValidationError: If token is invalid + """ + try: + import jwt + from jwt import PyJWKClient + except ImportError: + raise TokenValidationError( + "PyJWT library required for token validation" + ) + + # Check if this is a JWT (3 parts) or opaque token + parts = token.split(".") + if len(parts) != 3: + # Not a JWT - for Google, validate via tokeninfo endpoint + if self.is_google: + return self._validate_google_access_token(token) + raise InvalidTokenError("Invalid JWT format") + + try: + # Get signing key from JWKS + jwks_client = PyJWKClient(self.jwks_uri) + signing_key = jwks_client.get_signing_key_from_jwt(token) + + # Google ID tokens use client_id as audience + if self.is_google: + # For Google: audience is the client_id + claims = jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + issuer=self.issuer, + audience=self.audience, # This is the Google client_id + options={ + "require": ["exp", "iss", "aud", "email"], + "verify_exp": True, + "verify_iss": True, + "verify_aud": True, + } + ) + + # Additional Google-specific checks + # Verify azp (authorized party) if provided + if self.authorized_party and claims.get("azp"): + if claims["azp"] != self.authorized_party: + raise InvalidTokenError( + f"Invalid authorized party: {claims['azp']}" + ) + + # Google tokens should have email_verified + return claims + else: + # Standard OAuth 2.0 token validation + claims = jwt.decode( + token, + signing_key.key, + algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"], + issuer=self.issuer, + audience=self.audience, + options={ + "require": ["exp", "iss", "aud"], + "verify_exp": True, + "verify_iss": True, + "verify_aud": True, + } + ) + + return claims + + except jwt.ExpiredSignatureError: + raise TokenExpiredError("Token has expired") + except jwt.InvalidIssuerError: + raise InvalidIssuerError(f"Invalid issuer, expected {self.issuer}") + except jwt.InvalidAudienceError: + raise InvalidAudienceError(f"Invalid audience, expected {self.audience}") + except jwt.InvalidTokenError as e: + raise InvalidTokenError(f"Invalid token: {type(e).__name__}") + except Exception as e: + raise TokenValidationError(f"Token validation failed: {type(e).__name__}") + + def _validate_google_access_token(self, token: str) -> dict[str, Any]: + """ + Validate a Google opaque access token using tokeninfo endpoint. + + Google access tokens are not JWTs, so we validate them by calling + Google's tokeninfo endpoint which returns user info if valid. + + Args: + token: Google access token (opaque string) + + Returns: + Token claims including email, sub, etc. + + Raises: + TokenValidationError: If token is invalid + """ + import httpx + + try: + # Call Google's tokeninfo endpoint + response = httpx.get( + "https://www.googleapis.com/oauth2/v3/tokeninfo", + params={"access_token": token}, + timeout=10.0 + ) + + if response.status_code != 200: + error_data = response.json() if response.text else {} + raise InvalidTokenError( + f"Google token validation failed: {error_data.get('error_description', 'Unknown error')}" + ) + + claims = response.json() + + # Verify audience matches our client_id + token_aud = claims.get("aud") or claims.get("azp") + if token_aud != self.audience: + raise InvalidAudienceError( + f"Invalid audience: expected {self.audience}, got {token_aud}" + ) + + # Check expiration + exp = claims.get("expires_in") + if exp is not None and int(exp) <= 0: + raise TokenExpiredError("Token has expired") + + # Normalize claims to match JWT format + normalized = { + "sub": claims.get("sub"), + "email": claims.get("email"), + "email_verified": claims.get("email_verified") == "true", + "aud": token_aud, + "iss": "https://accounts.google.com", + "scope": claims.get("scope", ""), + } + + return normalized + + except httpx.RequestError as e: + raise TokenValidationError(f"Failed to validate token with Google: {type(e).__name__}") + + async def validate_async(self, token: str) -> dict[str, Any]: + """ + Validate a token asynchronously. + + Fetches JWKS if not cached before validation. + """ + await self.fetch_jwks() + return self.validate(token) + + def get_claims(self, token: str) -> dict[str, Any]: + """ + Extract claims from token without full validation. + + WARNING: Only use this for debugging/logging. + Always use validate() for actual token verification. + """ + try: + import jwt + # Decode without verification + claims = jwt.decode( + token, + options={"verify_signature": False} + ) + return claims + except Exception as e: + raise InvalidTokenError(f"Cannot decode token: {e}") diff --git a/src/knowledge_base/mcp/server.py b/src/knowledge_base/mcp/server.py new file mode 100644 index 0000000..7f5acac --- /dev/null +++ b/src/knowledge_base/mcp/server.py @@ -0,0 +1,517 @@ +"""MCP HTTP Server with OAuth 2.1 for Knowledge Base. + +Provides Streamable HTTP transport for MCP protocol with Google OAuth authentication. +Acts as both OAuth Authorization Server (proxying to Google) and Resource Server. +Compatible with Claude.AI remote MCP server integration. +""" + +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import Any, Optional +from urllib.parse import urlencode + +import httpx +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse +from pydantic import BaseModel + +from knowledge_base.mcp.config import ( + MCPSettings, + OAUTH_SCOPES, + TOOL_SCOPE_REQUIREMENTS, + check_scope_access, +) +from knowledge_base.mcp.oauth.metadata import ProtectedResourceMetadata +from knowledge_base.mcp.oauth.resource_server import ( + OAuthResourceServer, + extract_user_context, +) +from knowledge_base.mcp.tools import TOOLS, execute_tool, get_tools_for_scopes + +logger = logging.getLogger(__name__) + +# Initialize settings +mcp_settings = MCPSettings() + + +def _get_oauth_audience() -> str: + """Get OAuth audience. For Google OAuth, audience is the client_id.""" + return mcp_settings.MCP_OAUTH_CLIENT_ID + + +def _get_advertised_scopes() -> list[str]: + """Get scopes to advertise. Google only understands standard OpenID scopes.""" + return ["openid", "email", "profile"] + + +# Initialize resource server +resource_server = OAuthResourceServer( + resource=mcp_settings.MCP_OAUTH_RESOURCE_IDENTIFIER, + authorization_servers=[mcp_settings.MCP_OAUTH_AUTHORIZATION_SERVER], + audience=_get_oauth_audience(), + scopes_supported=_get_advertised_scopes(), +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager.""" + global resource_server + + resource_server = OAuthResourceServer( + resource=mcp_settings.MCP_OAUTH_RESOURCE_IDENTIFIER, + authorization_servers=[mcp_settings.MCP_OAUTH_AUTHORIZATION_SERVER], + audience=_get_oauth_audience(), + scopes_supported=_get_advertised_scopes(), + ) + + logger.info(f"MCP Server started on {mcp_settings.MCP_HOST}:{mcp_settings.MCP_PORT}") + + yield + + +# Create FastAPI app +app = FastAPI( + title="Knowledge Base MCP Server", + description="MCP server for Keboola AI Knowledge Base with OAuth 2.1 authentication", + version="0.1.0", + lifespan=lifespan, +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["https://claude.ai", "https://www.claude.ai"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ============================================================================= +# OAuth Middleware +# ============================================================================= + + +@app.middleware("http") +async def oauth_middleware(request: Request, call_next): + """OAuth authentication middleware.""" + skip_paths = [ + "/health", + "/.well-known/oauth-protected-resource", + "/.well-known/oauth-authorization-server", + "/authorize", + "/token", + "/register", + "/callback", + "/", + ] + if request.url.path in skip_paths: + return await call_next(request) + + # Extract Bearer token + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + return JSONResponse( + status_code=401, + content={"error": "unauthorized", "error_description": "Missing Bearer token"}, + headers={"WWW-Authenticate": 'Bearer realm="knowledge-base-mcp"'}, + ) + + token = auth_header[7:] + + # Dev mode: skip validation + if mcp_settings.MCP_DEV_MODE: + import os + + dev_email = os.getenv("TEST_USER_EMAIL", "dev@keboola.com") + request.state.user = { + "sub": "dev-user", + "email": dev_email, + "scopes": list(OAUTH_SCOPES.keys()), + "claims": {}, + } + return await call_next(request) + + # Validate token + if resource_server: + try: + claims = await resource_server.validate_token_async(token) + request.state.user = extract_user_context(claims) + return await call_next(request) + except Exception: + return JSONResponse( + status_code=401, + content={"error": "invalid_token", "error_description": "Token validation failed"}, + headers={ + "WWW-Authenticate": 'Bearer realm="knowledge-base-mcp", error="invalid_token"' + }, + ) + + return await call_next(request) + + +# ============================================================================= +# Health & Metadata Endpoints +# ============================================================================= + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy", "service": "knowledge-base-mcp-server"} + + +@app.get("/.well-known/oauth-protected-resource") +async def oauth_protected_resource_metadata(): + """RFC 9728 Protected Resource Metadata endpoint.""" + if not resource_server: + raise HTTPException(status_code=503, detail="OAuth not configured") + return resource_server.metadata.to_dict() + + +@app.get("/callback") +async def oauth_callback(code: str = None, state: str = None, error: str = None): + """OAuth authorization code callback (for browser popup flows).""" + if error: + return HTMLResponse( + content=f"

OAuth Error

{error}

", + status_code=400, + ) + + if code: + return HTMLResponse( + content=""" + + +

Authorization Successful

+

You can close this window and return to Claude.

+ + + + """ + % (code, state or ""), + status_code=200, + ) + + return HTMLResponse( + content="

OAuth Callback

", + status_code=200, + ) + + +# ============================================================================= +# OAuth Authorization Server Endpoints (proxy to Google) +# ============================================================================= +# The MCP spec requires the MCP server to act as an OAuth Authorization Server. +# Claude.AI discovers these endpoints via /.well-known/oauth-authorization-server +# or falls back to default paths (/authorize, /token, /register). +# We proxy the OAuth flow to Google as the upstream identity provider. + + +@app.get("/.well-known/oauth-authorization-server") +async def oauth_authorization_server_metadata(request: Request): + """RFC 8414 OAuth Authorization Server Metadata. + + Tells MCP clients (Claude.AI) where our authorization endpoints are. + """ + base_url = _get_base_url(request) + + return { + "issuer": base_url, + "authorization_endpoint": f"{base_url}/authorize", + "token_endpoint": f"{base_url}/token", + "registration_endpoint": f"{base_url}/register", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], + "scopes_supported": list(OAUTH_SCOPES.keys()), + } + + +@app.get("/authorize") +async def oauth_authorize(request: Request): + """OAuth authorization endpoint - redirects to Google OAuth. + + Claude.AI sends the user here. We redirect to Google's authorize endpoint, + mapping scopes and preserving PKCE parameters. Google redirects back to + Claude.AI's callback directly. + """ + params = dict(request.query_params) + + # Map any non-Google scopes to Google-compatible scopes + requested_scope = params.get("scope", "") + google_scopes = _map_to_google_scopes(requested_scope) + params["scope"] = google_scopes + + # Ensure client_id is set (use ours if not provided) + if "client_id" not in params or not params["client_id"]: + params["client_id"] = mcp_settings.MCP_OAUTH_CLIENT_ID + + google_authorize_url = ( + f"{mcp_settings.MCP_OAUTH_AUTHORIZATION_ENDPOINT}?{urlencode(params)}" + ) + return RedirectResponse(url=google_authorize_url, status_code=302) + + +@app.post("/token") +async def oauth_token(request: Request): + """OAuth token endpoint - proxies token exchange to Google. + + Claude.AI sends the authorization code here. We forward it to Google's + token endpoint, adding our client_secret for the exchange. + """ + # Parse form data (OAuth token requests use application/x-www-form-urlencoded) + form_data = await request.form() + token_params = dict(form_data) + + # Add our client credentials for the token exchange + token_params["client_id"] = mcp_settings.MCP_OAUTH_CLIENT_ID + token_params["client_secret"] = mcp_settings.MCP_OAUTH_CLIENT_SECRET + + async with httpx.AsyncClient() as client: + response = await client.post( + mcp_settings.MCP_OAUTH_TOKEN_ENDPOINT, + data=token_params, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + # Return Google's response directly to Claude.AI + return JSONResponse( + status_code=response.status_code, + content=response.json(), + ) + + +@app.post("/register") +async def oauth_register(request: Request): + """OAuth Dynamic Client Registration (RFC 7591). + + Returns our Google OAuth client_id so Claude.AI can use it for the + authorization flow. This is a simplified registration that always + returns the same client credentials. + """ + body = await request.json() + redirect_uris = body.get("redirect_uris", []) + client_name = body.get("client_name", "MCP Client") + + return JSONResponse( + status_code=201, + content={ + "client_id": mcp_settings.MCP_OAUTH_CLIENT_ID, + "client_name": client_name, + "redirect_uris": redirect_uris, + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + }, + ) + + +def _get_base_url(request: Request) -> str: + """Get the external base URL, respecting X-Forwarded-Proto from reverse proxies.""" + scheme = request.headers.get("x-forwarded-proto", request.url.scheme) + host = request.headers.get("host", request.url.netloc) + return f"{scheme}://{host}" + + +def _map_to_google_scopes(requested_scope: str) -> str: + """Map MCP/custom scopes to Google-compatible OAuth scopes. + + Claude.AI may send scopes like 'claudeai' or our custom 'kb.read kb.write'. + Google only understands standard OpenID scopes. + """ + google_scopes = {"openid", "email", "profile"} + + if requested_scope: + for scope in requested_scope.split(): + # Keep standard scopes, discard custom ones + if scope in ("openid", "email", "profile"): + google_scopes.add(scope) + + return " ".join(sorted(google_scopes)) + + +# ============================================================================= +# MCP Protocol Endpoint +# ============================================================================= + + +class MCPRequest(BaseModel): + """MCP JSON-RPC request.""" + + jsonrpc: str = "2.0" + method: str + params: Optional[dict[str, Any]] = None + id: Optional[int | str] = None + + +class MCPResponse(BaseModel): + """MCP JSON-RPC response.""" + + jsonrpc: str = "2.0" + result: Optional[Any] = None + error: Optional[dict[str, Any]] = None + id: Optional[int | str] = None + + +@app.get("/mcp") +async def mcp_sse_endpoint(request: Request): + """MCP SSE endpoint for server-initiated messages.""" + from starlette.responses import StreamingResponse + + user = getattr(request.state, "user", None) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + async def event_generator(): + yield "data: {}\n\n" + while True: + await asyncio.sleep(30) + yield ": heartbeat\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@app.post("/mcp") +async def mcp_endpoint(request: Request, mcp_request: MCPRequest): + """MCP JSON-RPC endpoint.""" + user = getattr(request.state, "user", None) + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + method = mcp_request.method + params = mcp_request.params or {} + + try: + if method == "initialize": + result = { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {"listChanged": False}, + "resources": {"subscribe": False, "listChanged": False}, + }, + "serverInfo": { + "name": "keboola-knowledge-base", + "version": "0.1.0", + }, + } + elif method == "notifications/initialized": + return MCPResponse(id=mcp_request.id, result={}) + elif method == "tools/list": + result = await handle_tools_list(user) + elif method == "tools/call": + result = await handle_tools_call(params, user) + elif method == "resources/list": + result = {"resources": []} + elif method == "resources/read": + return MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": "No resources available"}, + ) + elif method == "ping": + result = {} + else: + return MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {method}"}, + ) + + return MCPResponse(id=mcp_request.id, result=result) + + except HTTPException as e: + return MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": e.detail}, + ) + except Exception as e: + logger.exception(f"Error handling MCP request: {e}") + return MCPResponse( + id=mcp_request.id, + error={"code": -32603, "message": str(e)}, + ) + + +async def handle_tools_list(user: dict) -> dict: + """Handle tools/list MCP method.""" + user_scopes = user.get("scopes", []) + accessible_tools = get_tools_for_scopes(user_scopes) + + return { + "tools": [ + { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.inputSchema, + } + for tool in accessible_tools + ] + } + + +async def handle_tools_call(params: dict, user: dict) -> dict: + """Handle tools/call MCP method.""" + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not tool_name: + raise HTTPException(status_code=400, detail="Missing tool name") + + # Check scope access + user_scopes = user.get("scopes", []) + required_scopes = TOOL_SCOPE_REQUIREMENTS.get(tool_name, ["kb.read"]) + + if not check_scope_access(required_scopes, user_scopes): + raise HTTPException( + status_code=403, + detail=f"Insufficient scope for tool: {tool_name}", + ) + + result = await execute_tool(tool_name, arguments, user) + + return { + "content": [{"type": r.type, "text": r.text} for r in result], + } + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + + +def main(): + """Run MCP HTTP server.""" + import uvicorn + + # Configure logging + logging.basicConfig( + level=logging.DEBUG if mcp_settings.MCP_DEBUG else logging.INFO, + format='{"time": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "message": "%(message)s"}', + force=True, + ) + + uvicorn.run( + "knowledge_base.mcp.server:app", + host=mcp_settings.MCP_HOST, + port=mcp_settings.MCP_PORT, + reload=mcp_settings.MCP_DEBUG, + ) + + +if __name__ == "__main__": + main() diff --git a/src/knowledge_base/mcp/tools.py b/src/knowledge_base/mcp/tools.py new file mode 100644 index 0000000..f6da0f2 --- /dev/null +++ b/src/knowledge_base/mcp/tools.py @@ -0,0 +1,474 @@ +"""MCP tool definitions and execution dispatcher for Knowledge Base.""" + +import json +import logging +import uuid +from datetime import datetime +from typing import Any +from urllib.parse import quote + +from mcp.types import TextContent, Tool + +from knowledge_base.mcp.config import TOOL_SCOPE_REQUIREMENTS, check_scope_access + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS = [ + Tool( + name="ask_question", + description=( + "Ask a question and get an answer with sources from the Keboola knowledge base. " + "The answer is generated using RAG (retrieval-augmented generation) from indexed " + "Confluence pages, quick facts, and ingested documents." + ), + inputSchema={ + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The question to ask the knowledge base", + }, + "conversation_history": { + "type": "array", + "description": "Optional previous messages for context continuity", + "items": { + "type": "object", + "properties": { + "role": {"type": "string", "enum": ["user", "assistant"]}, + "content": {"type": "string"}, + }, + "required": ["role", "content"], + }, + }, + }, + "required": ["question"], + }, + ), + Tool( + name="search_knowledge", + description=( + "Search the Keboola knowledge base for documents matching a query. " + "Returns ranked results with titles, content snippets, scores, and Confluence URLs. " + "Uses hybrid search combining semantic similarity, keyword matching, and graph relationships." + ), + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query", + }, + "top_k": { + "type": "integer", + "description": "Number of results to return (1-20, default 5)", + "minimum": 1, + "maximum": 20, + "default": 5, + }, + "filters": { + "type": "object", + "description": "Optional filters to narrow results", + "properties": { + "space_key": { + "type": "string", + "description": "Filter by Confluence space key", + }, + "doc_type": { + "type": "string", + "description": "Filter by document type (e.g., webpage, pdf, quick_fact)", + }, + "topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Filter by topics (any match)", + }, + "updated_after": { + "type": "string", + "description": "Filter by update date (ISO format)", + }, + }, + }, + }, + "required": ["query"], + }, + ), + Tool( + name="create_knowledge", + description=( + "Create a quick knowledge fact in the Keboola knowledge base. " + "The fact is indexed directly into Graphiti (Neo4j knowledge graph) " + "and becomes immediately searchable." + ), + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The knowledge content to save", + }, + "topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional topic tags for the knowledge", + }, + }, + "required": ["content"], + }, + ), + Tool( + name="ingest_document", + description=( + "Ingest an external document into the Keboola knowledge base. " + "Supports web pages (HTML), PDFs, Google Docs (public/link-shared), " + "and Notion pages (public). The document is chunked and indexed into Graphiti." + ), + inputSchema={ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL of the document to ingest", + }, + "title": { + "type": "string", + "description": "Optional title override for the document", + }, + }, + "required": ["url"], + }, + ), + Tool( + name="submit_feedback", + description=( + "Submit feedback on a knowledge base chunk. This affects the chunk's " + "quality score: 'helpful' increases it, while 'outdated', 'incorrect', " + "and 'confusing' decrease it. Low-scoring chunks are eventually archived." + ), + inputSchema={ + "type": "object", + "properties": { + "chunk_id": { + "type": "string", + "description": "The ID of the chunk to give feedback on", + }, + "feedback_type": { + "type": "string", + "enum": ["helpful", "outdated", "incorrect", "confusing"], + "description": "Type of feedback", + }, + "details": { + "type": "string", + "description": "Optional details or correction suggestions", + }, + }, + "required": ["chunk_id", "feedback_type"], + }, + ), + Tool( + name="check_health", + description=( + "Check the health of the Keboola knowledge base system. " + "Returns status of Neo4j graph database, LLM provider, and search subsystem." + ), + inputSchema={ + "type": "object", + "properties": {}, + }, + ), +] + + +def get_tools_for_scopes(user_scopes: list[str]) -> list[Tool]: + """Return tools accessible to the user based on their scopes.""" + accessible = [] + for tool in TOOLS: + required = TOOL_SCOPE_REQUIREMENTS.get(tool.name, ["kb.read"]) + if check_scope_access(required, user_scopes): + accessible.append(tool) + return accessible + + +# ============================================================================= +# Tool Execution +# ============================================================================= + + +async def execute_tool( + tool_name: str, + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute a tool and return results as TextContent list.""" + logger.info(f"Executing tool: {tool_name}, user: {user.get('sub', 'unknown')}") + + try: + if tool_name == "ask_question": + return await _execute_ask_question(arguments, user) + elif tool_name == "search_knowledge": + return await _execute_search_knowledge(arguments, user) + elif tool_name == "create_knowledge": + return await _execute_create_knowledge(arguments, user) + elif tool_name == "ingest_document": + return await _execute_ingest_document(arguments, user) + elif tool_name == "submit_feedback": + return await _execute_submit_feedback(arguments, user) + elif tool_name == "check_health": + return await _execute_check_health(arguments, user) + else: + return [TextContent(type="text", text=f"Unknown tool: {tool_name}")] + except Exception as e: + logger.error(f"Tool execution failed: {tool_name}: {e}", exc_info=True) + return [TextContent(type="text", text=f"Error executing {tool_name}: {str(e)}")] + + +async def _execute_ask_question( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute ask_question tool.""" + from knowledge_base.core.qa import generate_answer, search_knowledge + + question = arguments["question"] + conversation_history = arguments.get("conversation_history") + + # Search for relevant chunks + chunks = await search_knowledge(question, limit=5) + + # Generate answer + answer = await generate_answer(question, chunks, conversation_history) + + # Build sources section + sources = [] + for chunk in chunks[:3]: + metadata = chunk.metadata if hasattr(chunk, "metadata") else {} + url = metadata.get("url", "") + title = chunk.page_title if hasattr(chunk, "page_title") else "Unknown" + if url: + sources.append(f"- [{title}]({url})") + else: + sources.append(f"- {title}") + + result = answer + if sources: + result += "\n\nSources:\n" + "\n".join(sources) + + return [TextContent(type="text", text=result)] + + +async def _execute_search_knowledge( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute search_knowledge tool.""" + from knowledge_base.core.qa import search_knowledge + + query = arguments["query"] + top_k = arguments.get("top_k", 5) + filters = arguments.get("filters") + + # Search + results = await search_knowledge(query, limit=top_k * 2 if filters else top_k) + + # Apply filters if present + if filters: + results = _apply_filters(results, filters) + + # Limit to requested count + results = results[:top_k] + + if not results: + return [TextContent(type="text", text=f"No results found for: {query}")] + + # Format results + lines = [f"Found {len(results)} results for: {query}\n"] + for i, r in enumerate(results, 1): + metadata = r.metadata if hasattr(r, "metadata") else {} + url = metadata.get("url", "") + title = r.page_title if hasattr(r, "page_title") else "Unknown" + content_preview = r.content[:200] + "..." if len(r.content) > 200 else r.content + score = f"{r.score:.3f}" if hasattr(r, "score") else "N/A" + + lines.append(f"### {i}. {title}") + lines.append(f"**Score:** {score} | **Chunk ID:** {r.chunk_id}") + if url: + lines.append(f"**URL:** {url}") + lines.append(f"\n{content_preview}\n") + + return [TextContent(type="text", text="\n".join(lines))] + + +def _apply_filters(results: list, filters: dict) -> list: + """Apply metadata filters to search results.""" + filtered = [] + for r in results: + metadata = r.metadata if hasattr(r, "metadata") else {} + + if "space_key" in filters and metadata.get("space_key") != filters["space_key"]: + continue + if "doc_type" in filters and metadata.get("doc_type") != filters["doc_type"]: + continue + if "topics" in filters: + result_topics = metadata.get("topics", []) + if isinstance(result_topics, str): + try: + result_topics = json.loads(result_topics) + except (json.JSONDecodeError, TypeError): + result_topics = [result_topics] + if not any(t in result_topics for t in filters["topics"]): + continue + if "updated_after" in filters: + updated_at = metadata.get("updated_at", "") + if updated_at and updated_at < filters["updated_after"]: + continue + + filtered.append(r) + return filtered + + +async def _execute_create_knowledge( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute create_knowledge tool.""" + from knowledge_base.graph.graphiti_indexer import GraphitiIndexer + from knowledge_base.vectorstore.indexer import ChunkData + + content = arguments["content"] + topics = arguments.get("topics", []) + user_email = user.get("email", "mcp-user") + + # Create unique IDs + page_id = f"mcp_{uuid.uuid4().hex[:16]}" + chunk_id = f"{page_id}_0" + now = datetime.utcnow() + + chunk_data = ChunkData( + chunk_id=chunk_id, + content=content, + page_id=page_id, + page_title=f"Quick Fact by {user_email}", + chunk_index=0, + space_key="MCP", + url=f"mcp://user/{quote(user_email, safe='')}", + author=user_email, + created_at=now.isoformat(), + updated_at=now.isoformat(), + chunk_type="text", + parent_headers="[]", + quality_score=100.0, + access_count=0, + feedback_count=0, + owner=user_email, + reviewed_by="", + reviewed_at="", + classification="internal", + doc_type="quick_fact", + topics=json.dumps(topics) if topics else "[]", + audience="[]", + complexity="", + summary=content[:200] if len(content) > 200 else content, + ) + + indexer = GraphitiIndexer() + await indexer.index_single_chunk(chunk_data) + + logger.info(f"Created knowledge via MCP: {chunk_id} by {user_email}") + + return [TextContent( + type="text", + text=f"Knowledge saved successfully.\n\n**Chunk ID:** {chunk_id}\n**Content:** {content[:200]}", + )] + + +async def _execute_ingest_document( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute ingest_document tool.""" + from knowledge_base.slack.ingest_doc import get_ingester + + url = arguments["url"] + user_email = user.get("email", "mcp-user") + + ingester = get_ingester() + result = await ingester.ingest_url( + url=url, + created_by=user_email, + channel_id="mcp", + ) + + if result["status"] == "success": + return [TextContent( + type="text", + text=( + f"Document ingested successfully.\n\n" + f"**Title:** {result['title']}\n" + f"**Source type:** {result['source_type']}\n" + f"**Chunks created:** {result['chunks_created']}\n" + f"**Page ID:** {result['page_id']}" + ), + )] + else: + return [TextContent( + type="text", + text=f"Failed to ingest document: {result.get('error', 'Unknown error')}", + )] + + +async def _execute_submit_feedback( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute submit_feedback tool.""" + from knowledge_base.lifecycle.feedback import submit_feedback + + chunk_id = arguments["chunk_id"] + feedback_type = arguments["feedback_type"] + details = arguments.get("details") + user_email = user.get("email", "mcp-user") + + feedback = await submit_feedback( + chunk_id=chunk_id, + slack_user_id=f"mcp:{user_email}", + slack_username=user_email, + feedback_type=feedback_type, + comment=details, + ) + + return [TextContent( + type="text", + text=( + f"Feedback submitted.\n\n" + f"**Chunk ID:** {chunk_id}\n" + f"**Feedback type:** {feedback_type}\n" + f"**Feedback ID:** {feedback.id}" + ), + )] + + +async def _execute_check_health( + arguments: dict[str, Any], + user: dict[str, Any], +) -> list[TextContent]: + """Execute check_health tool.""" + from knowledge_base.search import HybridRetriever + + retriever = HybridRetriever() + health = await retriever.check_health() + + status = "healthy" if health.get("graphiti_healthy") else "degraded" + + return [TextContent( + type="text", + text=( + f"Knowledge Base Health: **{status}**\n\n" + f"- Graphiti enabled: {health.get('graphiti_enabled', False)}\n" + f"- Graphiti healthy: {health.get('graphiti_healthy', False)}\n" + f"- Backend: {health.get('backend', 'unknown')}" + ), + )] diff --git a/src/knowledge_base/slack/bot.py b/src/knowledge_base/slack/bot.py index 5f674ce..7e81a72 100644 --- a/src/knowledge_base/slack/bot.py +++ b/src/knowledge_base/slack/bot.py @@ -23,8 +23,7 @@ process_thread_message, record_bot_response, ) -from knowledge_base.rag.factory import get_llm -from knowledge_base.rag.exceptions import LLMError +from knowledge_base.core.qa import search_knowledge, generate_answer from knowledge_base.search.models import SearchResult from knowledge_base.slack.modals import ( build_incorrect_feedback_modal, @@ -316,40 +315,9 @@ def handle_thread_message(event: dict, say: Any, client: WebClient) -> None: async def _search_chunks(query: str, limit: int = 5) -> list[SearchResult]: """Search for relevant chunks using Graphiti hybrid search. - Uses HybridRetriever which delegates to Graphiti's unified search: - - Semantic similarity (embeddings) - - BM25 keyword matching - - Graph relationships - - Returns SearchResult objects with content and metadata. + Delegates to core.qa.search_knowledge for reusability across interfaces. """ - logger.info(f"Searching for: '{query[:100]}...'") - - try: - from knowledge_base.search import HybridRetriever - - retriever = HybridRetriever() - health = await retriever.check_health() - logger.info(f"Hybrid search health: {health}") - - # Use Graphiti hybrid search - results = await retriever.search(query, k=limit) - logger.info(f"Hybrid search returned {len(results)} results") - - # Log first result for debugging - if results: - first = results[0] - logger.info( - f"First result: chunk_id={first.chunk_id}, " - f"title={first.page_title}, content_len={len(first.content)}" - ) - - return results - - except Exception as e: - logger.error(f"Hybrid search FAILED (returning 0 results): {e}", exc_info=True) - - return [] + return await search_knowledge(query, limit) async def _generate_answer( @@ -359,79 +327,9 @@ async def _generate_answer( ) -> str: """Generate an answer using LLM with retrieved chunks. - Args: - question: The user's question - chunks: SearchResult objects from Graphiti containing content and metadata - conversation_history: Previous messages in the conversation thread + Delegates to core.qa.generate_answer for reusability across interfaces. """ - if not chunks: - return "I couldn't find relevant information in the knowledge base to answer your question." - - # Build context from chunks (SearchResult has page_title property and content attribute) - context_parts = [] - for i, chunk in enumerate(chunks, 1): - context_parts.append( - f"[Source {i}: {chunk.page_title}]\n{chunk.content[:1000]}" - ) - context = "\n\n---\n\n".join(context_parts) - - # Build conversation history section - conversation_section = "" - if conversation_history: - history_parts = [] - for msg in conversation_history[-6:]: # Last 6 messages for context - role = "User" if msg["role"] == "user" else "Assistant" - # Truncate long messages in history - content = msg["content"][:500] + "..." if len(msg["content"]) > 500 else msg["content"] - history_parts.append(f"{role}: {content}") - if history_parts: - conversation_section = f""" -PREVIOUS CONVERSATION: -{chr(10).join(history_parts)} - -(Use this context to understand what the user is asking about and provide continuity) -""" - - prompt = f"""You are Keboola's internal knowledge base assistant. Answer questions ONLY based on the provided context documents. - -CRITICAL RULES: -- ONLY use information explicitly stated in the context documents below. -- Do NOT make up, assume, or hallucinate any information not in the documents. -- If the context doesn't contain enough information to answer, say so clearly. -- When referencing information, mention which source it came from. -{conversation_section} -CONTEXT DOCUMENTS: -{context} - -CURRENT QUESTION: {question} - -INSTRUCTIONS: -- Answer based strictly on the context documents above. -- Be concise and helpful. Use bullet points for multiple items. -- If the documents only partially answer the question, share what IS available and note what's missing. -- Do NOT invent tool names, process steps, or policies not mentioned in the documents. - -Provide your answer:""" - - try: - llm = await get_llm() - logger.info(f"Using LLM provider: {llm.provider_name}") - - # Skip health check - generate() has proper retry logic and error handling - answer = await llm.generate(prompt) - return answer.strip() - except LLMError as e: - logger.error(f"LLM provider error: {e}") - return ( - f"I found {len(chunks)} relevant documents but couldn't generate " - f"an answer at this time. Please try again later." - ) - except Exception as e: - logger.error(f"LLM generation failed: {e}") - return ( - f"I found {len(chunks)} relevant documents but couldn't generate " - f"an answer at this time. Please try again later." - ) + return await generate_answer(question, chunks, conversation_history) def _split_text_into_blocks(text: str, max_chars: int = 3000) -> list[dict]: diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..37e861a --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for MCP server modules.""" diff --git a/tests/unit/test_mcp_oauth.py b/tests/unit/test_mcp_oauth.py new file mode 100644 index 0000000..78a9b28 --- /dev/null +++ b/tests/unit/test_mcp_oauth.py @@ -0,0 +1,234 @@ +"""Tests for MCP OAuth resource server and scope handling.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from knowledge_base.mcp.config import check_scope_access +from knowledge_base.mcp.oauth.resource_server import ( + OAuthResourceServer, + extract_user_context, +) + + +# =========================================================================== +# extract_user_context +# =========================================================================== + + +class TestExtractUserContext: + """Test user context extraction from JWT claims.""" + + def test_google_verified_keboola_email_gets_read_and_write(self): + """Verified @keboola.com Google user should get kb.read + kb.write.""" + claims = { + "iss": "https://accounts.google.com", + "sub": "google-uid-123", + "email": "alice@keboola.com", + "email_verified": True, + } + ctx = extract_user_context(claims) + assert ctx["email"] == "alice@keboola.com" + assert ctx["sub"] == "google-uid-123" + assert "kb.read" in ctx["scopes"] + assert "kb.write" in ctx["scopes"] + # Standard OpenID scopes also present + assert "openid" in ctx["scopes"] + assert "email" in ctx["scopes"] + assert "profile" in ctx["scopes"] + + def test_google_verified_external_email_gets_read_only(self): + """Verified external Google user should get kb.read but NOT kb.write.""" + claims = { + "iss": "https://accounts.google.com", + "sub": "google-uid-456", + "email": "bob@external.com", + "email_verified": True, + } + ctx = extract_user_context(claims) + assert ctx["email"] == "bob@external.com" + assert "kb.read" in ctx["scopes"] + assert "kb.write" not in ctx["scopes"] + + def test_google_unverified_email_gets_no_scopes(self): + """Unverified Google user should get no application scopes.""" + claims = { + "iss": "https://accounts.google.com", + "sub": "google-uid-789", + "email": "unverified@keboola.com", + "email_verified": False, + } + ctx = extract_user_context(claims) + assert ctx["scopes"] == [] + + def test_google_missing_email_gets_no_scopes(self): + """Google user without email claim should get no scopes.""" + claims = { + "iss": "https://accounts.google.com", + "sub": "google-uid-noemail", + "email_verified": True, + } + ctx = extract_user_context(claims) + # email is "" which is falsy, so no scopes granted + assert ctx["scopes"] == [] + + def test_non_google_claims_with_scope_string(self): + """Non-Google token with 'scope' claim should parse scopes from the string.""" + claims = { + "iss": "https://some-other-idp.example.com", + "sub": "user-abc", + "email": "charlie@other.com", + "scope": "kb.read kb.write custom_scope", + } + ctx = extract_user_context(claims) + assert ctx["scopes"] == ["kb.read", "kb.write", "custom_scope"] + assert ctx["email"] == "charlie@other.com" + + def test_non_google_claims_with_empty_scope(self): + """Non-Google token with empty scope should produce empty scopes list.""" + claims = { + "iss": "https://another-idp.example.com", + "sub": "user-xyz", + "email": "dave@corp.com", + "scope": "", + } + ctx = extract_user_context(claims) + assert ctx["scopes"] == [] + + def test_claims_without_scope_or_google_issuer(self): + """Token without scope and without Google issuer should produce empty scopes.""" + claims = { + "iss": "https://custom-auth.example.com", + "sub": "user-custom", + "email": "eve@custom.com", + } + ctx = extract_user_context(claims) + assert ctx["scopes"] == [] + + def test_sub_fallback_for_email(self): + """When email claim is missing, email should fall back to sub.""" + claims = { + "iss": "https://custom-auth.example.com", + "sub": "user-no-email", + } + ctx = extract_user_context(claims) + assert ctx["email"] == "user-no-email" + + def test_claims_are_preserved(self): + """Original claims dict should be available in the context.""" + claims = { + "iss": "https://accounts.google.com", + "sub": "uid-1", + "email": "test@keboola.com", + "email_verified": True, + "aud": "test-client-id", + "exp": 9999999999, + } + ctx = extract_user_context(claims) + assert ctx["claims"] is claims + + +# =========================================================================== +# OAuthResourceServer +# =========================================================================== + + +class TestOAuthResourceServer: + """Test OAuthResourceServer initialization and properties.""" + + def test_initialization_creates_validator_and_metadata(self): + """OAuthResourceServer should initialize a validator and metadata.""" + server = OAuthResourceServer( + resource="https://kb-mcp.example.com", + authorization_servers=["https://accounts.google.com"], + audience="test-client-id", + scopes_supported=["openid", "email", "profile"], + ) + assert server.resource == "https://kb-mcp.example.com" + assert server.audience == "test-client-id" + assert server.metadata is not None + assert server.validator is not None + + def test_metadata_resource_matches(self): + """Metadata resource field should match the server resource.""" + server = OAuthResourceServer( + resource="https://kb-mcp.example.com", + authorization_servers=["https://accounts.google.com"], + audience="test-client-id", + ) + meta_dict = server.metadata.to_dict() + assert meta_dict["resource"] == "https://kb-mcp.example.com" + assert "https://accounts.google.com" in meta_dict["authorization_servers"] + + def test_metadata_includes_scopes(self): + """Metadata should include advertised scopes.""" + server = OAuthResourceServer( + resource="https://kb-mcp.example.com", + authorization_servers=["https://accounts.google.com"], + audience="test-client-id", + scopes_supported=["openid", "email"], + ) + meta_dict = server.metadata.to_dict() + assert "scopes_supported" in meta_dict + assert meta_dict["scopes_supported"] == ["openid", "email"] + + def test_no_authorization_servers_no_validator(self): + """Without authorization servers, validator should not be created.""" + server = OAuthResourceServer( + resource="https://kb-mcp.example.com", + authorization_servers=[], + audience="test-client-id", + ) + with pytest.raises(RuntimeError, match="No authorization server configured"): + _ = server.validator + + def test_google_issuer_sets_google_jwks_uri(self): + """Google issuer should configure Google's JWKS endpoint on the validator.""" + server = OAuthResourceServer( + resource="https://kb-mcp.example.com", + authorization_servers=["https://accounts.google.com"], + audience="test-client-id", + ) + assert server.validator.jwks_uri == "https://www.googleapis.com/oauth2/v3/certs" + assert server.validator.is_google is True + + +# =========================================================================== +# check_scope_access +# =========================================================================== + + +class TestCheckScopeAccess: + """Test scope access checking logic.""" + + def test_matching_scope_grants_access(self): + """If at least one required scope is in granted scopes, access is granted.""" + assert check_scope_access(["kb.read"], ["kb.read", "kb.write"]) is True + + def test_no_matching_scope_denies_access(self): + """If no required scope is in granted scopes, access is denied.""" + assert check_scope_access(["kb.write"], ["kb.read"]) is False + + def test_multiple_required_any_match(self): + """check_scope_access uses ANY (OR) logic for required scopes.""" + assert check_scope_access(["kb.read", "kb.write"], ["kb.read"]) is True + + def test_empty_required_denies(self): + """Empty required list means no scope matches -> deny.""" + assert check_scope_access([], ["kb.read"]) is False + + def test_empty_granted_denies(self): + """Empty granted list always denies.""" + assert check_scope_access(["kb.read"], []) is False + + def test_both_empty_denies(self): + """Both empty -> deny.""" + assert check_scope_access([], []) is False + + def test_exact_single_match(self): + """Single required scope matching single granted scope.""" + assert check_scope_access(["kb.write"], ["kb.write"]) is True + + def test_openid_scope_does_not_match_kb_read(self): + """Standard OpenID scopes should not satisfy kb.* requirements.""" + assert check_scope_access(["kb.read"], ["openid", "email", "profile"]) is False diff --git a/tests/unit/test_mcp_server.py b/tests/unit/test_mcp_server.py new file mode 100644 index 0000000..8aafcaa --- /dev/null +++ b/tests/unit/test_mcp_server.py @@ -0,0 +1,485 @@ +"""Tests for MCP HTTP server endpoints and JSON-RPC protocol.""" + +import os +import sys +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Set required env vars BEFORE importing server module, because server.py +# instantiates MCPSettings() at module level which requires these. +os.environ.setdefault("MCP_OAUTH_CLIENT_ID", "test-client-id") +os.environ.setdefault("MCP_OAUTH_CLIENT_SECRET", "test-client-secret") +os.environ.setdefault("MCP_OAUTH_RESOURCE_IDENTIFIER", "https://test-kb-mcp.example.com") +os.environ.setdefault("MCP_DEV_MODE", "true") + +from httpx import ASGITransport, AsyncClient + +from knowledge_base.mcp.server import app, mcp_settings # noqa: E402 + + +@dataclass +class _FakeSearchResult: + """Minimal stand-in for SearchResult.""" + + chunk_id: str + content: str + score: float + metadata: dict[str, Any] + + @property + def page_title(self) -> str: + return self.metadata.get("page_title", "") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def dev_auth_header() -> dict[str, str]: + """Bearer token header -- value is irrelevant in dev mode.""" + return {"Authorization": "Bearer dev-token-placeholder"} + + +@pytest.fixture +async def client(): + """httpx AsyncClient wired to the FastAPI app with dev mode enabled.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +# =========================================================================== +# Health & Metadata Endpoints +# =========================================================================== + + +class TestHealthEndpoint: + """Test GET /health.""" + + async def test_health_returns_200(self, client: AsyncClient): + resp = await client.get("/health") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "healthy" + assert body["service"] == "knowledge-base-mcp-server" + + +class TestOAuthMetadataEndpoint: + """Test GET /.well-known/oauth-protected-resource.""" + + async def test_returns_metadata(self, client: AsyncClient): + resp = await client.get("/.well-known/oauth-protected-resource") + assert resp.status_code == 200 + body = resp.json() + assert "resource" in body + assert "authorization_servers" in body + + +# =========================================================================== +# OAuth Authorization Server Metadata +# =========================================================================== + + +class TestOAuthAuthorizationServerMetadata: + """Test GET /.well-known/oauth-authorization-server.""" + + async def test_returns_metadata(self, client: AsyncClient): + resp = await client.get("/.well-known/oauth-authorization-server") + assert resp.status_code == 200 + body = resp.json() + assert "issuer" in body + assert "authorization_endpoint" in body + assert "token_endpoint" in body + assert "registration_endpoint" in body + assert body["response_types_supported"] == ["code"] + assert "S256" in body["code_challenge_methods_supported"] + + async def test_endpoints_use_base_url(self, client: AsyncClient): + resp = await client.get("/.well-known/oauth-authorization-server") + body = resp.json() + assert body["authorization_endpoint"].endswith("/authorize") + assert body["token_endpoint"].endswith("/token") + assert body["registration_endpoint"].endswith("/register") + + +# =========================================================================== +# OAuth Authorize Endpoint +# =========================================================================== + + +class TestOAuthAuthorize: + """Test GET /authorize - redirects to Google.""" + + async def test_redirects_to_google(self, client: AsyncClient): + resp = await client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": "test-client-id", + "redirect_uri": "https://claude.ai/api/mcp/auth_callback", + "scope": "claudeai", + "state": "test-state", + "code_challenge": "abc123", + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + assert resp.status_code == 302 + location = resp.headers["location"] + assert "accounts.google.com" in location + # Should map claudeai scope to Google scopes + assert "openid" in location + assert "email" in location + assert "profile" in location + # Should NOT contain the custom scope + assert "claudeai" not in location + # Should preserve PKCE params + assert "code_challenge=abc123" in location + assert "state=test-state" in location + + async def test_no_auth_required(self, client: AsyncClient): + """The /authorize endpoint must be accessible without Bearer token.""" + resp = await client.get( + "/authorize", + params={"response_type": "code", "scope": "openid"}, + follow_redirects=False, + ) + assert resp.status_code == 302 + + +# =========================================================================== +# OAuth Token Endpoint +# =========================================================================== + + +class TestOAuthToken: + """Test POST /token - proxies to Google.""" + + async def test_proxies_to_google(self, client: AsyncClient): + """Token endpoint should proxy to Google and return the response.""" + mock_google_response = MagicMock() + mock_google_response.status_code = 200 + mock_google_response.json.return_value = { + "access_token": "ya29.mock-token", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": "eyJ...", + } + + with patch("knowledge_base.mcp.server.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_google_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client + + resp = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": "test-auth-code", + "redirect_uri": "https://claude.ai/api/mcp/auth_callback", + "code_verifier": "test-verifier", + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["access_token"] == "ya29.mock-token" + assert body["token_type"] == "Bearer" + + async def test_no_auth_required(self, client: AsyncClient): + """The /token endpoint must be accessible without Bearer token.""" + mock_google_response = MagicMock() + mock_google_response.status_code = 400 + mock_google_response.json.return_value = {"error": "invalid_grant"} + + with patch("knowledge_base.mcp.server.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_google_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client + + resp = await client.post( + "/token", + data={"grant_type": "authorization_code", "code": "bad-code"}, + ) + + # Should NOT get 401 (auth not required), should get Google's error + assert resp.status_code == 400 + + +# =========================================================================== +# OAuth Dynamic Client Registration +# =========================================================================== + + +class TestOAuthRegister: + """Test POST /register - returns our client_id.""" + + async def test_returns_client_id(self, client: AsyncClient): + resp = await client.post( + "/register", + json={ + "redirect_uris": ["https://claude.ai/api/mcp/auth_callback"], + "client_name": "Claude.AI", + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + }, + ) + assert resp.status_code == 201 + body = resp.json() + assert body["client_id"] == "test-client-id" + assert body["client_name"] == "Claude.AI" + assert "authorization_code" in body["grant_types"] + + async def test_no_auth_required(self, client: AsyncClient): + """The /register endpoint must be accessible without Bearer token.""" + resp = await client.post( + "/register", + json={"redirect_uris": ["https://example.com/callback"]}, + ) + assert resp.status_code == 201 + + +# =========================================================================== +# Authentication +# =========================================================================== + + +class TestAuthentication: + """Test authentication middleware.""" + + async def test_post_mcp_without_auth_returns_401(self, client: AsyncClient): + """POST /mcp without Authorization header should return 401.""" + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "ping", "id": 1}, + ) + assert resp.status_code == 401 + + async def test_post_mcp_with_invalid_auth_scheme_returns_401(self, client: AsyncClient): + """POST /mcp with non-Bearer auth should return 401.""" + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "ping", "id": 1}, + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ) + assert resp.status_code == 401 + + +# =========================================================================== +# MCP Protocol: initialize +# =========================================================================== + + +class TestMCPInitialize: + """Test MCP initialize method.""" + + async def test_initialize_returns_protocol_version( + self, client: AsyncClient, dev_auth_header: dict + ): + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + headers=dev_auth_header, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["jsonrpc"] == "2.0" + assert body["id"] == 1 + result = body["result"] + assert result["protocolVersion"] == "2024-11-05" + assert "tools" in result["capabilities"] + assert result["serverInfo"]["name"] == "keboola-knowledge-base" + + +# =========================================================================== +# MCP Protocol: tools/list +# =========================================================================== + + +class TestMCPToolsList: + """Test MCP tools/list method.""" + + async def test_tools_list_returns_all_tools_in_dev_mode( + self, client: AsyncClient, dev_auth_header: dict + ): + """Dev mode grants all scopes, so all 6 tools should be listed.""" + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "id": 2}, + headers=dev_auth_header, + ) + assert resp.status_code == 200 + body = resp.json() + tools = body["result"]["tools"] + assert len(tools) == 6 + names = {t["name"] for t in tools} + assert "ask_question" in names + assert "search_knowledge" in names + assert "create_knowledge" in names + + async def test_tools_list_each_tool_has_schema( + self, client: AsyncClient, dev_auth_header: dict + ): + """Each tool in the list should have name, description, and inputSchema.""" + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "id": 3}, + headers=dev_auth_header, + ) + tools = resp.json()["result"]["tools"] + for tool in tools: + assert "name" in tool + assert "description" in tool + assert "inputSchema" in tool + + +class TestMCPToolsListScopeFiltering: + """Test that tools/list respects user scopes.""" + + async def test_read_only_user_sees_fewer_tools(self): + """A user with only kb.read should see only read tools.""" + from knowledge_base.mcp import server as srv_module + + original_fn = srv_module.handle_tools_list + + async def _custom_handle(user): + # Force read-only scopes + user = {**user, "scopes": ["kb.read"]} + return await original_fn(user) + + with patch.object(srv_module, "handle_tools_list", side_effect=_custom_handle): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + resp = await ac.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "tools/list", "id": 4}, + headers={"Authorization": "Bearer dev-token"}, + ) + + assert resp.status_code == 200 + tools = resp.json()["result"]["tools"] + names = {t["name"] for t in tools} + assert names == {"ask_question", "search_knowledge", "check_health"} + + +# =========================================================================== +# MCP Protocol: tools/call +# =========================================================================== + + +class TestMCPToolsCall: + """Test MCP tools/call method.""" + + async def test_tools_call_search_knowledge( + self, client: AsyncClient, dev_auth_header: dict + ): + """tools/call search_knowledge should execute and return content.""" + fake_results = [ + _FakeSearchResult( + chunk_id="c1", + content="Keboola overview content", + score=0.9, + metadata={"page_title": "Overview", "url": "https://wiki.keboola.com"}, + ), + ] + + with patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=fake_results, + ): + resp = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "search_knowledge", + "arguments": {"query": "Keboola overview"}, + }, + "id": 5, + }, + headers=dev_auth_header, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["id"] == 5 + content = body["result"]["content"] + assert len(content) >= 1 + assert content[0]["type"] == "text" + assert "Found 1 results" in content[0]["text"] + + async def test_tools_call_missing_name_returns_error( + self, client: AsyncClient, dev_auth_header: dict + ): + """tools/call without a tool name should return an error.""" + resp = await client.post( + "/mcp", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"arguments": {}}, + "id": 6, + }, + headers=dev_auth_header, + ) + assert resp.status_code == 200 + body = resp.json() + # Should be an error response (code -32000 from HTTPException) + assert body["error"] is not None + assert body["error"]["code"] == -32000 + + +# =========================================================================== +# MCP Protocol: ping +# =========================================================================== + + +class TestMCPPing: + """Test MCP ping method.""" + + async def test_ping_returns_empty_result( + self, client: AsyncClient, dev_auth_header: dict + ): + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "ping", "id": 7}, + headers=dev_auth_header, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["result"] == {} + assert body["id"] == 7 + + +# =========================================================================== +# MCP Protocol: unknown method +# =========================================================================== + + +class TestMCPUnknownMethod: + """Test unknown MCP methods.""" + + async def test_unknown_method_returns_error_code( + self, client: AsyncClient, dev_auth_header: dict + ): + resp = await client.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "totally/unknown", "id": 8}, + headers=dev_auth_header, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["error"]["code"] == -32601 + assert "Method not found" in body["error"]["message"] diff --git a/tests/unit/test_mcp_tools.py b/tests/unit/test_mcp_tools.py new file mode 100644 index 0000000..49fdc9d --- /dev/null +++ b/tests/unit/test_mcp_tools.py @@ -0,0 +1,552 @@ +"""Tests for MCP tool definitions and execution dispatcher.""" + +import json +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp.types import TextContent + +from knowledge_base.mcp.tools import ( + TOOLS, + _apply_filters, + execute_tool, + get_tools_for_scopes, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +@dataclass +class _FakeSearchResult: + """Minimal stand-in for SearchResult used in tool execution tests.""" + + chunk_id: str + content: str + score: float + metadata: dict[str, Any] + + @property + def page_title(self) -> str: + return self.metadata.get("page_title", "") + + +def _make_result( + chunk_id: str = "c1", + content: str = "some content", + score: float = 0.9, + page_title: str = "My Page", + url: str = "https://example.com", + space_key: str = "ENG", + doc_type: str = "webpage", + topics: list[str] | None = None, + updated_at: str = "2026-01-15T00:00:00Z", +) -> _FakeSearchResult: + return _FakeSearchResult( + chunk_id=chunk_id, + content=content, + score=score, + metadata={ + "page_title": page_title, + "url": url, + "space_key": space_key, + "doc_type": doc_type, + "topics": json.dumps(topics or []), + "updated_at": updated_at, + }, + ) + + +_USER_INTERNAL = {"email": "alice@keboola.com", "scopes": ["kb.read", "kb.write"]} +_USER_EXTERNAL = {"email": "bob@external.com", "scopes": ["kb.read"]} + + +# =========================================================================== +# get_tools_for_scopes +# =========================================================================== + + +class TestGetToolsForScopes: + """Test scope-based tool filtering.""" + + def test_read_scopes_return_read_tools(self): + """kb.read should return ask_question, search_knowledge, check_health.""" + tools = get_tools_for_scopes(["kb.read"]) + names = {t.name for t in tools} + assert names == {"ask_question", "search_knowledge", "check_health"} + + def test_write_scopes_return_write_tools(self): + """kb.write should return create_knowledge, ingest_document, submit_feedback.""" + tools = get_tools_for_scopes(["kb.write"]) + names = {t.name for t in tools} + assert names == {"create_knowledge", "ingest_document", "submit_feedback"} + + def test_both_scopes_return_all_tools(self): + """kb.read + kb.write should return all 6 tools.""" + tools = get_tools_for_scopes(["kb.read", "kb.write"]) + names = {t.name for t in tools} + assert len(names) == 6 + assert names == { + "ask_question", + "search_knowledge", + "check_health", + "create_knowledge", + "ingest_document", + "submit_feedback", + } + + def test_empty_scopes_return_no_tools(self): + """Empty scope list should return no tools.""" + tools = get_tools_for_scopes([]) + assert tools == [] + + def test_irrelevant_scopes_return_no_tools(self): + """Scopes like openid/email should not grant access to any tools.""" + tools = get_tools_for_scopes(["openid", "email", "profile"]) + assert tools == [] + + def test_returned_tools_are_tool_instances(self): + """Each returned item should be a Tool with name and inputSchema.""" + tools = get_tools_for_scopes(["kb.read"]) + for tool in tools: + assert hasattr(tool, "name") + assert hasattr(tool, "inputSchema") + + def test_tool_definitions_count(self): + """TOOLS list should contain exactly 6 tool definitions.""" + assert len(TOOLS) == 6 + + +# =========================================================================== +# _apply_filters +# =========================================================================== + + +class TestApplyFilters: + """Test metadata-based filtering of search results.""" + + def test_no_filters_returns_all(self): + """Empty filter dict returns all results unchanged.""" + results = [_make_result(chunk_id="c1"), _make_result(chunk_id="c2")] + filtered = _apply_filters(results, {}) + assert len(filtered) == 2 + + def test_space_key_filter(self): + """Filter by space_key keeps only matching results.""" + results = [ + _make_result(chunk_id="c1", space_key="ENG"), + _make_result(chunk_id="c2", space_key="SALES"), + ] + filtered = _apply_filters(results, {"space_key": "ENG"}) + assert len(filtered) == 1 + assert filtered[0].chunk_id == "c1" + + def test_doc_type_filter(self): + """Filter by doc_type keeps only matching results.""" + results = [ + _make_result(chunk_id="c1", doc_type="webpage"), + _make_result(chunk_id="c2", doc_type="pdf"), + _make_result(chunk_id="c3", doc_type="quick_fact"), + ] + filtered = _apply_filters(results, {"doc_type": "pdf"}) + assert len(filtered) == 1 + assert filtered[0].chunk_id == "c2" + + def test_topics_filter_json_array(self): + """Filter by topics works with JSON-encoded topic arrays.""" + results = [ + _make_result(chunk_id="c1", topics=["deployment", "ci"]), + _make_result(chunk_id="c2", topics=["billing"]), + _make_result(chunk_id="c3", topics=["deployment", "security"]), + ] + filtered = _apply_filters(results, {"topics": ["deployment"]}) + assert len(filtered) == 2 + ids = {r.chunk_id for r in filtered} + assert ids == {"c1", "c3"} + + def test_topics_filter_string_fallback(self): + """When topics is a plain string (not JSON), treat it as a single-element list.""" + result = _FakeSearchResult( + chunk_id="c1", + content="text", + score=0.5, + metadata={"topics": "security"}, + ) + filtered = _apply_filters([result], {"topics": ["security"]}) + assert len(filtered) == 1 + + def test_topics_filter_no_match(self): + """If none of the requested topics match, the result is excluded.""" + results = [_make_result(chunk_id="c1", topics=["billing"])] + filtered = _apply_filters(results, {"topics": ["deployment"]}) + assert len(filtered) == 0 + + def test_updated_after_filter(self): + """Filter by updated_after keeps only results updated after the date.""" + results = [ + _make_result(chunk_id="c1", updated_at="2026-01-01T00:00:00Z"), + _make_result(chunk_id="c2", updated_at="2026-02-01T00:00:00Z"), + ] + filtered = _apply_filters(results, {"updated_after": "2026-01-15T00:00:00Z"}) + assert len(filtered) == 1 + assert filtered[0].chunk_id == "c2" + + def test_combined_filters(self): + """Multiple filters are applied together (AND logic).""" + results = [ + _make_result(chunk_id="c1", space_key="ENG", doc_type="webpage"), + _make_result(chunk_id="c2", space_key="ENG", doc_type="pdf"), + _make_result(chunk_id="c3", space_key="SALES", doc_type="webpage"), + ] + filtered = _apply_filters(results, {"space_key": "ENG", "doc_type": "webpage"}) + assert len(filtered) == 1 + assert filtered[0].chunk_id == "c1" + + def test_filter_result_with_no_metadata(self): + """Results without metadata attribute should not crash.""" + result = MagicMock(spec=[]) # no attributes at all + # _apply_filters checks hasattr(r, "metadata") + filtered = _apply_filters([result], {"space_key": "ENG"}) + # Without metadata, space_key check against {} will fail -> excluded + assert len(filtered) == 0 + + +# =========================================================================== +# execute_tool +# =========================================================================== + + +class TestExecuteToolAskQuestion: + """Test ask_question tool execution.""" + + async def test_ask_question_returns_text_content(self): + """ask_question should return a list of TextContent with the answer and sources.""" + fake_chunks = [ + _make_result( + chunk_id="c1", + content="Keboola is a data platform.", + page_title="About Keboola", + url="https://wiki.keboola.com/about", + ), + ] + with ( + patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=fake_chunks, + ), + patch( + "knowledge_base.core.qa.generate_answer", + new_callable=AsyncMock, + return_value="Keboola is a data platform.", + ), + ): + result = await execute_tool( + "ask_question", + {"question": "What is Keboola?"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + assert isinstance(result[0], TextContent) + assert "Keboola is a data platform." in result[0].text + assert "Sources:" in result[0].text + assert "About Keboola" in result[0].text + + async def test_ask_question_no_sources(self): + """When search returns no chunks, sources section is absent.""" + with ( + patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "knowledge_base.core.qa.generate_answer", + new_callable=AsyncMock, + return_value="No information found.", + ), + ): + result = await execute_tool( + "ask_question", + {"question": "Something obscure?"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + assert "Sources:" not in result[0].text + + async def test_ask_question_with_conversation_history(self): + """Conversation history should be passed through to generate_answer.""" + history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help?"}, + ] + mock_generate = AsyncMock(return_value="Follow-up answer.") + with ( + patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "knowledge_base.core.qa.generate_answer", + mock_generate, + ), + ): + await execute_tool( + "ask_question", + {"question": "Follow-up?", "conversation_history": history}, + _USER_INTERNAL, + ) + + # generate_answer should have been called with conversation_history + mock_generate.assert_called_once() + call_kwargs = mock_generate.call_args + assert call_kwargs[0][2] == history # third positional arg + + +class TestExecuteToolSearchKnowledge: + """Test search_knowledge tool execution.""" + + async def test_search_returns_formatted_results(self): + """search_knowledge should return formatted markdown-like results.""" + fake_results = [ + _make_result( + chunk_id="c1", + content="First result content", + score=0.95, + page_title="Page One", + url="https://wiki.keboola.com/page1", + ), + _make_result( + chunk_id="c2", + content="Second result content", + score=0.85, + page_title="Page Two", + url="https://wiki.keboola.com/page2", + ), + ] + with patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=fake_results, + ): + result = await execute_tool( + "search_knowledge", + {"query": "test query"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + text = result[0].text + assert "Found 2 results" in text + assert "Page One" in text + assert "Page Two" in text + assert "0.950" in text + assert "c1" in text + + async def test_search_no_results(self): + """When no results found, return appropriate message.""" + with patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=[], + ): + result = await execute_tool( + "search_knowledge", + {"query": "nonexistent stuff"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + assert "No results found" in result[0].text + + async def test_search_respects_top_k(self): + """top_k limits the number of returned results.""" + many_results = [_make_result(chunk_id=f"c{i}") for i in range(10)] + with patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + return_value=many_results, + ): + result = await execute_tool( + "search_knowledge", + {"query": "test", "top_k": 3}, + _USER_INTERNAL, + ) + + text = result[0].text + assert "Found 3 results" in text + + +class TestExecuteToolCreateKnowledge: + """Test create_knowledge tool execution.""" + + async def test_create_knowledge_success(self): + """create_knowledge should index a chunk and return confirmation.""" + mock_indexer = AsyncMock() + mock_indexer.index_single_chunk = AsyncMock() + + with patch( + "knowledge_base.graph.graphiti_indexer.GraphitiIndexer", + return_value=mock_indexer, + ): + result = await execute_tool( + "create_knowledge", + {"content": "Neo4j is a graph database.", "topics": ["databases", "graphs"]}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + text = result[0].text + assert "Knowledge saved successfully" in text + assert "Neo4j is a graph database." in text + mock_indexer.index_single_chunk.assert_called_once() + + # Verify the ChunkData passed to the indexer + chunk_data = mock_indexer.index_single_chunk.call_args[0][0] + assert chunk_data.content == "Neo4j is a graph database." + assert chunk_data.space_key == "MCP" + assert chunk_data.doc_type == "quick_fact" + assert "alice@keboola.com" in chunk_data.page_title + + async def test_create_knowledge_default_topics(self): + """create_knowledge without topics should use empty list.""" + mock_indexer = AsyncMock() + mock_indexer.index_single_chunk = AsyncMock() + + with patch( + "knowledge_base.graph.graphiti_indexer.GraphitiIndexer", + return_value=mock_indexer, + ): + await execute_tool( + "create_knowledge", + {"content": "A fact."}, + _USER_INTERNAL, + ) + + chunk_data = mock_indexer.index_single_chunk.call_args[0][0] + assert chunk_data.topics == "[]" + + +class TestExecuteToolSubmitFeedback: + """Test submit_feedback tool execution.""" + + async def test_submit_feedback_success(self): + """submit_feedback should call lifecycle.feedback.submit_feedback.""" + mock_feedback = MagicMock() + mock_feedback.id = 42 + + with patch( + "knowledge_base.lifecycle.feedback.submit_feedback", + new_callable=AsyncMock, + return_value=mock_feedback, + ): + result = await execute_tool( + "submit_feedback", + {"chunk_id": "c123", "feedback_type": "helpful", "details": "Great content!"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + text = result[0].text + assert "Feedback submitted" in text + assert "c123" in text + assert "helpful" in text + assert "42" in text + + async def test_submit_feedback_without_details(self): + """submit_feedback should work without optional details.""" + mock_feedback = MagicMock() + mock_feedback.id = 7 + + with patch( + "knowledge_base.lifecycle.feedback.submit_feedback", + new_callable=AsyncMock, + return_value=mock_feedback, + ): + result = await execute_tool( + "submit_feedback", + {"chunk_id": "c999", "feedback_type": "outdated"}, + _USER_INTERNAL, + ) + + assert "Feedback submitted" in result[0].text + + +class TestExecuteToolCheckHealth: + """Test check_health tool execution.""" + + async def test_check_health_healthy(self): + """check_health should return healthy status when graphiti is up.""" + mock_retriever = AsyncMock() + mock_retriever.check_health = AsyncMock( + return_value={ + "graphiti_enabled": True, + "graphiti_healthy": True, + "backend": "graphiti", + } + ) + + with patch( + "knowledge_base.search.HybridRetriever", + return_value=mock_retriever, + ): + result = await execute_tool("check_health", {}, _USER_INTERNAL) + + assert len(result) == 1 + text = result[0].text + assert "healthy" in text + assert "Graphiti enabled: True" in text + assert "Graphiti healthy: True" in text + + async def test_check_health_degraded(self): + """check_health should return degraded when graphiti is not healthy.""" + mock_retriever = AsyncMock() + mock_retriever.check_health = AsyncMock( + return_value={ + "graphiti_enabled": True, + "graphiti_healthy": False, + "backend": "graphiti", + } + ) + + with patch( + "knowledge_base.search.HybridRetriever", + return_value=mock_retriever, + ): + result = await execute_tool("check_health", {}, _USER_INTERNAL) + + text = result[0].text + assert "degraded" in text + + +class TestExecuteToolUnknown: + """Test unknown tool execution and error handling.""" + + async def test_unknown_tool_returns_error(self): + """Unknown tool name should return an error TextContent.""" + result = await execute_tool("nonexistent_tool", {}, _USER_INTERNAL) + assert len(result) == 1 + assert "Unknown tool: nonexistent_tool" in result[0].text + + async def test_tool_execution_exception_returns_error(self): + """If a tool raises an exception, it should be caught and returned as text.""" + with patch( + "knowledge_base.core.qa.search_knowledge", + new_callable=AsyncMock, + side_effect=RuntimeError("connection lost"), + ): + result = await execute_tool( + "ask_question", + {"question": "test"}, + _USER_INTERNAL, + ) + + assert len(result) == 1 + assert "Error executing ask_question" in result[0].text + assert "connection lost" in result[0].text