diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index e41e8e1..e8efc6b 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -136,7 +136,30 @@ jobs:
docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/slack-bot:$TAG
# ============================================
- # Stage 2b: Build jobs Docker Image
+ # Stage 2b: Build MCP Server Docker Image
+ # ============================================
+ build-mcp:
+ needs: unit-tests
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: google-github-actions/auth@v2
+ with:
+ credentials_json: ${{ secrets.GCP_SA_KEY }}
+ - uses: google-github-actions/setup-gcloud@v2
+ - name: Configure Docker
+ run: gcloud auth configure-docker ${{ env.REGION }}-docker.pkg.dev
+ - name: Build and push mcp-server image
+ env:
+ TAG: ${{ github.sha }}
+ run: |
+ docker build -f deploy/docker/Dockerfile.mcp -t ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG .
+ docker tag ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:latest
+ docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG
+ docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:latest
+
+ # ============================================
+ # Stage 2c: Build jobs Docker Image
# ============================================
build-jobs:
needs: unit-tests
@@ -166,11 +189,11 @@ jobs:
docker push ${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG
# ============================================
- # Stage 2c: Build Gate (for branch protection)
+ # Stage 2d: Build Gate (for branch protection)
# ============================================
build:
# Gate job: requires builds + reviewer approvals (on PRs) before staging deploy
- needs: [build-slack-bot, build-jobs, ai-review, security-review]
+ needs: [build-slack-bot, build-mcp, build-jobs, ai-review, security-review]
if: always()
runs-on: ubuntu-latest
steps:
@@ -178,17 +201,20 @@ jobs:
env:
EVENT_NAME: ${{ github.event_name }}
BUILD_BOT: ${{ needs.build-slack-bot.result }}
+ BUILD_MCP: ${{ needs.build-mcp.result }}
BUILD_JOBS: ${{ needs.build-jobs.result }}
AI_REVIEW: ${{ needs.ai-review.result }}
SEC_REVIEW: ${{ needs.security-review.result }}
run: |
echo "Build Slack Bot: $BUILD_BOT"
+ echo "Build MCP Server: $BUILD_MCP"
echo "Build Jobs: $BUILD_JOBS"
echo "AI Review: $AI_REVIEW"
echo "Security Review: $SEC_REVIEW"
# Build jobs must succeed
[[ "$BUILD_BOT" == "success" ]] || { echo "::error::build-slack-bot failed"; exit 1; }
+ [[ "$BUILD_MCP" == "success" ]] || { echo "::error::build-mcp failed"; exit 1; }
[[ "$BUILD_JOBS" == "success" ]] || { echo "::error::build-jobs failed"; exit 1; }
# For PRs: security review is required, AI review is advisory
@@ -231,12 +257,34 @@ jobs:
${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG \
${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:staging
+ # ============================================
+ # Stage 3b: Deploy MCP Server to Staging
+ # ============================================
+ deploy-mcp-staging:
+ needs: build
+ if: always() && needs.build.result == 'success'
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: google-github-actions/auth@v2
+ with:
+ credentials_json: ${{ secrets.GCP_SA_KEY }}
+ - uses: google-github-actions/setup-gcloud@v2
+ - name: Deploy mcp-server to staging
+ env:
+ TAG: ${{ github.sha }}
+ run: |
+ gcloud run services update kb-mcp-staging \
+ --image=${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG \
+ --region=${{ env.REGION }} \
+ --project=${{ env.PROJECT_ID }}
+
# ============================================
# Stage 4: E2E Tests against Staging (ALL tests)
# ============================================
e2e-tests-staging:
- needs: deploy-staging
- if: always() && needs.deploy-staging.result == 'success'
+ needs: [deploy-staging, deploy-mcp-staging]
+ if: always() && needs.deploy-staging.result == 'success' && needs.deploy-mcp-staging.result == 'success'
runs-on: ubuntu-latest
env:
# Slack configuration (Pattern: {SERVICE}_{ENV}_*)
@@ -344,3 +392,25 @@ jobs:
gcloud artifacts docker tags add \
${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:$TAG \
${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/jobs:latest
+
+ # ============================================
+ # Stage 5b: Deploy MCP Server to Production (main only)
+ # ============================================
+ deploy-mcp-production:
+ needs: e2e-tests-staging
+ if: always() && needs.e2e-tests-staging.result == 'success' && github.ref == 'refs/heads/main'
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: google-github-actions/auth@v2
+ with:
+ credentials_json: ${{ secrets.GCP_SA_KEY }}
+ - uses: google-github-actions/setup-gcloud@v2
+ - name: Deploy mcp-server to production
+ env:
+ TAG: ${{ github.sha }}
+ run: |
+ gcloud run services update kb-mcp \
+ --image=${{ env.REGION }}-docker.pkg.dev/${{ env.PROJECT_ID }}/knowledge-base/mcp-server:$TAG \
+ --region=${{ env.REGION }} \
+ --project=${{ env.PROJECT_ID }}
diff --git a/deploy/docker/Dockerfile.mcp b/deploy/docker/Dockerfile.mcp
new file mode 100644
index 0000000..a0e1bee
--- /dev/null
+++ b/deploy/docker/Dockerfile.mcp
@@ -0,0 +1,33 @@
+# =============================================================================
+# MCP Server Docker Image
+# =============================================================================
+FROM python:3.11-slim
+
+ENV PYTHONDONTWRITEBYTECODE=1 \
+ PYTHONUNBUFFERED=1 \
+ PIP_NO_CACHE_DIR=1 \
+ PIP_DISABLE_PIP_VERSION_CHECK=1
+
+WORKDIR /app
+
+# Install system dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ gcc \
+ && rm -rf /var/lib/apt/lists/*
+
+# Copy all source files
+COPY pyproject.toml .
+COPY src/ ./src/
+
+# Install Python dependencies
+RUN pip install --no-cache-dir . starlette
+
+# Create non-root user
+RUN useradd --create-home --shell /bin/bash appuser && \
+ chown -R appuser:appuser /app
+
+USER appuser
+
+ENV PORT=8080
+
+CMD ["python", "-m", "knowledge_base.mcp.server"]
diff --git a/deploy/terraform/cloudrun-mcp.tf b/deploy/terraform/cloudrun-mcp.tf
new file mode 100644
index 0000000..004ce70
--- /dev/null
+++ b/deploy/terraform/cloudrun-mcp.tf
@@ -0,0 +1,398 @@
+# =============================================================================
+# MCP Server Cloud Run Services (Production + Staging)
+# =============================================================================
+
+# -----------------------------------------------------------------------------
+# Service Account
+# -----------------------------------------------------------------------------
+resource "google_service_account" "mcp_server" {
+ account_id = "mcp-server"
+ display_name = "MCP Server Service Account"
+}
+
+# Vertex AI access for embeddings and LLM
+resource "google_project_iam_member" "mcp_server_vertex_ai" {
+ project = var.project_id
+ role = "roles/aiplatform.user"
+ member = "serviceAccount:${google_service_account.mcp_server.email}"
+}
+
+# -----------------------------------------------------------------------------
+# Secret Manager - MCP OAuth Secrets
+# -----------------------------------------------------------------------------
+resource "google_secret_manager_secret" "mcp_oauth_client_id" {
+ secret_id = "mcp-oauth-client-id"
+
+ replication {
+ auto {}
+ }
+
+ labels = {
+ environment = var.environment
+ purpose = "mcp"
+ }
+}
+
+resource "google_secret_manager_secret_version" "mcp_oauth_client_id" {
+ secret = google_secret_manager_secret.mcp_oauth_client_id.id
+ secret_data = "REPLACE_ME"
+
+ lifecycle {
+ ignore_changes = [secret_data]
+ }
+}
+
+resource "google_secret_manager_secret" "mcp_oauth_client_secret" {
+ secret_id = "mcp-oauth-client-secret"
+
+ replication {
+ auto {}
+ }
+
+ labels = {
+ environment = var.environment
+ purpose = "mcp"
+ }
+}
+
+resource "google_secret_manager_secret_version" "mcp_oauth_client_secret" {
+ secret = google_secret_manager_secret.mcp_oauth_client_secret.id
+ secret_data = "REPLACE_ME"
+
+ lifecycle {
+ ignore_changes = [secret_data]
+ }
+}
+
+# IAM: MCP server SA can access OAuth secrets
+resource "google_secret_manager_secret_iam_member" "mcp_oauth_client_id_access" {
+ secret_id = google_secret_manager_secret.mcp_oauth_client_id.secret_id
+ role = "roles/secretmanager.secretAccessor"
+ member = "serviceAccount:${google_service_account.mcp_server.email}"
+}
+
+resource "google_secret_manager_secret_iam_member" "mcp_oauth_client_secret_access" {
+ secret_id = google_secret_manager_secret.mcp_oauth_client_secret.secret_id
+ role = "roles/secretmanager.secretAccessor"
+ member = "serviceAccount:${google_service_account.mcp_server.email}"
+}
+
+# -----------------------------------------------------------------------------
+# Production MCP Server
+# -----------------------------------------------------------------------------
+resource "google_cloud_run_v2_service" "mcp_server" {
+ name = "kb-mcp"
+ location = var.region
+
+ template {
+ scaling {
+ min_instance_count = 1
+ max_instance_count = 5
+ }
+
+ containers {
+ image = "${var.region}-docker.pkg.dev/${var.project_id}/knowledge-base/mcp-server:latest"
+
+ resources {
+ limits = {
+ cpu = "1"
+ memory = "512Mi"
+ }
+ }
+
+ # Graph Database Configuration (Graphiti + Neo4j)
+ env {
+ name = "GRAPH_BACKEND"
+ value = "neo4j"
+ }
+
+ env {
+ name = "GRAPH_ENABLE_GRAPHITI"
+ value = "true"
+ }
+
+ # Neo4j connection - using internal GCE VM IP via VPC connector
+ env {
+ name = "NEO4J_URI"
+ value = "bolt://${google_compute_instance.neo4j_prod.network_interface[0].network_ip}:7687"
+ }
+
+ env {
+ name = "NEO4J_USER"
+ value = "neo4j"
+ }
+
+ env {
+ name = "NEO4J_PASSWORD"
+ value = random_password.neo4j_prod_password.result
+ }
+
+ # LLM Configuration
+ env {
+ name = "LLM_PROVIDER"
+ value = "gemini"
+ }
+
+ env {
+ name = "GEMINI_CONVERSATION_MODEL"
+ value = "gemini-2.5-flash"
+ }
+
+ env {
+ name = "EMBEDDING_PROVIDER"
+ value = "vertex-ai"
+ }
+
+ env {
+ name = "GCP_PROJECT_ID"
+ value = var.project_id
+ }
+
+ env {
+ name = "VERTEX_AI_PROJECT"
+ value = var.project_id
+ }
+
+ env {
+ name = "VERTEX_AI_LOCATION"
+ value = var.region
+ }
+
+ env {
+ name = "GOOGLE_GENAI_USE_VERTEXAI"
+ value = "true"
+ }
+
+ # MCP OAuth Configuration
+ env {
+ name = "MCP_OAUTH_CLIENT_ID"
+ value_source {
+ secret_key_ref {
+ secret = google_secret_manager_secret.mcp_oauth_client_id.secret_id
+ version = "latest"
+ }
+ }
+ }
+
+ env {
+ name = "MCP_OAUTH_CLIENT_SECRET"
+ value_source {
+ secret_key_ref {
+ secret = google_secret_manager_secret.mcp_oauth_client_secret.secret_id
+ version = "latest"
+ }
+ }
+ }
+
+ env {
+ name = "MCP_OAUTH_RESOURCE_IDENTIFIER"
+ value = "https://kb-mcp.${var.base_domain}"
+ }
+
+ env {
+ name = "MCP_DEV_MODE"
+ value = "false"
+ }
+
+ # Health check
+ startup_probe {
+ http_get {
+ path = "/health"
+ port = 8080
+ }
+ initial_delay_seconds = 5
+ period_seconds = 10
+ failure_threshold = 3
+ }
+ }
+
+ vpc_access {
+ connector = google_vpc_access_connector.connector.id
+ egress = "PRIVATE_RANGES_ONLY"
+ }
+
+ service_account = google_service_account.mcp_server.email
+ }
+
+ traffic {
+ type = "TRAFFIC_TARGET_ALLOCATION_TYPE_LATEST"
+ percent = 100
+ }
+
+ depends_on = [
+ google_secret_manager_secret_version.mcp_oauth_client_id,
+ google_secret_manager_secret_version.mcp_oauth_client_secret,
+ ]
+}
+
+# Allow unauthenticated access (OAuth is handled at application level)
+resource "google_cloud_run_v2_service_iam_member" "mcp_server_invoker" {
+ project = var.project_id
+ location = var.region
+ name = google_cloud_run_v2_service.mcp_server.name
+ role = "roles/run.invoker"
+ member = "allUsers"
+}
+
+# -----------------------------------------------------------------------------
+# Staging MCP Server
+# -----------------------------------------------------------------------------
+resource "google_cloud_run_v2_service" "mcp_server_staging" {
+ name = "kb-mcp-staging"
+ location = var.region
+
+ template {
+ scaling {
+ min_instance_count = 0
+ max_instance_count = 3
+ }
+
+ containers {
+ image = "${var.region}-docker.pkg.dev/${var.project_id}/knowledge-base/mcp-server:staging"
+
+ resources {
+ limits = {
+ cpu = "1"
+ memory = "512Mi"
+ }
+ }
+
+ # Graph Database Configuration (Graphiti + Neo4j)
+ env {
+ name = "GRAPH_BACKEND"
+ value = "neo4j"
+ }
+
+ env {
+ name = "GRAPH_ENABLE_GRAPHITI"
+ value = "true"
+ }
+
+ # Neo4j staging connection - using internal GCE VM IP via VPC connector
+ env {
+ name = "NEO4J_URI"
+ value = "bolt://${google_compute_instance.neo4j_staging.network_interface[0].network_ip}:7687"
+ }
+
+ env {
+ name = "NEO4J_USER"
+ value = "neo4j"
+ }
+
+ env {
+ name = "NEO4J_PASSWORD"
+ value = random_password.neo4j_staging_password.result
+ }
+
+ # LLM Configuration
+ env {
+ name = "LLM_PROVIDER"
+ value = "gemini"
+ }
+
+ env {
+ name = "GEMINI_CONVERSATION_MODEL"
+ value = "gemini-2.5-flash"
+ }
+
+ env {
+ name = "EMBEDDING_PROVIDER"
+ value = "vertex-ai"
+ }
+
+ env {
+ name = "GCP_PROJECT_ID"
+ value = var.project_id
+ }
+
+ env {
+ name = "VERTEX_AI_PROJECT"
+ value = var.project_id
+ }
+
+ env {
+ name = "VERTEX_AI_LOCATION"
+ value = var.region
+ }
+
+ env {
+ name = "GOOGLE_GENAI_USE_VERTEXAI"
+ value = "true"
+ }
+
+ # MCP OAuth Configuration (reuse same secrets for staging)
+ env {
+ name = "MCP_OAUTH_CLIENT_ID"
+ value_source {
+ secret_key_ref {
+ secret = google_secret_manager_secret.mcp_oauth_client_id.secret_id
+ version = "latest"
+ }
+ }
+ }
+
+ env {
+ name = "MCP_OAUTH_CLIENT_SECRET"
+ value_source {
+ secret_key_ref {
+ secret = google_secret_manager_secret.mcp_oauth_client_secret.secret_id
+ version = "latest"
+ }
+ }
+ }
+
+ env {
+ name = "MCP_OAUTH_RESOURCE_IDENTIFIER"
+ value = "https://kb-mcp-staging.${var.staging_domain}"
+ }
+
+ env {
+ name = "MCP_DEV_MODE"
+ value = "false"
+ }
+
+ env {
+ name = "ENVIRONMENT"
+ value = "staging"
+ }
+
+ # Health check
+ startup_probe {
+ http_get {
+ path = "/health"
+ port = 8080
+ }
+ initial_delay_seconds = 5
+ period_seconds = 10
+ failure_threshold = 3
+ }
+ }
+
+ vpc_access {
+ connector = google_vpc_access_connector.connector.id
+ egress = "PRIVATE_RANGES_ONLY"
+ }
+
+ service_account = google_service_account.mcp_server.email
+ }
+
+ traffic {
+ type = "TRAFFIC_TARGET_ALLOCATION_TYPE_LATEST"
+ percent = 100
+ }
+
+ depends_on = [
+ google_secret_manager_secret_version.mcp_oauth_client_id,
+ google_secret_manager_secret_version.mcp_oauth_client_secret,
+ google_compute_instance.neo4j_staging,
+ ]
+}
+
+# Allow unauthenticated access (OAuth is handled at application level)
+resource "google_cloud_run_v2_service_iam_member" "mcp_server_staging_invoker" {
+ project = var.project_id
+ location = var.region
+ name = google_cloud_run_v2_service.mcp_server_staging.name
+ role = "roles/run.invoker"
+ member = "allUsers"
+}
diff --git a/pyproject.toml b/pyproject.toml
index 072cc70..ae2c512 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,6 +36,10 @@ dependencies = [
"google-genai>=1.0.0", # For Gemini LLM client in Graphiti
"kuzu>=0.4.0",
"neo4j>=5.26.0",
+ # MCP server dependencies
+ "mcp>=1.25.0",
+ "authlib>=1.6.0",
+ "PyJWT>=2.10.0",
]
[project.optional-dependencies]
diff --git a/src/knowledge_base/core/__init__.py b/src/knowledge_base/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/knowledge_base/core/qa.py b/src/knowledge_base/core/qa.py
new file mode 100644
index 0000000..5705624
--- /dev/null
+++ b/src/knowledge_base/core/qa.py
@@ -0,0 +1,131 @@
+"""Core Q&A logic for knowledge base search and answer generation.
+
+Extracted from the Slack bot to be reusable across interfaces (Slack, MCP, API).
+"""
+
+import logging
+
+from knowledge_base.rag.factory import get_llm
+from knowledge_base.rag.exceptions import LLMError
+from knowledge_base.search import HybridRetriever, SearchResult
+
+logger = logging.getLogger(__name__)
+
+
+async def search_knowledge(query: str, limit: int = 5) -> list[SearchResult]:
+ """Search for relevant chunks using Graphiti hybrid search.
+
+ Uses HybridRetriever which delegates to Graphiti's unified search:
+ - Semantic similarity (embeddings)
+ - BM25 keyword matching
+ - Graph relationships
+
+ Returns SearchResult objects with content and metadata.
+ """
+ logger.info(f"Searching for: '{query[:100]}...'")
+
+ try:
+ retriever = HybridRetriever()
+ health = await retriever.check_health()
+ logger.info(f"Hybrid search health: {health}")
+
+ # Use Graphiti hybrid search
+ results = await retriever.search(query, k=limit)
+ logger.info(f"Hybrid search returned {len(results)} results")
+
+ # Log first result for debugging
+ if results:
+ first = results[0]
+ logger.info(
+ f"First result: chunk_id={first.chunk_id}, "
+ f"title={first.page_title}, content_len={len(first.content)}"
+ )
+
+ return results
+
+ except Exception as e:
+ logger.error(f"Hybrid search FAILED (returning 0 results): {e}", exc_info=True)
+
+ return []
+
+
+async def generate_answer(
+ question: str,
+ chunks: list[SearchResult],
+ conversation_history: list[dict[str, str]] | None = None,
+) -> str:
+ """Generate an answer using LLM with retrieved chunks.
+
+ Args:
+ question: The user's question
+ chunks: SearchResult objects from Graphiti containing content and metadata
+ conversation_history: Previous messages in the conversation thread
+ """
+ if not chunks:
+ return "I couldn't find relevant information in the knowledge base to answer your question."
+
+ # Build context from chunks (SearchResult has page_title property and content attribute)
+ context_parts = []
+ for i, chunk in enumerate(chunks, 1):
+ context_parts.append(
+ f"[Source {i}: {chunk.page_title}]\n{chunk.content[:1000]}"
+ )
+ context = "\n\n---\n\n".join(context_parts)
+
+ # Build conversation history section
+ conversation_section = ""
+ if conversation_history:
+ history_parts = []
+ for msg in conversation_history[-6:]: # Last 6 messages for context
+ role = "User" if msg["role"] == "user" else "Assistant"
+ # Truncate long messages in history
+ content = msg["content"][:500] + "..." if len(msg["content"]) > 500 else msg["content"]
+ history_parts.append(f"{role}: {content}")
+ if history_parts:
+ conversation_section = f"""
+PREVIOUS CONVERSATION:
+{chr(10).join(history_parts)}
+
+(Use this context to understand what the user is asking about and provide continuity)
+"""
+
+ prompt = f"""You are Keboola's internal knowledge base assistant. Answer questions ONLY based on the provided context documents.
+
+CRITICAL RULES:
+- ONLY use information explicitly stated in the context documents below.
+- Do NOT make up, assume, or hallucinate any information not in the documents.
+- If the context doesn't contain enough information to answer, say so clearly.
+- When referencing information, mention which source it came from.
+{conversation_section}
+CONTEXT DOCUMENTS:
+{context}
+
+CURRENT QUESTION: {question}
+
+INSTRUCTIONS:
+- Answer based strictly on the context documents above.
+- Be concise and helpful. Use bullet points for multiple items.
+- If the documents only partially answer the question, share what IS available and note what's missing.
+- Do NOT invent tool names, process steps, or policies not mentioned in the documents.
+
+Provide your answer:"""
+
+ try:
+ llm = await get_llm()
+ logger.info(f"Using LLM provider: {llm.provider_name}")
+
+ # Skip health check - generate() has proper retry logic and error handling
+ answer = await llm.generate(prompt)
+ return answer.strip()
+ except LLMError as e:
+ logger.error(f"LLM provider error: {e}")
+ return (
+ f"I found {len(chunks)} relevant documents but couldn't generate "
+ f"an answer at this time. Please try again later."
+ )
+ except Exception as e:
+ logger.error(f"LLM generation failed: {e}")
+ return (
+ f"I found {len(chunks)} relevant documents but couldn't generate "
+ f"an answer at this time. Please try again later."
+ )
diff --git a/src/knowledge_base/mcp/__init__.py b/src/knowledge_base/mcp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/knowledge_base/mcp/config.py b/src/knowledge_base/mcp/config.py
new file mode 100644
index 0000000..32f9ff8
--- /dev/null
+++ b/src/knowledge_base/mcp/config.py
@@ -0,0 +1,105 @@
+"""MCP server configuration using pydantic-settings."""
+
+from pydantic import field_validator
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+
+class MCPSettings(BaseSettings):
+ """MCP server settings loaded from environment variables.
+
+ Required variables (no defaults - fail fast if missing):
+ MCP_OAUTH_CLIENT_ID: Google OAuth client ID
+ MCP_OAUTH_RESOURCE_IDENTIFIER: Resource server identifier (e.g. "https://kb-mcp.keboola.com")
+ """
+
+ model_config = SettingsConfigDict(
+ env_file=".env",
+ env_file_encoding="utf-8",
+ extra="ignore",
+ )
+
+ # OAuth 2.1 Configuration (Google OAuth)
+ MCP_OAUTH_CLIENT_ID: str # Required - fail fast if missing
+ MCP_OAUTH_CLIENT_SECRET: str # Required - needed for token exchange with Google
+
+ @field_validator("MCP_OAUTH_CLIENT_ID", "MCP_OAUTH_CLIENT_SECRET")
+ @classmethod
+ def must_be_non_empty(cls, v: str, info) -> str:
+ """Ensure OAuth credentials are non-empty strings."""
+ if not v or not v.strip():
+ raise ValueError(
+ f"{info.field_name} must be a non-empty string"
+ )
+ return v
+
+ MCP_OAUTH_AUTHORIZATION_SERVER: str = "https://accounts.google.com"
+ MCP_OAUTH_AUTHORIZATION_ENDPOINT: str = (
+ "https://accounts.google.com/o/oauth2/v2/auth"
+ )
+ MCP_OAUTH_TOKEN_ENDPOINT: str = "https://oauth2.googleapis.com/token"
+ MCP_OAUTH_JWKS_URI: str = "https://www.googleapis.com/oauth2/v3/certs"
+ MCP_OAUTH_ISSUER: str = "https://accounts.google.com"
+ MCP_OAUTH_RESOURCE_IDENTIFIER: str # Required - e.g. "https://kb-mcp.keboola.com"
+ MCP_OAUTH_SCOPES: str = "openid email profile"
+
+ # Rate Limiting
+ MCP_RATE_LIMIT_READ_PER_MINUTE: int = 30
+ MCP_RATE_LIMIT_WRITE_PER_HOUR: int = 20
+
+ # Server
+ MCP_HOST: str = "0.0.0.0"
+ MCP_PORT: int = 8080
+ MCP_DEV_MODE: bool = False
+ MCP_DEBUG: bool = False
+
+
+# OAuth scope definitions
+OAUTH_SCOPES: dict[str, str] = {
+ "openid": "OpenID Connect authentication",
+ "email": "User email address",
+ "profile": "User profile information",
+ "kb.read": "Search and query the knowledge base",
+ "kb.write": "Create knowledge and submit feedback",
+}
+
+# Required scopes per MCP tool
+TOOL_SCOPE_REQUIREMENTS: dict[str, list[str]] = {
+ "ask_question": ["kb.read"],
+ "search_knowledge": ["kb.read"],
+ "create_knowledge": ["kb.write"],
+ "ingest_document": ["kb.write"],
+ "submit_feedback": ["kb.write"],
+ "check_health": ["kb.read"],
+}
+
+# Tools that perform write operations (subject to stricter rate limits)
+WRITE_TOOLS: list[str] = [
+ tool
+ for tool, scopes in TOOL_SCOPE_REQUIREMENTS.items()
+ if "kb.write" in scopes
+]
+
+# Rate limit configuration keyed by operation type
+RATE_LIMITS: dict[str, dict[str, int]] = {
+ "read": {
+ "requests": 30, # Default, overridden by MCPSettings.MCP_RATE_LIMIT_READ_PER_MINUTE
+ "window_seconds": 60,
+ },
+ "write": {
+ "requests": 20, # Default, overridden by MCPSettings.MCP_RATE_LIMIT_WRITE_PER_HOUR
+ "window_seconds": 3600,
+ },
+}
+
+
+def check_scope_access(required_scopes: list[str], granted_scopes: list[str]) -> bool:
+ """Check if any of the required scopes are in the granted scopes.
+
+ Args:
+ required_scopes: Scopes required by the tool (at least one must match).
+ granted_scopes: Scopes granted to the authenticated user/token.
+
+ Returns:
+ True if at least one required scope is present in granted scopes.
+ """
+ return any(scope in granted_scopes for scope in required_scopes)
diff --git a/src/knowledge_base/mcp/oauth/__init__.py b/src/knowledge_base/mcp/oauth/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/knowledge_base/mcp/oauth/metadata.py b/src/knowledge_base/mcp/oauth/metadata.py
new file mode 100644
index 0000000..2298d5a
--- /dev/null
+++ b/src/knowledge_base/mcp/oauth/metadata.py
@@ -0,0 +1,68 @@
+"""
+RFC 9728 OAuth Protected Resource Metadata
+
+Implements the Protected Resource Metadata endpoint for OAuth 2.1 resource servers.
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+
+@dataclass
+class ProtectedResourceMetadata:
+ """
+ RFC 9728 Protected Resource Metadata.
+
+ Describes the OAuth-protected resource server and its requirements.
+ This metadata is served at /.well-known/oauth-protected-resource
+ """
+
+ # Required: The resource identifier (typically the server URL)
+ resource: str
+
+ # Required: List of authorization servers that can issue tokens
+ authorization_servers: list[str]
+
+ # Optional: Supported OAuth scopes
+ scopes_supported: list[str] = field(default_factory=list)
+
+ # Optional: Bearer token methods supported
+ bearer_methods_supported: list[str] = field(
+ default_factory=lambda: ["header"]
+ )
+
+ # Optional: Resource documentation URL
+ resource_documentation: Optional[str] = None
+
+ # Optional: Additional resource metadata
+ resource_signing_alg_values_supported: list[str] = field(
+ default_factory=lambda: ["RS256", "ES256"]
+ )
+
+ def to_dict(self) -> dict:
+ """Serialize metadata to dictionary for JSON response."""
+ result = {
+ "resource": self.resource,
+ "authorization_servers": self.authorization_servers,
+ }
+
+ if self.scopes_supported:
+ result["scopes_supported"] = self.scopes_supported
+
+ if self.bearer_methods_supported:
+ result["bearer_methods_supported"] = self.bearer_methods_supported
+
+ if self.resource_documentation:
+ result["resource_documentation"] = self.resource_documentation
+
+ if self.resource_signing_alg_values_supported:
+ result["resource_signing_alg_values_supported"] = (
+ self.resource_signing_alg_values_supported
+ )
+
+ return result
+
+ def to_json(self) -> str:
+ """Serialize metadata to JSON string."""
+ import json
+ return json.dumps(self.to_dict(), indent=2)
diff --git a/src/knowledge_base/mcp/oauth/resource_server.py b/src/knowledge_base/mcp/oauth/resource_server.py
new file mode 100644
index 0000000..0c2b664
--- /dev/null
+++ b/src/knowledge_base/mcp/oauth/resource_server.py
@@ -0,0 +1,259 @@
+"""
+OAuth 2.1 Resource Server Implementation
+
+Provides FastAPI middleware for OAuth token validation and user context extraction.
+"""
+
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Callable, Optional
+
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import JSONResponse, Response
+
+from .token_validator import TokenValidator, TokenValidationError
+from .metadata import ProtectedResourceMetadata
+
+logger = logging.getLogger(__name__)
+
+
+def extract_user_context(claims: dict[str, Any]) -> dict[str, Any]:
+ """
+ Extract user context from validated token claims.
+
+ Args:
+ claims: Decoded JWT claims
+
+ Returns:
+ User context dictionary with:
+ - email: User's email address
+ - scopes: List of granted scopes
+ - sub: Subject identifier
+
+ Scope logic:
+ - All verified Google users get: kb.read
+ - Verified @keboola.com users additionally get: kb.write
+ - Non-Google tokens: scopes extracted from the 'scope' claim as-is
+
+ Note:
+ Google ID tokens don't include a 'scope' claim. For Google OAuth,
+ we grant default scopes based on email verification and domain.
+ """
+ # Extract scopes (space-separated string to list)
+ scope_string = claims.get("scope", "")
+ scopes = scope_string.split() if scope_string else []
+
+ # For Google OAuth tokens without scope claim, grant default scopes
+ # Google tokens have 'iss' = 'https://accounts.google.com' and 'email_verified'
+ if not scopes and claims.get("iss") == "https://accounts.google.com":
+ email = claims.get("email", "")
+ email_verified = claims.get("email_verified", False)
+
+ if email_verified and email:
+ # Grant read access to all verified Google users
+ scopes = [
+ "openid",
+ "email",
+ "profile",
+ "kb.read",
+ ]
+ logger.info("Google OAuth: granted default scopes for verified user")
+
+ # Grant write access for @keboola.com domain (internal users)
+ # email_verified is already checked above, but enforce explicitly for write scope
+ if email.endswith("@keboola.com") and email_verified:
+ scopes.append("kb.write")
+ logger.info("Google OAuth: granted write scope for internal user")
+
+ return {
+ "sub": claims.get("sub"),
+ "email": claims.get("email", claims.get("sub")),
+ "scopes": scopes,
+ "claims": claims,
+ }
+
+
+@dataclass
+class OAuthResourceServer:
+ """
+ OAuth 2.1 Resource Server configuration.
+
+ Manages token validation and protected resource metadata.
+ """
+
+ resource: str
+ authorization_servers: list[str]
+ audience: str
+ scopes_supported: list[str] = field(default_factory=list)
+
+ # Internal components
+ _validator: Optional[TokenValidator] = field(default=None, repr=False)
+ _metadata: Optional[ProtectedResourceMetadata] = field(default=None, repr=False)
+
+ def __post_init__(self):
+ """Initialize internal components."""
+ if self.authorization_servers:
+ self._validator = TokenValidator(
+ issuer=self.authorization_servers[0],
+ audience=self.audience,
+ )
+
+ self._metadata = ProtectedResourceMetadata(
+ resource=self.resource,
+ authorization_servers=self.authorization_servers,
+ scopes_supported=self.scopes_supported,
+ )
+
+ @property
+ def metadata(self) -> ProtectedResourceMetadata:
+ """Get protected resource metadata."""
+ return self._metadata
+
+ @property
+ def validator(self) -> TokenValidator:
+ """Get token validator."""
+ if not self._validator:
+ raise RuntimeError("No authorization server configured")
+ return self._validator
+
+ def validate_token(self, token: str) -> dict[str, Any]:
+ """Validate access token and return claims."""
+ return self.validator.validate(token)
+
+ async def validate_token_async(self, token: str) -> dict[str, Any]:
+ """Validate access token asynchronously."""
+ return await self.validator.validate_async(token)
+
+
+class OAuthMiddleware(BaseHTTPMiddleware):
+ """
+ FastAPI/Starlette middleware for OAuth token validation.
+
+ Validates Bearer tokens in Authorization header and adds
+ user context to request.state.
+ """
+
+ def __init__(
+ self,
+ app,
+ resource_server: Optional[OAuthResourceServer] = None,
+ exclude_paths: Optional[list[str]] = None,
+ dev_mode: bool = False,
+ ):
+ """
+ Initialize OAuth middleware.
+
+ Args:
+ app: FastAPI/Starlette application
+ resource_server: OAuth resource server configuration
+ exclude_paths: Paths to exclude from auth (e.g., /health)
+ dev_mode: If True, skip validation (for development only)
+ """
+ super().__init__(app)
+ self.resource_server = resource_server
+ self.exclude_paths = exclude_paths or [
+ "/health",
+ "/.well-known/oauth-protected-resource",
+ "/callback",
+ ]
+ self.dev_mode = dev_mode
+
+ def _extract_token(self, request: Request) -> Optional[str]:
+ """Extract Bearer token from Authorization header."""
+ auth_header = request.headers.get("Authorization", "")
+ if auth_header.startswith("Bearer "):
+ return auth_header[7:]
+ return None
+
+ def _should_skip_auth(self, request: Request) -> bool:
+ """Check if path should skip authentication."""
+ path = request.url.path
+ return any(
+ path == excluded or path.startswith(f"{excluded}/")
+ for excluded in self.exclude_paths
+ )
+
+ def _unauthorized_response(self, error: str = "Unauthorized") -> Response:
+ """Return 401 Unauthorized response with WWW-Authenticate header."""
+ return JSONResponse(
+ status_code=401,
+ content={"error": "unauthorized", "error_description": error},
+ headers={
+ "WWW-Authenticate": 'Bearer realm="knowledge-base-mcp", error="invalid_token"'
+ },
+ )
+
+ def _forbidden_response(self, error: str = "Forbidden") -> Response:
+ """Return 403 Forbidden response."""
+ return JSONResponse(
+ status_code=403,
+ content={"error": "insufficient_scope", "error_description": error},
+ )
+
+ async def dispatch(
+ self, request: Request, call_next: Callable
+ ) -> Response:
+ """Process request through OAuth validation."""
+
+ # Skip auth for excluded paths
+ if self._should_skip_auth(request):
+ return await call_next(request)
+
+ # Extract token
+ token = self._extract_token(request)
+ if not token:
+ return self._unauthorized_response("Missing Bearer token")
+
+ # Dev mode: skip validation
+ if self.dev_mode:
+ request.state.user = {
+ "sub": "dev-user",
+ "email": "dev@example.com",
+ "scopes": ["openid", "kb.read", "kb.write"],
+ "claims": {},
+ }
+ return await call_next(request)
+
+ # Validate token
+ if not self.resource_server:
+ return self._unauthorized_response("OAuth not configured")
+
+ try:
+ claims = await self.resource_server.validate_token_async(token)
+ request.state.user = extract_user_context(claims)
+ return await call_next(request)
+ except TokenValidationError:
+ return self._unauthorized_response("Token validation failed")
+ except Exception:
+ return self._unauthorized_response("Token validation failed")
+
+
+def require_scopes(*required_scopes: str):
+ """
+ Dependency for requiring specific OAuth scopes.
+
+ Usage:
+ @app.get("/api/search")
+ async def search(user: dict = Depends(require_scopes("kb.read"))):
+ ...
+ """
+ from fastapi import HTTPException, Request
+
+ async def dependency(request: Request) -> dict:
+ user = getattr(request.state, "user", None)
+ if not user:
+ raise HTTPException(status_code=401, detail="Not authenticated")
+
+ user_scopes = user.get("scopes", [])
+ has_scope = any(scope in user_scopes for scope in required_scopes)
+
+ if not has_scope:
+ raise HTTPException(
+ status_code=403,
+ detail=f"Required scope: {' or '.join(required_scopes)}",
+ )
+
+ return user
+
+ return dependency
diff --git a/src/knowledge_base/mcp/oauth/token_validator.py b/src/knowledge_base/mcp/oauth/token_validator.py
new file mode 100644
index 0000000..9aa647b
--- /dev/null
+++ b/src/knowledge_base/mcp/oauth/token_validator.py
@@ -0,0 +1,271 @@
+"""
+JWT Token Validator for OAuth 2.1 Resource Server
+
+Validates access tokens issued by the authorization server.
+"""
+
+import time
+from dataclasses import dataclass, field
+from typing import Any, Optional
+import logging
+
+import httpx
+
+logger = logging.getLogger(__name__)
+
+
+class TokenValidationError(Exception):
+ """Base exception for token validation errors."""
+ pass
+
+
+class TokenExpiredError(TokenValidationError):
+ """Token has expired."""
+ pass
+
+
+class InvalidTokenError(TokenValidationError):
+ """Token is invalid (bad signature, format, etc.)."""
+ pass
+
+
+class InvalidIssuerError(TokenValidationError):
+ """Token issuer doesn't match expected issuer."""
+ pass
+
+
+class InvalidAudienceError(TokenValidationError):
+ """Token audience doesn't match expected audience."""
+ pass
+
+
+@dataclass
+class TokenValidator:
+ """
+ Validates JWT access tokens using JWKS from authorization server.
+
+ Supports:
+ - RS256, RS384, RS512 (RSA)
+ - ES256, ES384, ES512 (ECDSA)
+ - Google OAuth (ID tokens with client_id as audience)
+ """
+
+ issuer: str
+ audience: str
+ jwks_uri: Optional[str] = None
+ jwks: dict = field(default_factory=dict)
+ _jwks_cache_time: float = field(default=0.0, repr=False)
+ _jwks_cache_ttl: int = 3600 # 1 hour
+ # Google OAuth specific
+ is_google: bool = False
+ authorized_party: Optional[str] = None # For Google: must match client_id
+
+ def __post_init__(self):
+ """Initialize JWKS URI from issuer if not provided."""
+ if not self.jwks_uri:
+ # Google uses a different JWKS endpoint
+ if self.issuer == "https://accounts.google.com":
+ self.jwks_uri = "https://www.googleapis.com/oauth2/v3/certs"
+ self.is_google = True
+ else:
+ self.jwks_uri = f"{self.issuer.rstrip('/')}/.well-known/jwks.json"
+
+ async def fetch_jwks(self) -> dict:
+ """
+ Fetch JWKS from authorization server.
+
+ Caches the JWKS for _jwks_cache_ttl seconds.
+ """
+ now = time.time()
+ if self.jwks and (now - self._jwks_cache_time) < self._jwks_cache_ttl:
+ return self.jwks
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(self.jwks_uri)
+ response.raise_for_status()
+ self.jwks = response.json()
+ self._jwks_cache_time = now
+
+ return self.jwks
+
+ def validate(self, token: str) -> dict[str, Any]:
+ """
+ Validate a token synchronously.
+
+ For Google OAuth, handles both:
+ - JWT id_tokens (3 parts separated by dots)
+ - Opaque access_tokens (validated via Google's tokeninfo endpoint)
+
+ Args:
+ token: Token string (JWT or opaque)
+
+ Returns:
+ Token claims if valid
+
+ Raises:
+ TokenValidationError: If token is invalid
+ """
+ try:
+ import jwt
+ from jwt import PyJWKClient
+ except ImportError:
+ raise TokenValidationError(
+ "PyJWT library required for token validation"
+ )
+
+ # Check if this is a JWT (3 parts) or opaque token
+ parts = token.split(".")
+ if len(parts) != 3:
+ # Not a JWT - for Google, validate via tokeninfo endpoint
+ if self.is_google:
+ return self._validate_google_access_token(token)
+ raise InvalidTokenError("Invalid JWT format")
+
+ try:
+ # Get signing key from JWKS
+ jwks_client = PyJWKClient(self.jwks_uri)
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
+
+ # Google ID tokens use client_id as audience
+ if self.is_google:
+ # For Google: audience is the client_id
+ claims = jwt.decode(
+ token,
+ signing_key.key,
+ algorithms=["RS256"],
+ issuer=self.issuer,
+ audience=self.audience, # This is the Google client_id
+ options={
+ "require": ["exp", "iss", "aud", "email"],
+ "verify_exp": True,
+ "verify_iss": True,
+ "verify_aud": True,
+ }
+ )
+
+ # Additional Google-specific checks
+ # Verify azp (authorized party) if provided
+ if self.authorized_party and claims.get("azp"):
+ if claims["azp"] != self.authorized_party:
+ raise InvalidTokenError(
+ f"Invalid authorized party: {claims['azp']}"
+ )
+
+ # Google tokens should have email_verified
+ return claims
+ else:
+ # Standard OAuth 2.0 token validation
+ claims = jwt.decode(
+ token,
+ signing_key.key,
+ algorithms=["RS256", "RS384", "RS512", "ES256", "ES384", "ES512"],
+ issuer=self.issuer,
+ audience=self.audience,
+ options={
+ "require": ["exp", "iss", "aud"],
+ "verify_exp": True,
+ "verify_iss": True,
+ "verify_aud": True,
+ }
+ )
+
+ return claims
+
+ except jwt.ExpiredSignatureError:
+ raise TokenExpiredError("Token has expired")
+ except jwt.InvalidIssuerError:
+ raise InvalidIssuerError(f"Invalid issuer, expected {self.issuer}")
+ except jwt.InvalidAudienceError:
+ raise InvalidAudienceError(f"Invalid audience, expected {self.audience}")
+ except jwt.InvalidTokenError as e:
+ raise InvalidTokenError(f"Invalid token: {type(e).__name__}")
+ except Exception as e:
+ raise TokenValidationError(f"Token validation failed: {type(e).__name__}")
+
+ def _validate_google_access_token(self, token: str) -> dict[str, Any]:
+ """
+ Validate a Google opaque access token using tokeninfo endpoint.
+
+ Google access tokens are not JWTs, so we validate them by calling
+ Google's tokeninfo endpoint which returns user info if valid.
+
+ Args:
+ token: Google access token (opaque string)
+
+ Returns:
+ Token claims including email, sub, etc.
+
+ Raises:
+ TokenValidationError: If token is invalid
+ """
+ import httpx
+
+ try:
+ # Call Google's tokeninfo endpoint
+ response = httpx.get(
+ "https://www.googleapis.com/oauth2/v3/tokeninfo",
+ params={"access_token": token},
+ timeout=10.0
+ )
+
+ if response.status_code != 200:
+ error_data = response.json() if response.text else {}
+ raise InvalidTokenError(
+ f"Google token validation failed: {error_data.get('error_description', 'Unknown error')}"
+ )
+
+ claims = response.json()
+
+ # Verify audience matches our client_id
+ token_aud = claims.get("aud") or claims.get("azp")
+ if token_aud != self.audience:
+ raise InvalidAudienceError(
+ f"Invalid audience: expected {self.audience}, got {token_aud}"
+ )
+
+ # Check expiration
+ exp = claims.get("expires_in")
+ if exp is not None and int(exp) <= 0:
+ raise TokenExpiredError("Token has expired")
+
+ # Normalize claims to match JWT format
+ normalized = {
+ "sub": claims.get("sub"),
+ "email": claims.get("email"),
+ "email_verified": claims.get("email_verified") == "true",
+ "aud": token_aud,
+ "iss": "https://accounts.google.com",
+ "scope": claims.get("scope", ""),
+ }
+
+ return normalized
+
+ except httpx.RequestError as e:
+ raise TokenValidationError(f"Failed to validate token with Google: {type(e).__name__}")
+
+ async def validate_async(self, token: str) -> dict[str, Any]:
+ """
+ Validate a token asynchronously.
+
+ Fetches JWKS if not cached before validation.
+ """
+ await self.fetch_jwks()
+ return self.validate(token)
+
+ def get_claims(self, token: str) -> dict[str, Any]:
+ """
+ Extract claims from token without full validation.
+
+ WARNING: Only use this for debugging/logging.
+ Always use validate() for actual token verification.
+ """
+ try:
+ import jwt
+ # Decode without verification
+ claims = jwt.decode(
+ token,
+ options={"verify_signature": False}
+ )
+ return claims
+ except Exception as e:
+ raise InvalidTokenError(f"Cannot decode token: {e}")
diff --git a/src/knowledge_base/mcp/server.py b/src/knowledge_base/mcp/server.py
new file mode 100644
index 0000000..7f5acac
--- /dev/null
+++ b/src/knowledge_base/mcp/server.py
@@ -0,0 +1,517 @@
+"""MCP HTTP Server with OAuth 2.1 for Knowledge Base.
+
+Provides Streamable HTTP transport for MCP protocol with Google OAuth authentication.
+Acts as both OAuth Authorization Server (proxying to Google) and Resource Server.
+Compatible with Claude.AI remote MCP server integration.
+"""
+
+import asyncio
+import logging
+from contextlib import asynccontextmanager
+from typing import Any, Optional
+from urllib.parse import urlencode
+
+import httpx
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
+from pydantic import BaseModel
+
+from knowledge_base.mcp.config import (
+ MCPSettings,
+ OAUTH_SCOPES,
+ TOOL_SCOPE_REQUIREMENTS,
+ check_scope_access,
+)
+from knowledge_base.mcp.oauth.metadata import ProtectedResourceMetadata
+from knowledge_base.mcp.oauth.resource_server import (
+ OAuthResourceServer,
+ extract_user_context,
+)
+from knowledge_base.mcp.tools import TOOLS, execute_tool, get_tools_for_scopes
+
+logger = logging.getLogger(__name__)
+
+# Initialize settings
+mcp_settings = MCPSettings()
+
+
+def _get_oauth_audience() -> str:
+ """Get OAuth audience. For Google OAuth, audience is the client_id."""
+ return mcp_settings.MCP_OAUTH_CLIENT_ID
+
+
+def _get_advertised_scopes() -> list[str]:
+ """Get scopes to advertise. Google only understands standard OpenID scopes."""
+ return ["openid", "email", "profile"]
+
+
+# Initialize resource server
+resource_server = OAuthResourceServer(
+ resource=mcp_settings.MCP_OAUTH_RESOURCE_IDENTIFIER,
+ authorization_servers=[mcp_settings.MCP_OAUTH_AUTHORIZATION_SERVER],
+ audience=_get_oauth_audience(),
+ scopes_supported=_get_advertised_scopes(),
+)
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Application lifespan manager."""
+ global resource_server
+
+ resource_server = OAuthResourceServer(
+ resource=mcp_settings.MCP_OAUTH_RESOURCE_IDENTIFIER,
+ authorization_servers=[mcp_settings.MCP_OAUTH_AUTHORIZATION_SERVER],
+ audience=_get_oauth_audience(),
+ scopes_supported=_get_advertised_scopes(),
+ )
+
+ logger.info(f"MCP Server started on {mcp_settings.MCP_HOST}:{mcp_settings.MCP_PORT}")
+
+ yield
+
+
+# Create FastAPI app
+app = FastAPI(
+ title="Knowledge Base MCP Server",
+ description="MCP server for Keboola AI Knowledge Base with OAuth 2.1 authentication",
+ version="0.1.0",
+ lifespan=lifespan,
+)
+
+# CORS middleware
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["https://claude.ai", "https://www.claude.ai"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+# =============================================================================
+# OAuth Middleware
+# =============================================================================
+
+
+@app.middleware("http")
+async def oauth_middleware(request: Request, call_next):
+ """OAuth authentication middleware."""
+ skip_paths = [
+ "/health",
+ "/.well-known/oauth-protected-resource",
+ "/.well-known/oauth-authorization-server",
+ "/authorize",
+ "/token",
+ "/register",
+ "/callback",
+ "/",
+ ]
+ if request.url.path in skip_paths:
+ return await call_next(request)
+
+ # Extract Bearer token
+ auth_header = request.headers.get("Authorization", "")
+ if not auth_header.startswith("Bearer "):
+ return JSONResponse(
+ status_code=401,
+ content={"error": "unauthorized", "error_description": "Missing Bearer token"},
+ headers={"WWW-Authenticate": 'Bearer realm="knowledge-base-mcp"'},
+ )
+
+ token = auth_header[7:]
+
+ # Dev mode: skip validation
+ if mcp_settings.MCP_DEV_MODE:
+ import os
+
+ dev_email = os.getenv("TEST_USER_EMAIL", "dev@keboola.com")
+ request.state.user = {
+ "sub": "dev-user",
+ "email": dev_email,
+ "scopes": list(OAUTH_SCOPES.keys()),
+ "claims": {},
+ }
+ return await call_next(request)
+
+ # Validate token
+ if resource_server:
+ try:
+ claims = await resource_server.validate_token_async(token)
+ request.state.user = extract_user_context(claims)
+ return await call_next(request)
+ except Exception:
+ return JSONResponse(
+ status_code=401,
+ content={"error": "invalid_token", "error_description": "Token validation failed"},
+ headers={
+ "WWW-Authenticate": 'Bearer realm="knowledge-base-mcp", error="invalid_token"'
+ },
+ )
+
+ return await call_next(request)
+
+
+# =============================================================================
+# Health & Metadata Endpoints
+# =============================================================================
+
+
+@app.get("/health")
+async def health_check():
+ """Health check endpoint."""
+ return {"status": "healthy", "service": "knowledge-base-mcp-server"}
+
+
+@app.get("/.well-known/oauth-protected-resource")
+async def oauth_protected_resource_metadata():
+ """RFC 9728 Protected Resource Metadata endpoint."""
+ if not resource_server:
+ raise HTTPException(status_code=503, detail="OAuth not configured")
+ return resource_server.metadata.to_dict()
+
+
+@app.get("/callback")
+async def oauth_callback(code: str = None, state: str = None, error: str = None):
+ """OAuth authorization code callback (for browser popup flows)."""
+ if error:
+ return HTMLResponse(
+ content=f"
OAuth Error
{error}
",
+ status_code=400,
+ )
+
+ if code:
+ return HTMLResponse(
+ content="""
+
+
+ Authorization Successful
+ You can close this window and return to Claude.
+
+
+
+ """
+ % (code, state or ""),
+ status_code=200,
+ )
+
+ return HTMLResponse(
+ content="OAuth Callback
",
+ status_code=200,
+ )
+
+
+# =============================================================================
+# OAuth Authorization Server Endpoints (proxy to Google)
+# =============================================================================
+# The MCP spec requires the MCP server to act as an OAuth Authorization Server.
+# Claude.AI discovers these endpoints via /.well-known/oauth-authorization-server
+# or falls back to default paths (/authorize, /token, /register).
+# We proxy the OAuth flow to Google as the upstream identity provider.
+
+
+@app.get("/.well-known/oauth-authorization-server")
+async def oauth_authorization_server_metadata(request: Request):
+ """RFC 8414 OAuth Authorization Server Metadata.
+
+ Tells MCP clients (Claude.AI) where our authorization endpoints are.
+ """
+ base_url = _get_base_url(request)
+
+ return {
+ "issuer": base_url,
+ "authorization_endpoint": f"{base_url}/authorize",
+ "token_endpoint": f"{base_url}/token",
+ "registration_endpoint": f"{base_url}/register",
+ "response_types_supported": ["code"],
+ "grant_types_supported": ["authorization_code", "refresh_token"],
+ "code_challenge_methods_supported": ["S256"],
+ "token_endpoint_auth_methods_supported": ["client_secret_post", "none"],
+ "scopes_supported": list(OAUTH_SCOPES.keys()),
+ }
+
+
+@app.get("/authorize")
+async def oauth_authorize(request: Request):
+ """OAuth authorization endpoint - redirects to Google OAuth.
+
+ Claude.AI sends the user here. We redirect to Google's authorize endpoint,
+ mapping scopes and preserving PKCE parameters. Google redirects back to
+ Claude.AI's callback directly.
+ """
+ params = dict(request.query_params)
+
+ # Map any non-Google scopes to Google-compatible scopes
+ requested_scope = params.get("scope", "")
+ google_scopes = _map_to_google_scopes(requested_scope)
+ params["scope"] = google_scopes
+
+ # Ensure client_id is set (use ours if not provided)
+ if "client_id" not in params or not params["client_id"]:
+ params["client_id"] = mcp_settings.MCP_OAUTH_CLIENT_ID
+
+ google_authorize_url = (
+ f"{mcp_settings.MCP_OAUTH_AUTHORIZATION_ENDPOINT}?{urlencode(params)}"
+ )
+ return RedirectResponse(url=google_authorize_url, status_code=302)
+
+
+@app.post("/token")
+async def oauth_token(request: Request):
+ """OAuth token endpoint - proxies token exchange to Google.
+
+ Claude.AI sends the authorization code here. We forward it to Google's
+ token endpoint, adding our client_secret for the exchange.
+ """
+ # Parse form data (OAuth token requests use application/x-www-form-urlencoded)
+ form_data = await request.form()
+ token_params = dict(form_data)
+
+ # Add our client credentials for the token exchange
+ token_params["client_id"] = mcp_settings.MCP_OAUTH_CLIENT_ID
+ token_params["client_secret"] = mcp_settings.MCP_OAUTH_CLIENT_SECRET
+
+ async with httpx.AsyncClient() as client:
+ response = await client.post(
+ mcp_settings.MCP_OAUTH_TOKEN_ENDPOINT,
+ data=token_params,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ )
+
+ # Return Google's response directly to Claude.AI
+ return JSONResponse(
+ status_code=response.status_code,
+ content=response.json(),
+ )
+
+
+@app.post("/register")
+async def oauth_register(request: Request):
+ """OAuth Dynamic Client Registration (RFC 7591).
+
+ Returns our Google OAuth client_id so Claude.AI can use it for the
+ authorization flow. This is a simplified registration that always
+ returns the same client credentials.
+ """
+ body = await request.json()
+ redirect_uris = body.get("redirect_uris", [])
+ client_name = body.get("client_name", "MCP Client")
+
+ return JSONResponse(
+ status_code=201,
+ content={
+ "client_id": mcp_settings.MCP_OAUTH_CLIENT_ID,
+ "client_name": client_name,
+ "redirect_uris": redirect_uris,
+ "grant_types": ["authorization_code", "refresh_token"],
+ "response_types": ["code"],
+ "token_endpoint_auth_method": "none",
+ },
+ )
+
+
+def _get_base_url(request: Request) -> str:
+ """Get the external base URL, respecting X-Forwarded-Proto from reverse proxies."""
+ scheme = request.headers.get("x-forwarded-proto", request.url.scheme)
+ host = request.headers.get("host", request.url.netloc)
+ return f"{scheme}://{host}"
+
+
+def _map_to_google_scopes(requested_scope: str) -> str:
+ """Map MCP/custom scopes to Google-compatible OAuth scopes.
+
+ Claude.AI may send scopes like 'claudeai' or our custom 'kb.read kb.write'.
+ Google only understands standard OpenID scopes.
+ """
+ google_scopes = {"openid", "email", "profile"}
+
+ if requested_scope:
+ for scope in requested_scope.split():
+ # Keep standard scopes, discard custom ones
+ if scope in ("openid", "email", "profile"):
+ google_scopes.add(scope)
+
+ return " ".join(sorted(google_scopes))
+
+
+# =============================================================================
+# MCP Protocol Endpoint
+# =============================================================================
+
+
+class MCPRequest(BaseModel):
+ """MCP JSON-RPC request."""
+
+ jsonrpc: str = "2.0"
+ method: str
+ params: Optional[dict[str, Any]] = None
+ id: Optional[int | str] = None
+
+
+class MCPResponse(BaseModel):
+ """MCP JSON-RPC response."""
+
+ jsonrpc: str = "2.0"
+ result: Optional[Any] = None
+ error: Optional[dict[str, Any]] = None
+ id: Optional[int | str] = None
+
+
+@app.get("/mcp")
+async def mcp_sse_endpoint(request: Request):
+ """MCP SSE endpoint for server-initiated messages."""
+ from starlette.responses import StreamingResponse
+
+ user = getattr(request.state, "user", None)
+ if not user:
+ raise HTTPException(status_code=401, detail="Not authenticated")
+
+ async def event_generator():
+ yield "data: {}\n\n"
+ while True:
+ await asyncio.sleep(30)
+ yield ": heartbeat\n\n"
+
+ return StreamingResponse(
+ event_generator(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no",
+ },
+ )
+
+
+@app.post("/mcp")
+async def mcp_endpoint(request: Request, mcp_request: MCPRequest):
+ """MCP JSON-RPC endpoint."""
+ user = getattr(request.state, "user", None)
+ if not user:
+ raise HTTPException(status_code=401, detail="Not authenticated")
+
+ method = mcp_request.method
+ params = mcp_request.params or {}
+
+ try:
+ if method == "initialize":
+ result = {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {
+ "tools": {"listChanged": False},
+ "resources": {"subscribe": False, "listChanged": False},
+ },
+ "serverInfo": {
+ "name": "keboola-knowledge-base",
+ "version": "0.1.0",
+ },
+ }
+ elif method == "notifications/initialized":
+ return MCPResponse(id=mcp_request.id, result={})
+ elif method == "tools/list":
+ result = await handle_tools_list(user)
+ elif method == "tools/call":
+ result = await handle_tools_call(params, user)
+ elif method == "resources/list":
+ result = {"resources": []}
+ elif method == "resources/read":
+ return MCPResponse(
+ id=mcp_request.id,
+ error={"code": -32601, "message": "No resources available"},
+ )
+ elif method == "ping":
+ result = {}
+ else:
+ return MCPResponse(
+ id=mcp_request.id,
+ error={"code": -32601, "message": f"Method not found: {method}"},
+ )
+
+ return MCPResponse(id=mcp_request.id, result=result)
+
+ except HTTPException as e:
+ return MCPResponse(
+ id=mcp_request.id,
+ error={"code": -32000, "message": e.detail},
+ )
+ except Exception as e:
+ logger.exception(f"Error handling MCP request: {e}")
+ return MCPResponse(
+ id=mcp_request.id,
+ error={"code": -32603, "message": str(e)},
+ )
+
+
+async def handle_tools_list(user: dict) -> dict:
+ """Handle tools/list MCP method."""
+ user_scopes = user.get("scopes", [])
+ accessible_tools = get_tools_for_scopes(user_scopes)
+
+ return {
+ "tools": [
+ {
+ "name": tool.name,
+ "description": tool.description,
+ "inputSchema": tool.inputSchema,
+ }
+ for tool in accessible_tools
+ ]
+ }
+
+
+async def handle_tools_call(params: dict, user: dict) -> dict:
+ """Handle tools/call MCP method."""
+ tool_name = params.get("name")
+ arguments = params.get("arguments", {})
+
+ if not tool_name:
+ raise HTTPException(status_code=400, detail="Missing tool name")
+
+ # Check scope access
+ user_scopes = user.get("scopes", [])
+ required_scopes = TOOL_SCOPE_REQUIREMENTS.get(tool_name, ["kb.read"])
+
+ if not check_scope_access(required_scopes, user_scopes):
+ raise HTTPException(
+ status_code=403,
+ detail=f"Insufficient scope for tool: {tool_name}",
+ )
+
+ result = await execute_tool(tool_name, arguments, user)
+
+ return {
+ "content": [{"type": r.type, "text": r.text} for r in result],
+ }
+
+
+# =============================================================================
+# Main Entry Point
+# =============================================================================
+
+
+def main():
+ """Run MCP HTTP server."""
+ import uvicorn
+
+ # Configure logging
+ logging.basicConfig(
+ level=logging.DEBUG if mcp_settings.MCP_DEBUG else logging.INFO,
+ format='{"time": "%(asctime)s", "level": "%(levelname)s", "logger": "%(name)s", "message": "%(message)s"}',
+ force=True,
+ )
+
+ uvicorn.run(
+ "knowledge_base.mcp.server:app",
+ host=mcp_settings.MCP_HOST,
+ port=mcp_settings.MCP_PORT,
+ reload=mcp_settings.MCP_DEBUG,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/knowledge_base/mcp/tools.py b/src/knowledge_base/mcp/tools.py
new file mode 100644
index 0000000..f6da0f2
--- /dev/null
+++ b/src/knowledge_base/mcp/tools.py
@@ -0,0 +1,474 @@
+"""MCP tool definitions and execution dispatcher for Knowledge Base."""
+
+import json
+import logging
+import uuid
+from datetime import datetime
+from typing import Any
+from urllib.parse import quote
+
+from mcp.types import TextContent, Tool
+
+from knowledge_base.mcp.config import TOOL_SCOPE_REQUIREMENTS, check_scope_access
+
+logger = logging.getLogger(__name__)
+
+
+# =============================================================================
+# Tool Definitions
+# =============================================================================
+
+TOOLS = [
+ Tool(
+ name="ask_question",
+ description=(
+ "Ask a question and get an answer with sources from the Keboola knowledge base. "
+ "The answer is generated using RAG (retrieval-augmented generation) from indexed "
+ "Confluence pages, quick facts, and ingested documents."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {
+ "question": {
+ "type": "string",
+ "description": "The question to ask the knowledge base",
+ },
+ "conversation_history": {
+ "type": "array",
+ "description": "Optional previous messages for context continuity",
+ "items": {
+ "type": "object",
+ "properties": {
+ "role": {"type": "string", "enum": ["user", "assistant"]},
+ "content": {"type": "string"},
+ },
+ "required": ["role", "content"],
+ },
+ },
+ },
+ "required": ["question"],
+ },
+ ),
+ Tool(
+ name="search_knowledge",
+ description=(
+ "Search the Keboola knowledge base for documents matching a query. "
+ "Returns ranked results with titles, content snippets, scores, and Confluence URLs. "
+ "Uses hybrid search combining semantic similarity, keyword matching, and graph relationships."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "The search query",
+ },
+ "top_k": {
+ "type": "integer",
+ "description": "Number of results to return (1-20, default 5)",
+ "minimum": 1,
+ "maximum": 20,
+ "default": 5,
+ },
+ "filters": {
+ "type": "object",
+ "description": "Optional filters to narrow results",
+ "properties": {
+ "space_key": {
+ "type": "string",
+ "description": "Filter by Confluence space key",
+ },
+ "doc_type": {
+ "type": "string",
+ "description": "Filter by document type (e.g., webpage, pdf, quick_fact)",
+ },
+ "topics": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Filter by topics (any match)",
+ },
+ "updated_after": {
+ "type": "string",
+ "description": "Filter by update date (ISO format)",
+ },
+ },
+ },
+ },
+ "required": ["query"],
+ },
+ ),
+ Tool(
+ name="create_knowledge",
+ description=(
+ "Create a quick knowledge fact in the Keboola knowledge base. "
+ "The fact is indexed directly into Graphiti (Neo4j knowledge graph) "
+ "and becomes immediately searchable."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "The knowledge content to save",
+ },
+ "topics": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Optional topic tags for the knowledge",
+ },
+ },
+ "required": ["content"],
+ },
+ ),
+ Tool(
+ name="ingest_document",
+ description=(
+ "Ingest an external document into the Keboola knowledge base. "
+ "Supports web pages (HTML), PDFs, Google Docs (public/link-shared), "
+ "and Notion pages (public). The document is chunked and indexed into Graphiti."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {
+ "url": {
+ "type": "string",
+ "description": "URL of the document to ingest",
+ },
+ "title": {
+ "type": "string",
+ "description": "Optional title override for the document",
+ },
+ },
+ "required": ["url"],
+ },
+ ),
+ Tool(
+ name="submit_feedback",
+ description=(
+ "Submit feedback on a knowledge base chunk. This affects the chunk's "
+ "quality score: 'helpful' increases it, while 'outdated', 'incorrect', "
+ "and 'confusing' decrease it. Low-scoring chunks are eventually archived."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {
+ "chunk_id": {
+ "type": "string",
+ "description": "The ID of the chunk to give feedback on",
+ },
+ "feedback_type": {
+ "type": "string",
+ "enum": ["helpful", "outdated", "incorrect", "confusing"],
+ "description": "Type of feedback",
+ },
+ "details": {
+ "type": "string",
+ "description": "Optional details or correction suggestions",
+ },
+ },
+ "required": ["chunk_id", "feedback_type"],
+ },
+ ),
+ Tool(
+ name="check_health",
+ description=(
+ "Check the health of the Keboola knowledge base system. "
+ "Returns status of Neo4j graph database, LLM provider, and search subsystem."
+ ),
+ inputSchema={
+ "type": "object",
+ "properties": {},
+ },
+ ),
+]
+
+
+def get_tools_for_scopes(user_scopes: list[str]) -> list[Tool]:
+ """Return tools accessible to the user based on their scopes."""
+ accessible = []
+ for tool in TOOLS:
+ required = TOOL_SCOPE_REQUIREMENTS.get(tool.name, ["kb.read"])
+ if check_scope_access(required, user_scopes):
+ accessible.append(tool)
+ return accessible
+
+
+# =============================================================================
+# Tool Execution
+# =============================================================================
+
+
+async def execute_tool(
+ tool_name: str,
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute a tool and return results as TextContent list."""
+ logger.info(f"Executing tool: {tool_name}, user: {user.get('sub', 'unknown')}")
+
+ try:
+ if tool_name == "ask_question":
+ return await _execute_ask_question(arguments, user)
+ elif tool_name == "search_knowledge":
+ return await _execute_search_knowledge(arguments, user)
+ elif tool_name == "create_knowledge":
+ return await _execute_create_knowledge(arguments, user)
+ elif tool_name == "ingest_document":
+ return await _execute_ingest_document(arguments, user)
+ elif tool_name == "submit_feedback":
+ return await _execute_submit_feedback(arguments, user)
+ elif tool_name == "check_health":
+ return await _execute_check_health(arguments, user)
+ else:
+ return [TextContent(type="text", text=f"Unknown tool: {tool_name}")]
+ except Exception as e:
+ logger.error(f"Tool execution failed: {tool_name}: {e}", exc_info=True)
+ return [TextContent(type="text", text=f"Error executing {tool_name}: {str(e)}")]
+
+
+async def _execute_ask_question(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute ask_question tool."""
+ from knowledge_base.core.qa import generate_answer, search_knowledge
+
+ question = arguments["question"]
+ conversation_history = arguments.get("conversation_history")
+
+ # Search for relevant chunks
+ chunks = await search_knowledge(question, limit=5)
+
+ # Generate answer
+ answer = await generate_answer(question, chunks, conversation_history)
+
+ # Build sources section
+ sources = []
+ for chunk in chunks[:3]:
+ metadata = chunk.metadata if hasattr(chunk, "metadata") else {}
+ url = metadata.get("url", "")
+ title = chunk.page_title if hasattr(chunk, "page_title") else "Unknown"
+ if url:
+ sources.append(f"- [{title}]({url})")
+ else:
+ sources.append(f"- {title}")
+
+ result = answer
+ if sources:
+ result += "\n\nSources:\n" + "\n".join(sources)
+
+ return [TextContent(type="text", text=result)]
+
+
+async def _execute_search_knowledge(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute search_knowledge tool."""
+ from knowledge_base.core.qa import search_knowledge
+
+ query = arguments["query"]
+ top_k = arguments.get("top_k", 5)
+ filters = arguments.get("filters")
+
+ # Search
+ results = await search_knowledge(query, limit=top_k * 2 if filters else top_k)
+
+ # Apply filters if present
+ if filters:
+ results = _apply_filters(results, filters)
+
+ # Limit to requested count
+ results = results[:top_k]
+
+ if not results:
+ return [TextContent(type="text", text=f"No results found for: {query}")]
+
+ # Format results
+ lines = [f"Found {len(results)} results for: {query}\n"]
+ for i, r in enumerate(results, 1):
+ metadata = r.metadata if hasattr(r, "metadata") else {}
+ url = metadata.get("url", "")
+ title = r.page_title if hasattr(r, "page_title") else "Unknown"
+ content_preview = r.content[:200] + "..." if len(r.content) > 200 else r.content
+ score = f"{r.score:.3f}" if hasattr(r, "score") else "N/A"
+
+ lines.append(f"### {i}. {title}")
+ lines.append(f"**Score:** {score} | **Chunk ID:** {r.chunk_id}")
+ if url:
+ lines.append(f"**URL:** {url}")
+ lines.append(f"\n{content_preview}\n")
+
+ return [TextContent(type="text", text="\n".join(lines))]
+
+
+def _apply_filters(results: list, filters: dict) -> list:
+ """Apply metadata filters to search results."""
+ filtered = []
+ for r in results:
+ metadata = r.metadata if hasattr(r, "metadata") else {}
+
+ if "space_key" in filters and metadata.get("space_key") != filters["space_key"]:
+ continue
+ if "doc_type" in filters and metadata.get("doc_type") != filters["doc_type"]:
+ continue
+ if "topics" in filters:
+ result_topics = metadata.get("topics", [])
+ if isinstance(result_topics, str):
+ try:
+ result_topics = json.loads(result_topics)
+ except (json.JSONDecodeError, TypeError):
+ result_topics = [result_topics]
+ if not any(t in result_topics for t in filters["topics"]):
+ continue
+ if "updated_after" in filters:
+ updated_at = metadata.get("updated_at", "")
+ if updated_at and updated_at < filters["updated_after"]:
+ continue
+
+ filtered.append(r)
+ return filtered
+
+
+async def _execute_create_knowledge(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute create_knowledge tool."""
+ from knowledge_base.graph.graphiti_indexer import GraphitiIndexer
+ from knowledge_base.vectorstore.indexer import ChunkData
+
+ content = arguments["content"]
+ topics = arguments.get("topics", [])
+ user_email = user.get("email", "mcp-user")
+
+ # Create unique IDs
+ page_id = f"mcp_{uuid.uuid4().hex[:16]}"
+ chunk_id = f"{page_id}_0"
+ now = datetime.utcnow()
+
+ chunk_data = ChunkData(
+ chunk_id=chunk_id,
+ content=content,
+ page_id=page_id,
+ page_title=f"Quick Fact by {user_email}",
+ chunk_index=0,
+ space_key="MCP",
+ url=f"mcp://user/{quote(user_email, safe='')}",
+ author=user_email,
+ created_at=now.isoformat(),
+ updated_at=now.isoformat(),
+ chunk_type="text",
+ parent_headers="[]",
+ quality_score=100.0,
+ access_count=0,
+ feedback_count=0,
+ owner=user_email,
+ reviewed_by="",
+ reviewed_at="",
+ classification="internal",
+ doc_type="quick_fact",
+ topics=json.dumps(topics) if topics else "[]",
+ audience="[]",
+ complexity="",
+ summary=content[:200] if len(content) > 200 else content,
+ )
+
+ indexer = GraphitiIndexer()
+ await indexer.index_single_chunk(chunk_data)
+
+ logger.info(f"Created knowledge via MCP: {chunk_id} by {user_email}")
+
+ return [TextContent(
+ type="text",
+ text=f"Knowledge saved successfully.\n\n**Chunk ID:** {chunk_id}\n**Content:** {content[:200]}",
+ )]
+
+
+async def _execute_ingest_document(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute ingest_document tool."""
+ from knowledge_base.slack.ingest_doc import get_ingester
+
+ url = arguments["url"]
+ user_email = user.get("email", "mcp-user")
+
+ ingester = get_ingester()
+ result = await ingester.ingest_url(
+ url=url,
+ created_by=user_email,
+ channel_id="mcp",
+ )
+
+ if result["status"] == "success":
+ return [TextContent(
+ type="text",
+ text=(
+ f"Document ingested successfully.\n\n"
+ f"**Title:** {result['title']}\n"
+ f"**Source type:** {result['source_type']}\n"
+ f"**Chunks created:** {result['chunks_created']}\n"
+ f"**Page ID:** {result['page_id']}"
+ ),
+ )]
+ else:
+ return [TextContent(
+ type="text",
+ text=f"Failed to ingest document: {result.get('error', 'Unknown error')}",
+ )]
+
+
+async def _execute_submit_feedback(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute submit_feedback tool."""
+ from knowledge_base.lifecycle.feedback import submit_feedback
+
+ chunk_id = arguments["chunk_id"]
+ feedback_type = arguments["feedback_type"]
+ details = arguments.get("details")
+ user_email = user.get("email", "mcp-user")
+
+ feedback = await submit_feedback(
+ chunk_id=chunk_id,
+ slack_user_id=f"mcp:{user_email}",
+ slack_username=user_email,
+ feedback_type=feedback_type,
+ comment=details,
+ )
+
+ return [TextContent(
+ type="text",
+ text=(
+ f"Feedback submitted.\n\n"
+ f"**Chunk ID:** {chunk_id}\n"
+ f"**Feedback type:** {feedback_type}\n"
+ f"**Feedback ID:** {feedback.id}"
+ ),
+ )]
+
+
+async def _execute_check_health(
+ arguments: dict[str, Any],
+ user: dict[str, Any],
+) -> list[TextContent]:
+ """Execute check_health tool."""
+ from knowledge_base.search import HybridRetriever
+
+ retriever = HybridRetriever()
+ health = await retriever.check_health()
+
+ status = "healthy" if health.get("graphiti_healthy") else "degraded"
+
+ return [TextContent(
+ type="text",
+ text=(
+ f"Knowledge Base Health: **{status}**\n\n"
+ f"- Graphiti enabled: {health.get('graphiti_enabled', False)}\n"
+ f"- Graphiti healthy: {health.get('graphiti_healthy', False)}\n"
+ f"- Backend: {health.get('backend', 'unknown')}"
+ ),
+ )]
diff --git a/src/knowledge_base/slack/bot.py b/src/knowledge_base/slack/bot.py
index 5f674ce..7e81a72 100644
--- a/src/knowledge_base/slack/bot.py
+++ b/src/knowledge_base/slack/bot.py
@@ -23,8 +23,7 @@
process_thread_message,
record_bot_response,
)
-from knowledge_base.rag.factory import get_llm
-from knowledge_base.rag.exceptions import LLMError
+from knowledge_base.core.qa import search_knowledge, generate_answer
from knowledge_base.search.models import SearchResult
from knowledge_base.slack.modals import (
build_incorrect_feedback_modal,
@@ -316,40 +315,9 @@ def handle_thread_message(event: dict, say: Any, client: WebClient) -> None:
async def _search_chunks(query: str, limit: int = 5) -> list[SearchResult]:
"""Search for relevant chunks using Graphiti hybrid search.
- Uses HybridRetriever which delegates to Graphiti's unified search:
- - Semantic similarity (embeddings)
- - BM25 keyword matching
- - Graph relationships
-
- Returns SearchResult objects with content and metadata.
+ Delegates to core.qa.search_knowledge for reusability across interfaces.
"""
- logger.info(f"Searching for: '{query[:100]}...'")
-
- try:
- from knowledge_base.search import HybridRetriever
-
- retriever = HybridRetriever()
- health = await retriever.check_health()
- logger.info(f"Hybrid search health: {health}")
-
- # Use Graphiti hybrid search
- results = await retriever.search(query, k=limit)
- logger.info(f"Hybrid search returned {len(results)} results")
-
- # Log first result for debugging
- if results:
- first = results[0]
- logger.info(
- f"First result: chunk_id={first.chunk_id}, "
- f"title={first.page_title}, content_len={len(first.content)}"
- )
-
- return results
-
- except Exception as e:
- logger.error(f"Hybrid search FAILED (returning 0 results): {e}", exc_info=True)
-
- return []
+ return await search_knowledge(query, limit)
async def _generate_answer(
@@ -359,79 +327,9 @@ async def _generate_answer(
) -> str:
"""Generate an answer using LLM with retrieved chunks.
- Args:
- question: The user's question
- chunks: SearchResult objects from Graphiti containing content and metadata
- conversation_history: Previous messages in the conversation thread
+ Delegates to core.qa.generate_answer for reusability across interfaces.
"""
- if not chunks:
- return "I couldn't find relevant information in the knowledge base to answer your question."
-
- # Build context from chunks (SearchResult has page_title property and content attribute)
- context_parts = []
- for i, chunk in enumerate(chunks, 1):
- context_parts.append(
- f"[Source {i}: {chunk.page_title}]\n{chunk.content[:1000]}"
- )
- context = "\n\n---\n\n".join(context_parts)
-
- # Build conversation history section
- conversation_section = ""
- if conversation_history:
- history_parts = []
- for msg in conversation_history[-6:]: # Last 6 messages for context
- role = "User" if msg["role"] == "user" else "Assistant"
- # Truncate long messages in history
- content = msg["content"][:500] + "..." if len(msg["content"]) > 500 else msg["content"]
- history_parts.append(f"{role}: {content}")
- if history_parts:
- conversation_section = f"""
-PREVIOUS CONVERSATION:
-{chr(10).join(history_parts)}
-
-(Use this context to understand what the user is asking about and provide continuity)
-"""
-
- prompt = f"""You are Keboola's internal knowledge base assistant. Answer questions ONLY based on the provided context documents.
-
-CRITICAL RULES:
-- ONLY use information explicitly stated in the context documents below.
-- Do NOT make up, assume, or hallucinate any information not in the documents.
-- If the context doesn't contain enough information to answer, say so clearly.
-- When referencing information, mention which source it came from.
-{conversation_section}
-CONTEXT DOCUMENTS:
-{context}
-
-CURRENT QUESTION: {question}
-
-INSTRUCTIONS:
-- Answer based strictly on the context documents above.
-- Be concise and helpful. Use bullet points for multiple items.
-- If the documents only partially answer the question, share what IS available and note what's missing.
-- Do NOT invent tool names, process steps, or policies not mentioned in the documents.
-
-Provide your answer:"""
-
- try:
- llm = await get_llm()
- logger.info(f"Using LLM provider: {llm.provider_name}")
-
- # Skip health check - generate() has proper retry logic and error handling
- answer = await llm.generate(prompt)
- return answer.strip()
- except LLMError as e:
- logger.error(f"LLM provider error: {e}")
- return (
- f"I found {len(chunks)} relevant documents but couldn't generate "
- f"an answer at this time. Please try again later."
- )
- except Exception as e:
- logger.error(f"LLM generation failed: {e}")
- return (
- f"I found {len(chunks)} relevant documents but couldn't generate "
- f"an answer at this time. Please try again later."
- )
+ return await generate_answer(question, chunks, conversation_history)
def _split_text_into_blocks(text: str, max_chars: int = 3000) -> list[dict]:
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 0000000..37e861a
--- /dev/null
+++ b/tests/unit/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for MCP server modules."""
diff --git a/tests/unit/test_mcp_oauth.py b/tests/unit/test_mcp_oauth.py
new file mode 100644
index 0000000..78a9b28
--- /dev/null
+++ b/tests/unit/test_mcp_oauth.py
@@ -0,0 +1,234 @@
+"""Tests for MCP OAuth resource server and scope handling."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from knowledge_base.mcp.config import check_scope_access
+from knowledge_base.mcp.oauth.resource_server import (
+ OAuthResourceServer,
+ extract_user_context,
+)
+
+
+# ===========================================================================
+# extract_user_context
+# ===========================================================================
+
+
+class TestExtractUserContext:
+ """Test user context extraction from JWT claims."""
+
+ def test_google_verified_keboola_email_gets_read_and_write(self):
+ """Verified @keboola.com Google user should get kb.read + kb.write."""
+ claims = {
+ "iss": "https://accounts.google.com",
+ "sub": "google-uid-123",
+ "email": "alice@keboola.com",
+ "email_verified": True,
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["email"] == "alice@keboola.com"
+ assert ctx["sub"] == "google-uid-123"
+ assert "kb.read" in ctx["scopes"]
+ assert "kb.write" in ctx["scopes"]
+ # Standard OpenID scopes also present
+ assert "openid" in ctx["scopes"]
+ assert "email" in ctx["scopes"]
+ assert "profile" in ctx["scopes"]
+
+ def test_google_verified_external_email_gets_read_only(self):
+ """Verified external Google user should get kb.read but NOT kb.write."""
+ claims = {
+ "iss": "https://accounts.google.com",
+ "sub": "google-uid-456",
+ "email": "bob@external.com",
+ "email_verified": True,
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["email"] == "bob@external.com"
+ assert "kb.read" in ctx["scopes"]
+ assert "kb.write" not in ctx["scopes"]
+
+ def test_google_unverified_email_gets_no_scopes(self):
+ """Unverified Google user should get no application scopes."""
+ claims = {
+ "iss": "https://accounts.google.com",
+ "sub": "google-uid-789",
+ "email": "unverified@keboola.com",
+ "email_verified": False,
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["scopes"] == []
+
+ def test_google_missing_email_gets_no_scopes(self):
+ """Google user without email claim should get no scopes."""
+ claims = {
+ "iss": "https://accounts.google.com",
+ "sub": "google-uid-noemail",
+ "email_verified": True,
+ }
+ ctx = extract_user_context(claims)
+ # email is "" which is falsy, so no scopes granted
+ assert ctx["scopes"] == []
+
+ def test_non_google_claims_with_scope_string(self):
+ """Non-Google token with 'scope' claim should parse scopes from the string."""
+ claims = {
+ "iss": "https://some-other-idp.example.com",
+ "sub": "user-abc",
+ "email": "charlie@other.com",
+ "scope": "kb.read kb.write custom_scope",
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["scopes"] == ["kb.read", "kb.write", "custom_scope"]
+ assert ctx["email"] == "charlie@other.com"
+
+ def test_non_google_claims_with_empty_scope(self):
+ """Non-Google token with empty scope should produce empty scopes list."""
+ claims = {
+ "iss": "https://another-idp.example.com",
+ "sub": "user-xyz",
+ "email": "dave@corp.com",
+ "scope": "",
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["scopes"] == []
+
+ def test_claims_without_scope_or_google_issuer(self):
+ """Token without scope and without Google issuer should produce empty scopes."""
+ claims = {
+ "iss": "https://custom-auth.example.com",
+ "sub": "user-custom",
+ "email": "eve@custom.com",
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["scopes"] == []
+
+ def test_sub_fallback_for_email(self):
+ """When email claim is missing, email should fall back to sub."""
+ claims = {
+ "iss": "https://custom-auth.example.com",
+ "sub": "user-no-email",
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["email"] == "user-no-email"
+
+ def test_claims_are_preserved(self):
+ """Original claims dict should be available in the context."""
+ claims = {
+ "iss": "https://accounts.google.com",
+ "sub": "uid-1",
+ "email": "test@keboola.com",
+ "email_verified": True,
+ "aud": "test-client-id",
+ "exp": 9999999999,
+ }
+ ctx = extract_user_context(claims)
+ assert ctx["claims"] is claims
+
+
+# ===========================================================================
+# OAuthResourceServer
+# ===========================================================================
+
+
+class TestOAuthResourceServer:
+ """Test OAuthResourceServer initialization and properties."""
+
+ def test_initialization_creates_validator_and_metadata(self):
+ """OAuthResourceServer should initialize a validator and metadata."""
+ server = OAuthResourceServer(
+ resource="https://kb-mcp.example.com",
+ authorization_servers=["https://accounts.google.com"],
+ audience="test-client-id",
+ scopes_supported=["openid", "email", "profile"],
+ )
+ assert server.resource == "https://kb-mcp.example.com"
+ assert server.audience == "test-client-id"
+ assert server.metadata is not None
+ assert server.validator is not None
+
+ def test_metadata_resource_matches(self):
+ """Metadata resource field should match the server resource."""
+ server = OAuthResourceServer(
+ resource="https://kb-mcp.example.com",
+ authorization_servers=["https://accounts.google.com"],
+ audience="test-client-id",
+ )
+ meta_dict = server.metadata.to_dict()
+ assert meta_dict["resource"] == "https://kb-mcp.example.com"
+ assert "https://accounts.google.com" in meta_dict["authorization_servers"]
+
+ def test_metadata_includes_scopes(self):
+ """Metadata should include advertised scopes."""
+ server = OAuthResourceServer(
+ resource="https://kb-mcp.example.com",
+ authorization_servers=["https://accounts.google.com"],
+ audience="test-client-id",
+ scopes_supported=["openid", "email"],
+ )
+ meta_dict = server.metadata.to_dict()
+ assert "scopes_supported" in meta_dict
+ assert meta_dict["scopes_supported"] == ["openid", "email"]
+
+ def test_no_authorization_servers_no_validator(self):
+ """Without authorization servers, validator should not be created."""
+ server = OAuthResourceServer(
+ resource="https://kb-mcp.example.com",
+ authorization_servers=[],
+ audience="test-client-id",
+ )
+ with pytest.raises(RuntimeError, match="No authorization server configured"):
+ _ = server.validator
+
+ def test_google_issuer_sets_google_jwks_uri(self):
+ """Google issuer should configure Google's JWKS endpoint on the validator."""
+ server = OAuthResourceServer(
+ resource="https://kb-mcp.example.com",
+ authorization_servers=["https://accounts.google.com"],
+ audience="test-client-id",
+ )
+ assert server.validator.jwks_uri == "https://www.googleapis.com/oauth2/v3/certs"
+ assert server.validator.is_google is True
+
+
+# ===========================================================================
+# check_scope_access
+# ===========================================================================
+
+
+class TestCheckScopeAccess:
+ """Test scope access checking logic."""
+
+ def test_matching_scope_grants_access(self):
+ """If at least one required scope is in granted scopes, access is granted."""
+ assert check_scope_access(["kb.read"], ["kb.read", "kb.write"]) is True
+
+ def test_no_matching_scope_denies_access(self):
+ """If no required scope is in granted scopes, access is denied."""
+ assert check_scope_access(["kb.write"], ["kb.read"]) is False
+
+ def test_multiple_required_any_match(self):
+ """check_scope_access uses ANY (OR) logic for required scopes."""
+ assert check_scope_access(["kb.read", "kb.write"], ["kb.read"]) is True
+
+ def test_empty_required_denies(self):
+ """Empty required list means no scope matches -> deny."""
+ assert check_scope_access([], ["kb.read"]) is False
+
+ def test_empty_granted_denies(self):
+ """Empty granted list always denies."""
+ assert check_scope_access(["kb.read"], []) is False
+
+ def test_both_empty_denies(self):
+ """Both empty -> deny."""
+ assert check_scope_access([], []) is False
+
+ def test_exact_single_match(self):
+ """Single required scope matching single granted scope."""
+ assert check_scope_access(["kb.write"], ["kb.write"]) is True
+
+ def test_openid_scope_does_not_match_kb_read(self):
+ """Standard OpenID scopes should not satisfy kb.* requirements."""
+ assert check_scope_access(["kb.read"], ["openid", "email", "profile"]) is False
diff --git a/tests/unit/test_mcp_server.py b/tests/unit/test_mcp_server.py
new file mode 100644
index 0000000..8aafcaa
--- /dev/null
+++ b/tests/unit/test_mcp_server.py
@@ -0,0 +1,485 @@
+"""Tests for MCP HTTP server endpoints and JSON-RPC protocol."""
+
+import os
+import sys
+from dataclasses import dataclass
+from typing import Any
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+# Set required env vars BEFORE importing server module, because server.py
+# instantiates MCPSettings() at module level which requires these.
+os.environ.setdefault("MCP_OAUTH_CLIENT_ID", "test-client-id")
+os.environ.setdefault("MCP_OAUTH_CLIENT_SECRET", "test-client-secret")
+os.environ.setdefault("MCP_OAUTH_RESOURCE_IDENTIFIER", "https://test-kb-mcp.example.com")
+os.environ.setdefault("MCP_DEV_MODE", "true")
+
+from httpx import ASGITransport, AsyncClient
+
+from knowledge_base.mcp.server import app, mcp_settings # noqa: E402
+
+
+@dataclass
+class _FakeSearchResult:
+ """Minimal stand-in for SearchResult."""
+
+ chunk_id: str
+ content: str
+ score: float
+ metadata: dict[str, Any]
+
+ @property
+ def page_title(self) -> str:
+ return self.metadata.get("page_title", "")
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def dev_auth_header() -> dict[str, str]:
+ """Bearer token header -- value is irrelevant in dev mode."""
+ return {"Authorization": "Bearer dev-token-placeholder"}
+
+
+@pytest.fixture
+async def client():
+ """httpx AsyncClient wired to the FastAPI app with dev mode enabled."""
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ yield ac
+
+
+# ===========================================================================
+# Health & Metadata Endpoints
+# ===========================================================================
+
+
+class TestHealthEndpoint:
+ """Test GET /health."""
+
+ async def test_health_returns_200(self, client: AsyncClient):
+ resp = await client.get("/health")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["status"] == "healthy"
+ assert body["service"] == "knowledge-base-mcp-server"
+
+
+class TestOAuthMetadataEndpoint:
+ """Test GET /.well-known/oauth-protected-resource."""
+
+ async def test_returns_metadata(self, client: AsyncClient):
+ resp = await client.get("/.well-known/oauth-protected-resource")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "resource" in body
+ assert "authorization_servers" in body
+
+
+# ===========================================================================
+# OAuth Authorization Server Metadata
+# ===========================================================================
+
+
+class TestOAuthAuthorizationServerMetadata:
+ """Test GET /.well-known/oauth-authorization-server."""
+
+ async def test_returns_metadata(self, client: AsyncClient):
+ resp = await client.get("/.well-known/oauth-authorization-server")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert "issuer" in body
+ assert "authorization_endpoint" in body
+ assert "token_endpoint" in body
+ assert "registration_endpoint" in body
+ assert body["response_types_supported"] == ["code"]
+ assert "S256" in body["code_challenge_methods_supported"]
+
+ async def test_endpoints_use_base_url(self, client: AsyncClient):
+ resp = await client.get("/.well-known/oauth-authorization-server")
+ body = resp.json()
+ assert body["authorization_endpoint"].endswith("/authorize")
+ assert body["token_endpoint"].endswith("/token")
+ assert body["registration_endpoint"].endswith("/register")
+
+
+# ===========================================================================
+# OAuth Authorize Endpoint
+# ===========================================================================
+
+
+class TestOAuthAuthorize:
+ """Test GET /authorize - redirects to Google."""
+
+ async def test_redirects_to_google(self, client: AsyncClient):
+ resp = await client.get(
+ "/authorize",
+ params={
+ "response_type": "code",
+ "client_id": "test-client-id",
+ "redirect_uri": "https://claude.ai/api/mcp/auth_callback",
+ "scope": "claudeai",
+ "state": "test-state",
+ "code_challenge": "abc123",
+ "code_challenge_method": "S256",
+ },
+ follow_redirects=False,
+ )
+ assert resp.status_code == 302
+ location = resp.headers["location"]
+ assert "accounts.google.com" in location
+ # Should map claudeai scope to Google scopes
+ assert "openid" in location
+ assert "email" in location
+ assert "profile" in location
+ # Should NOT contain the custom scope
+ assert "claudeai" not in location
+ # Should preserve PKCE params
+ assert "code_challenge=abc123" in location
+ assert "state=test-state" in location
+
+ async def test_no_auth_required(self, client: AsyncClient):
+ """The /authorize endpoint must be accessible without Bearer token."""
+ resp = await client.get(
+ "/authorize",
+ params={"response_type": "code", "scope": "openid"},
+ follow_redirects=False,
+ )
+ assert resp.status_code == 302
+
+
+# ===========================================================================
+# OAuth Token Endpoint
+# ===========================================================================
+
+
+class TestOAuthToken:
+ """Test POST /token - proxies to Google."""
+
+ async def test_proxies_to_google(self, client: AsyncClient):
+ """Token endpoint should proxy to Google and return the response."""
+ mock_google_response = MagicMock()
+ mock_google_response.status_code = 200
+ mock_google_response.json.return_value = {
+ "access_token": "ya29.mock-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ "id_token": "eyJ...",
+ }
+
+ with patch("knowledge_base.mcp.server.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.post.return_value = mock_google_response
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_cls.return_value = mock_client
+
+ resp = await client.post(
+ "/token",
+ data={
+ "grant_type": "authorization_code",
+ "code": "test-auth-code",
+ "redirect_uri": "https://claude.ai/api/mcp/auth_callback",
+ "code_verifier": "test-verifier",
+ },
+ )
+
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["access_token"] == "ya29.mock-token"
+ assert body["token_type"] == "Bearer"
+
+ async def test_no_auth_required(self, client: AsyncClient):
+ """The /token endpoint must be accessible without Bearer token."""
+ mock_google_response = MagicMock()
+ mock_google_response.status_code = 400
+ mock_google_response.json.return_value = {"error": "invalid_grant"}
+
+ with patch("knowledge_base.mcp.server.httpx.AsyncClient") as mock_client_cls:
+ mock_client = AsyncMock()
+ mock_client.post.return_value = mock_google_response
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
+ mock_client.__aexit__ = AsyncMock(return_value=None)
+ mock_client_cls.return_value = mock_client
+
+ resp = await client.post(
+ "/token",
+ data={"grant_type": "authorization_code", "code": "bad-code"},
+ )
+
+ # Should NOT get 401 (auth not required), should get Google's error
+ assert resp.status_code == 400
+
+
+# ===========================================================================
+# OAuth Dynamic Client Registration
+# ===========================================================================
+
+
+class TestOAuthRegister:
+ """Test POST /register - returns our client_id."""
+
+ async def test_returns_client_id(self, client: AsyncClient):
+ resp = await client.post(
+ "/register",
+ json={
+ "redirect_uris": ["https://claude.ai/api/mcp/auth_callback"],
+ "client_name": "Claude.AI",
+ "grant_types": ["authorization_code"],
+ "response_types": ["code"],
+ "token_endpoint_auth_method": "none",
+ },
+ )
+ assert resp.status_code == 201
+ body = resp.json()
+ assert body["client_id"] == "test-client-id"
+ assert body["client_name"] == "Claude.AI"
+ assert "authorization_code" in body["grant_types"]
+
+ async def test_no_auth_required(self, client: AsyncClient):
+ """The /register endpoint must be accessible without Bearer token."""
+ resp = await client.post(
+ "/register",
+ json={"redirect_uris": ["https://example.com/callback"]},
+ )
+ assert resp.status_code == 201
+
+
+# ===========================================================================
+# Authentication
+# ===========================================================================
+
+
+class TestAuthentication:
+ """Test authentication middleware."""
+
+ async def test_post_mcp_without_auth_returns_401(self, client: AsyncClient):
+ """POST /mcp without Authorization header should return 401."""
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "ping", "id": 1},
+ )
+ assert resp.status_code == 401
+
+ async def test_post_mcp_with_invalid_auth_scheme_returns_401(self, client: AsyncClient):
+ """POST /mcp with non-Bearer auth should return 401."""
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "ping", "id": 1},
+ headers={"Authorization": "Basic dXNlcjpwYXNz"},
+ )
+ assert resp.status_code == 401
+
+
+# ===========================================================================
+# MCP Protocol: initialize
+# ===========================================================================
+
+
+class TestMCPInitialize:
+ """Test MCP initialize method."""
+
+ async def test_initialize_returns_protocol_version(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
+ headers=dev_auth_header,
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["jsonrpc"] == "2.0"
+ assert body["id"] == 1
+ result = body["result"]
+ assert result["protocolVersion"] == "2024-11-05"
+ assert "tools" in result["capabilities"]
+ assert result["serverInfo"]["name"] == "keboola-knowledge-base"
+
+
+# ===========================================================================
+# MCP Protocol: tools/list
+# ===========================================================================
+
+
+class TestMCPToolsList:
+ """Test MCP tools/list method."""
+
+ async def test_tools_list_returns_all_tools_in_dev_mode(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ """Dev mode grants all scopes, so all 6 tools should be listed."""
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "tools/list", "id": 2},
+ headers=dev_auth_header,
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ tools = body["result"]["tools"]
+ assert len(tools) == 6
+ names = {t["name"] for t in tools}
+ assert "ask_question" in names
+ assert "search_knowledge" in names
+ assert "create_knowledge" in names
+
+ async def test_tools_list_each_tool_has_schema(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ """Each tool in the list should have name, description, and inputSchema."""
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "tools/list", "id": 3},
+ headers=dev_auth_header,
+ )
+ tools = resp.json()["result"]["tools"]
+ for tool in tools:
+ assert "name" in tool
+ assert "description" in tool
+ assert "inputSchema" in tool
+
+
+class TestMCPToolsListScopeFiltering:
+ """Test that tools/list respects user scopes."""
+
+ async def test_read_only_user_sees_fewer_tools(self):
+ """A user with only kb.read should see only read tools."""
+ from knowledge_base.mcp import server as srv_module
+
+ original_fn = srv_module.handle_tools_list
+
+ async def _custom_handle(user):
+ # Force read-only scopes
+ user = {**user, "scopes": ["kb.read"]}
+ return await original_fn(user)
+
+ with patch.object(srv_module, "handle_tools_list", side_effect=_custom_handle):
+ transport = ASGITransport(app=app)
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
+ resp = await ac.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "tools/list", "id": 4},
+ headers={"Authorization": "Bearer dev-token"},
+ )
+
+ assert resp.status_code == 200
+ tools = resp.json()["result"]["tools"]
+ names = {t["name"] for t in tools}
+ assert names == {"ask_question", "search_knowledge", "check_health"}
+
+
+# ===========================================================================
+# MCP Protocol: tools/call
+# ===========================================================================
+
+
+class TestMCPToolsCall:
+ """Test MCP tools/call method."""
+
+ async def test_tools_call_search_knowledge(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ """tools/call search_knowledge should execute and return content."""
+ fake_results = [
+ _FakeSearchResult(
+ chunk_id="c1",
+ content="Keboola overview content",
+ score=0.9,
+ metadata={"page_title": "Overview", "url": "https://wiki.keboola.com"},
+ ),
+ ]
+
+ with patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=fake_results,
+ ):
+ resp = await client.post(
+ "/mcp",
+ json={
+ "jsonrpc": "2.0",
+ "method": "tools/call",
+ "params": {
+ "name": "search_knowledge",
+ "arguments": {"query": "Keboola overview"},
+ },
+ "id": 5,
+ },
+ headers=dev_auth_header,
+ )
+
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["id"] == 5
+ content = body["result"]["content"]
+ assert len(content) >= 1
+ assert content[0]["type"] == "text"
+ assert "Found 1 results" in content[0]["text"]
+
+ async def test_tools_call_missing_name_returns_error(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ """tools/call without a tool name should return an error."""
+ resp = await client.post(
+ "/mcp",
+ json={
+ "jsonrpc": "2.0",
+ "method": "tools/call",
+ "params": {"arguments": {}},
+ "id": 6,
+ },
+ headers=dev_auth_header,
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ # Should be an error response (code -32000 from HTTPException)
+ assert body["error"] is not None
+ assert body["error"]["code"] == -32000
+
+
+# ===========================================================================
+# MCP Protocol: ping
+# ===========================================================================
+
+
+class TestMCPPing:
+ """Test MCP ping method."""
+
+ async def test_ping_returns_empty_result(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "ping", "id": 7},
+ headers=dev_auth_header,
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["result"] == {}
+ assert body["id"] == 7
+
+
+# ===========================================================================
+# MCP Protocol: unknown method
+# ===========================================================================
+
+
+class TestMCPUnknownMethod:
+ """Test unknown MCP methods."""
+
+ async def test_unknown_method_returns_error_code(
+ self, client: AsyncClient, dev_auth_header: dict
+ ):
+ resp = await client.post(
+ "/mcp",
+ json={"jsonrpc": "2.0", "method": "totally/unknown", "id": 8},
+ headers=dev_auth_header,
+ )
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["error"]["code"] == -32601
+ assert "Method not found" in body["error"]["message"]
diff --git a/tests/unit/test_mcp_tools.py b/tests/unit/test_mcp_tools.py
new file mode 100644
index 0000000..49fdc9d
--- /dev/null
+++ b/tests/unit/test_mcp_tools.py
@@ -0,0 +1,552 @@
+"""Tests for MCP tool definitions and execution dispatcher."""
+
+import json
+from dataclasses import dataclass
+from typing import Any
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from mcp.types import TextContent
+
+from knowledge_base.mcp.tools import (
+ TOOLS,
+ _apply_filters,
+ execute_tool,
+ get_tools_for_scopes,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+@dataclass
+class _FakeSearchResult:
+ """Minimal stand-in for SearchResult used in tool execution tests."""
+
+ chunk_id: str
+ content: str
+ score: float
+ metadata: dict[str, Any]
+
+ @property
+ def page_title(self) -> str:
+ return self.metadata.get("page_title", "")
+
+
+def _make_result(
+ chunk_id: str = "c1",
+ content: str = "some content",
+ score: float = 0.9,
+ page_title: str = "My Page",
+ url: str = "https://example.com",
+ space_key: str = "ENG",
+ doc_type: str = "webpage",
+ topics: list[str] | None = None,
+ updated_at: str = "2026-01-15T00:00:00Z",
+) -> _FakeSearchResult:
+ return _FakeSearchResult(
+ chunk_id=chunk_id,
+ content=content,
+ score=score,
+ metadata={
+ "page_title": page_title,
+ "url": url,
+ "space_key": space_key,
+ "doc_type": doc_type,
+ "topics": json.dumps(topics or []),
+ "updated_at": updated_at,
+ },
+ )
+
+
+_USER_INTERNAL = {"email": "alice@keboola.com", "scopes": ["kb.read", "kb.write"]}
+_USER_EXTERNAL = {"email": "bob@external.com", "scopes": ["kb.read"]}
+
+
+# ===========================================================================
+# get_tools_for_scopes
+# ===========================================================================
+
+
+class TestGetToolsForScopes:
+ """Test scope-based tool filtering."""
+
+ def test_read_scopes_return_read_tools(self):
+ """kb.read should return ask_question, search_knowledge, check_health."""
+ tools = get_tools_for_scopes(["kb.read"])
+ names = {t.name for t in tools}
+ assert names == {"ask_question", "search_knowledge", "check_health"}
+
+ def test_write_scopes_return_write_tools(self):
+ """kb.write should return create_knowledge, ingest_document, submit_feedback."""
+ tools = get_tools_for_scopes(["kb.write"])
+ names = {t.name for t in tools}
+ assert names == {"create_knowledge", "ingest_document", "submit_feedback"}
+
+ def test_both_scopes_return_all_tools(self):
+ """kb.read + kb.write should return all 6 tools."""
+ tools = get_tools_for_scopes(["kb.read", "kb.write"])
+ names = {t.name for t in tools}
+ assert len(names) == 6
+ assert names == {
+ "ask_question",
+ "search_knowledge",
+ "check_health",
+ "create_knowledge",
+ "ingest_document",
+ "submit_feedback",
+ }
+
+ def test_empty_scopes_return_no_tools(self):
+ """Empty scope list should return no tools."""
+ tools = get_tools_for_scopes([])
+ assert tools == []
+
+ def test_irrelevant_scopes_return_no_tools(self):
+ """Scopes like openid/email should not grant access to any tools."""
+ tools = get_tools_for_scopes(["openid", "email", "profile"])
+ assert tools == []
+
+ def test_returned_tools_are_tool_instances(self):
+ """Each returned item should be a Tool with name and inputSchema."""
+ tools = get_tools_for_scopes(["kb.read"])
+ for tool in tools:
+ assert hasattr(tool, "name")
+ assert hasattr(tool, "inputSchema")
+
+ def test_tool_definitions_count(self):
+ """TOOLS list should contain exactly 6 tool definitions."""
+ assert len(TOOLS) == 6
+
+
+# ===========================================================================
+# _apply_filters
+# ===========================================================================
+
+
+class TestApplyFilters:
+ """Test metadata-based filtering of search results."""
+
+ def test_no_filters_returns_all(self):
+ """Empty filter dict returns all results unchanged."""
+ results = [_make_result(chunk_id="c1"), _make_result(chunk_id="c2")]
+ filtered = _apply_filters(results, {})
+ assert len(filtered) == 2
+
+ def test_space_key_filter(self):
+ """Filter by space_key keeps only matching results."""
+ results = [
+ _make_result(chunk_id="c1", space_key="ENG"),
+ _make_result(chunk_id="c2", space_key="SALES"),
+ ]
+ filtered = _apply_filters(results, {"space_key": "ENG"})
+ assert len(filtered) == 1
+ assert filtered[0].chunk_id == "c1"
+
+ def test_doc_type_filter(self):
+ """Filter by doc_type keeps only matching results."""
+ results = [
+ _make_result(chunk_id="c1", doc_type="webpage"),
+ _make_result(chunk_id="c2", doc_type="pdf"),
+ _make_result(chunk_id="c3", doc_type="quick_fact"),
+ ]
+ filtered = _apply_filters(results, {"doc_type": "pdf"})
+ assert len(filtered) == 1
+ assert filtered[0].chunk_id == "c2"
+
+ def test_topics_filter_json_array(self):
+ """Filter by topics works with JSON-encoded topic arrays."""
+ results = [
+ _make_result(chunk_id="c1", topics=["deployment", "ci"]),
+ _make_result(chunk_id="c2", topics=["billing"]),
+ _make_result(chunk_id="c3", topics=["deployment", "security"]),
+ ]
+ filtered = _apply_filters(results, {"topics": ["deployment"]})
+ assert len(filtered) == 2
+ ids = {r.chunk_id for r in filtered}
+ assert ids == {"c1", "c3"}
+
+ def test_topics_filter_string_fallback(self):
+ """When topics is a plain string (not JSON), treat it as a single-element list."""
+ result = _FakeSearchResult(
+ chunk_id="c1",
+ content="text",
+ score=0.5,
+ metadata={"topics": "security"},
+ )
+ filtered = _apply_filters([result], {"topics": ["security"]})
+ assert len(filtered) == 1
+
+ def test_topics_filter_no_match(self):
+ """If none of the requested topics match, the result is excluded."""
+ results = [_make_result(chunk_id="c1", topics=["billing"])]
+ filtered = _apply_filters(results, {"topics": ["deployment"]})
+ assert len(filtered) == 0
+
+ def test_updated_after_filter(self):
+ """Filter by updated_after keeps only results updated after the date."""
+ results = [
+ _make_result(chunk_id="c1", updated_at="2026-01-01T00:00:00Z"),
+ _make_result(chunk_id="c2", updated_at="2026-02-01T00:00:00Z"),
+ ]
+ filtered = _apply_filters(results, {"updated_after": "2026-01-15T00:00:00Z"})
+ assert len(filtered) == 1
+ assert filtered[0].chunk_id == "c2"
+
+ def test_combined_filters(self):
+ """Multiple filters are applied together (AND logic)."""
+ results = [
+ _make_result(chunk_id="c1", space_key="ENG", doc_type="webpage"),
+ _make_result(chunk_id="c2", space_key="ENG", doc_type="pdf"),
+ _make_result(chunk_id="c3", space_key="SALES", doc_type="webpage"),
+ ]
+ filtered = _apply_filters(results, {"space_key": "ENG", "doc_type": "webpage"})
+ assert len(filtered) == 1
+ assert filtered[0].chunk_id == "c1"
+
+ def test_filter_result_with_no_metadata(self):
+ """Results without metadata attribute should not crash."""
+ result = MagicMock(spec=[]) # no attributes at all
+ # _apply_filters checks hasattr(r, "metadata")
+ filtered = _apply_filters([result], {"space_key": "ENG"})
+ # Without metadata, space_key check against {} will fail -> excluded
+ assert len(filtered) == 0
+
+
+# ===========================================================================
+# execute_tool
+# ===========================================================================
+
+
+class TestExecuteToolAskQuestion:
+ """Test ask_question tool execution."""
+
+ async def test_ask_question_returns_text_content(self):
+ """ask_question should return a list of TextContent with the answer and sources."""
+ fake_chunks = [
+ _make_result(
+ chunk_id="c1",
+ content="Keboola is a data platform.",
+ page_title="About Keboola",
+ url="https://wiki.keboola.com/about",
+ ),
+ ]
+ with (
+ patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=fake_chunks,
+ ),
+ patch(
+ "knowledge_base.core.qa.generate_answer",
+ new_callable=AsyncMock,
+ return_value="Keboola is a data platform.",
+ ),
+ ):
+ result = await execute_tool(
+ "ask_question",
+ {"question": "What is Keboola?"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ assert isinstance(result[0], TextContent)
+ assert "Keboola is a data platform." in result[0].text
+ assert "Sources:" in result[0].text
+ assert "About Keboola" in result[0].text
+
+ async def test_ask_question_no_sources(self):
+ """When search returns no chunks, sources section is absent."""
+ with (
+ patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=[],
+ ),
+ patch(
+ "knowledge_base.core.qa.generate_answer",
+ new_callable=AsyncMock,
+ return_value="No information found.",
+ ),
+ ):
+ result = await execute_tool(
+ "ask_question",
+ {"question": "Something obscure?"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ assert "Sources:" not in result[0].text
+
+ async def test_ask_question_with_conversation_history(self):
+ """Conversation history should be passed through to generate_answer."""
+ history = [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi, how can I help?"},
+ ]
+ mock_generate = AsyncMock(return_value="Follow-up answer.")
+ with (
+ patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=[],
+ ),
+ patch(
+ "knowledge_base.core.qa.generate_answer",
+ mock_generate,
+ ),
+ ):
+ await execute_tool(
+ "ask_question",
+ {"question": "Follow-up?", "conversation_history": history},
+ _USER_INTERNAL,
+ )
+
+ # generate_answer should have been called with conversation_history
+ mock_generate.assert_called_once()
+ call_kwargs = mock_generate.call_args
+ assert call_kwargs[0][2] == history # third positional arg
+
+
+class TestExecuteToolSearchKnowledge:
+ """Test search_knowledge tool execution."""
+
+ async def test_search_returns_formatted_results(self):
+ """search_knowledge should return formatted markdown-like results."""
+ fake_results = [
+ _make_result(
+ chunk_id="c1",
+ content="First result content",
+ score=0.95,
+ page_title="Page One",
+ url="https://wiki.keboola.com/page1",
+ ),
+ _make_result(
+ chunk_id="c2",
+ content="Second result content",
+ score=0.85,
+ page_title="Page Two",
+ url="https://wiki.keboola.com/page2",
+ ),
+ ]
+ with patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=fake_results,
+ ):
+ result = await execute_tool(
+ "search_knowledge",
+ {"query": "test query"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ text = result[0].text
+ assert "Found 2 results" in text
+ assert "Page One" in text
+ assert "Page Two" in text
+ assert "0.950" in text
+ assert "c1" in text
+
+ async def test_search_no_results(self):
+ """When no results found, return appropriate message."""
+ with patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=[],
+ ):
+ result = await execute_tool(
+ "search_knowledge",
+ {"query": "nonexistent stuff"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ assert "No results found" in result[0].text
+
+ async def test_search_respects_top_k(self):
+ """top_k limits the number of returned results."""
+ many_results = [_make_result(chunk_id=f"c{i}") for i in range(10)]
+ with patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ return_value=many_results,
+ ):
+ result = await execute_tool(
+ "search_knowledge",
+ {"query": "test", "top_k": 3},
+ _USER_INTERNAL,
+ )
+
+ text = result[0].text
+ assert "Found 3 results" in text
+
+
+class TestExecuteToolCreateKnowledge:
+ """Test create_knowledge tool execution."""
+
+ async def test_create_knowledge_success(self):
+ """create_knowledge should index a chunk and return confirmation."""
+ mock_indexer = AsyncMock()
+ mock_indexer.index_single_chunk = AsyncMock()
+
+ with patch(
+ "knowledge_base.graph.graphiti_indexer.GraphitiIndexer",
+ return_value=mock_indexer,
+ ):
+ result = await execute_tool(
+ "create_knowledge",
+ {"content": "Neo4j is a graph database.", "topics": ["databases", "graphs"]},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ text = result[0].text
+ assert "Knowledge saved successfully" in text
+ assert "Neo4j is a graph database." in text
+ mock_indexer.index_single_chunk.assert_called_once()
+
+ # Verify the ChunkData passed to the indexer
+ chunk_data = mock_indexer.index_single_chunk.call_args[0][0]
+ assert chunk_data.content == "Neo4j is a graph database."
+ assert chunk_data.space_key == "MCP"
+ assert chunk_data.doc_type == "quick_fact"
+ assert "alice@keboola.com" in chunk_data.page_title
+
+ async def test_create_knowledge_default_topics(self):
+ """create_knowledge without topics should use empty list."""
+ mock_indexer = AsyncMock()
+ mock_indexer.index_single_chunk = AsyncMock()
+
+ with patch(
+ "knowledge_base.graph.graphiti_indexer.GraphitiIndexer",
+ return_value=mock_indexer,
+ ):
+ await execute_tool(
+ "create_knowledge",
+ {"content": "A fact."},
+ _USER_INTERNAL,
+ )
+
+ chunk_data = mock_indexer.index_single_chunk.call_args[0][0]
+ assert chunk_data.topics == "[]"
+
+
+class TestExecuteToolSubmitFeedback:
+ """Test submit_feedback tool execution."""
+
+ async def test_submit_feedback_success(self):
+ """submit_feedback should call lifecycle.feedback.submit_feedback."""
+ mock_feedback = MagicMock()
+ mock_feedback.id = 42
+
+ with patch(
+ "knowledge_base.lifecycle.feedback.submit_feedback",
+ new_callable=AsyncMock,
+ return_value=mock_feedback,
+ ):
+ result = await execute_tool(
+ "submit_feedback",
+ {"chunk_id": "c123", "feedback_type": "helpful", "details": "Great content!"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ text = result[0].text
+ assert "Feedback submitted" in text
+ assert "c123" in text
+ assert "helpful" in text
+ assert "42" in text
+
+ async def test_submit_feedback_without_details(self):
+ """submit_feedback should work without optional details."""
+ mock_feedback = MagicMock()
+ mock_feedback.id = 7
+
+ with patch(
+ "knowledge_base.lifecycle.feedback.submit_feedback",
+ new_callable=AsyncMock,
+ return_value=mock_feedback,
+ ):
+ result = await execute_tool(
+ "submit_feedback",
+ {"chunk_id": "c999", "feedback_type": "outdated"},
+ _USER_INTERNAL,
+ )
+
+ assert "Feedback submitted" in result[0].text
+
+
+class TestExecuteToolCheckHealth:
+ """Test check_health tool execution."""
+
+ async def test_check_health_healthy(self):
+ """check_health should return healthy status when graphiti is up."""
+ mock_retriever = AsyncMock()
+ mock_retriever.check_health = AsyncMock(
+ return_value={
+ "graphiti_enabled": True,
+ "graphiti_healthy": True,
+ "backend": "graphiti",
+ }
+ )
+
+ with patch(
+ "knowledge_base.search.HybridRetriever",
+ return_value=mock_retriever,
+ ):
+ result = await execute_tool("check_health", {}, _USER_INTERNAL)
+
+ assert len(result) == 1
+ text = result[0].text
+ assert "healthy" in text
+ assert "Graphiti enabled: True" in text
+ assert "Graphiti healthy: True" in text
+
+ async def test_check_health_degraded(self):
+ """check_health should return degraded when graphiti is not healthy."""
+ mock_retriever = AsyncMock()
+ mock_retriever.check_health = AsyncMock(
+ return_value={
+ "graphiti_enabled": True,
+ "graphiti_healthy": False,
+ "backend": "graphiti",
+ }
+ )
+
+ with patch(
+ "knowledge_base.search.HybridRetriever",
+ return_value=mock_retriever,
+ ):
+ result = await execute_tool("check_health", {}, _USER_INTERNAL)
+
+ text = result[0].text
+ assert "degraded" in text
+
+
+class TestExecuteToolUnknown:
+ """Test unknown tool execution and error handling."""
+
+ async def test_unknown_tool_returns_error(self):
+ """Unknown tool name should return an error TextContent."""
+ result = await execute_tool("nonexistent_tool", {}, _USER_INTERNAL)
+ assert len(result) == 1
+ assert "Unknown tool: nonexistent_tool" in result[0].text
+
+ async def test_tool_execution_exception_returns_error(self):
+ """If a tool raises an exception, it should be caught and returned as text."""
+ with patch(
+ "knowledge_base.core.qa.search_knowledge",
+ new_callable=AsyncMock,
+ side_effect=RuntimeError("connection lost"),
+ ):
+ result = await execute_tool(
+ "ask_question",
+ {"question": "test"},
+ _USER_INTERNAL,
+ )
+
+ assert len(result) == 1
+ assert "Error executing ask_question" in result[0].text
+ assert "connection lost" in result[0].text