diff --git a/src/llm.ts b/src/llm.ts index 46c6295..4cb9d3e 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -502,8 +502,20 @@ export class LlamaCpp implements LLM { // (likely a binary/build config issue in node-llama-cpp). // @ts-expect-error node-llama-cpp API compat const gpuTypes = await getLlamaGpuTypes(); - // Prefer CUDA > Metal > Vulkan > CPU - const preferred = (["cuda", "metal", "vulkan"] as const).find(g => gpuTypes.includes(g)); + + // QMD_GPU: override GPU backend. Values: cuda, metal, vulkan, false/off/0 + const gpuOverride = process.env.QMD_GPU?.toLowerCase(); + const forceCpu = gpuOverride === "false" || gpuOverride === "0" || gpuOverride === "off"; + + let preferred: "cuda" | "metal" | "vulkan" | undefined; + if (forceCpu) { + preferred = undefined; + } else if (gpuOverride && (["cuda", "metal", "vulkan"] as const).some(g => g === gpuOverride)) { + preferred = gpuOverride as "cuda" | "metal" | "vulkan"; + } else { + // Prefer CUDA > Metal > Vulkan > CPU + preferred = (["cuda", "metal", "vulkan"] as const).find(g => gpuTypes.includes(g)); + } let llama: Llama; if (preferred) { diff --git a/src/mcp.ts b/src/mcp.ts index 323f469..5cd1608 100644 --- a/src/mcp.ts +++ b/src/mcp.ts @@ -88,18 +88,27 @@ function formatSearchSummary(results: SearchResultItem[], query: string): string * Injected into the LLM's system prompt via MCP initialize response — * gives the LLM immediate context about what's searchable without a tool call. */ -function buildInstructions(store: Store): string { +function buildInstructions(store: Store, defaultCollection?: string): string { const status = store.getStatus(); const lines: string[] = []; // --- What is this? --- const globalCtx = getGlobalContext(); - lines.push(`QMD is your local search engine over ${status.totalDocuments} markdown documents.`); + if (defaultCollection) { + const col = status.collections.find(c => c.name === defaultCollection); + const docCount = col?.documents ?? status.totalDocuments; + lines.push(`QMD is your local search engine over ${docCount} markdown documents (collection: "${defaultCollection}").`); + } else { + lines.push(`QMD is your local search engine over ${status.totalDocuments} markdown documents.`); + } if (globalCtx) lines.push(`Context: ${globalCtx}`); // --- What's searchable? --- if (status.collections.length > 0) { lines.push(""); + if (defaultCollection) { + lines.push(`Default collection: "${defaultCollection}".`); + } lines.push("Collections (scope with `collection` parameter):"); for (const col of status.collections) { const collConfig = getCollection(col.name); @@ -150,10 +159,13 @@ function buildInstructions(store: Store): string { * Create an MCP server with all QMD tools, resources, and prompts registered. * Shared by both stdio and HTTP transports. */ -function createMcpServer(store: Store): McpServer { +function createMcpServer(store: Store, collectionOverride?: string): McpServer { + // Collection scoping: URL path (/mcp/RAMP) > QMD_COLLECTION env var > none + const defaultCollection = collectionOverride || process.env.QMD_COLLECTION || undefined; + const server = new McpServer( { name: "qmd", version: "0.9.9" }, - { instructions: buildInstructions(store) }, + { instructions: buildInstructions(store, defaultCollection) }, ); // --------------------------------------------------------------------------- @@ -320,8 +332,9 @@ Intent-aware lex (C++ performance, not sports): query: s.query, })); - // Use default collections if none specified - const effectiveCollections = collections ?? getDefaultCollectionNames(); + // Use default collections if none specified; URL-scoped collection takes priority + const effectiveCollections = collections + ?? (defaultCollection ? [defaultCollection] : getDefaultCollectionNames()); const results = await structuredSearch(store, subSearches, { collections: effectiveCollections.length > 0 ? effectiveCollections : undefined, @@ -552,15 +565,32 @@ export type HttpServerHandle = { /** * Start MCP server over Streamable HTTP (JSON responses, no SSE). * Binds to localhost only. Returns a handle for shutdown and port discovery. + * + * Creates a new transport + McpServer per client session so multiple + * clients (or the same client reconnecting) can initialize independently. + * All sessions share the same SQLite store. */ export async function startMcpHttpServer(port: number, options?: { quiet?: boolean }): Promise { const store = createStore(); - const mcpServer = createMcpServer(store); - const transport = new WebStandardStreamableHTTPServerTransport({ - sessionIdGenerator: () => randomUUID(), - enableJsonResponse: true, - }); - await mcpServer.connect(transport); + + // Per-session state: each client connection gets its own transport + McpServer + const sessions = new Map(); + + /** Spin up a new transport + McpServer for a fresh client session. */ + async function createSession(collection?: string): Promise { + const transport = new WebStandardStreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableJsonResponse: true, + onsessioninitialized: (sessionId: string) => { + sessions.set(sessionId, transport); + const scope = collection ? ` [${collection}]` : ""; + log(`${ts()} session ${sessionId.slice(0, 8)}${scope} created (${sessions.size} active)`); + }, + }); + const server = createMcpServer(store, collection); + await server.connect(transport); + return transport; + } const startTime = Date.now(); const quiet = options?.quiet ?? false; @@ -599,25 +629,62 @@ export async function startMcpHttpServer(port: number, options?: { quiet?: boole return Buffer.concat(chunks).toString(); } + /** Build a Record from node IncomingMessage headers. */ + function extractHeaders(nodeReq: IncomingMessage): Record { + const h: Record = {}; + for (const [k, v] of Object.entries(nodeReq.headers)) { + if (typeof v === "string") h[k] = v; + } + return h; + } + + /** + * Extract collection name from URL path: /mcp/RAMP → "RAMP", /mcp → undefined. + */ + function collectionFromPath(pathname: string): string | undefined { + const match = pathname.match(/^\/mcp\/([^/?]+)/); + return match?.[1] ? decodeURIComponent(match[1]) : undefined; + } + + /** Look up the transport for a request, creating a new session when needed. */ + async function resolveTransport( + nodeReq: IncomingMessage, + body?: any, + ): Promise { + const sessionId = nodeReq.headers["mcp-session-id"]; + if (typeof sessionId === "string" && sessions.has(sessionId)) { + return sessions.get(sessionId)!; + } + // No (valid) session ID — only POST is allowed (new initialize request). + if (nodeReq.method === "POST" && body?.method === "initialize") { + const collection = collectionFromPath(nodeReq.url || ""); + return createSession(collection); + } + return null; + } + const httpServer = createServer(async (nodeReq: IncomingMessage, nodeRes: ServerResponse) => { const reqStart = Date.now(); const pathname = nodeReq.url || "/"; try { if (pathname === "/health" && nodeReq.method === "GET") { - const body = JSON.stringify({ status: "ok", uptime: Math.floor((Date.now() - startTime) / 1000) }); + const body = JSON.stringify({ + status: "ok", + uptime: Math.floor((Date.now() - startTime) / 1000), + sessions: sessions.size, + }); nodeRes.writeHead(200, { "Content-Type": "application/json" }); nodeRes.end(body); log(`${ts()} GET /health (${Date.now() - reqStart}ms)`); return; } - // REST endpoint: POST /search — structured search without MCP protocol // REST endpoint: POST /query (alias: /search) — structured search without MCP protocol if ((pathname === "/query" || pathname === "/search") && nodeReq.method === "POST") { const rawBody = await collectBody(nodeReq); const params = JSON.parse(rawBody); - + // Validate required fields if (!params.searches || !Array.isArray(params.searches)) { nodeRes.writeHead(400, { "Content-Type": "application/json" }); @@ -663,16 +730,23 @@ export async function startMcpHttpServer(port: number, options?: { quiet?: boole return; } - if (pathname === "/mcp" && nodeReq.method === "POST") { + if (pathname.startsWith("/mcp") && nodeReq.method === "POST") { const rawBody = await collectBody(nodeReq); const body = JSON.parse(rawBody); const label = describeRequest(body); - const url = `http://localhost:${port}${pathname}`; - const headers: Record = {}; - for (const [k, v] of Object.entries(nodeReq.headers)) { - if (typeof v === "string") headers[k] = v; + const transport = await resolveTransport(nodeReq, body); + if (!transport) { + nodeRes.writeHead(400, { "Content-Type": "application/json" }); + nodeRes.end(JSON.stringify({ + jsonrpc: "2.0", + error: { code: -32000, message: "Bad Request: missing or invalid session" }, + id: body?.id ?? null, + })); + log(`${ts()} POST /mcp ${label} → 400 bad session (${Date.now() - reqStart}ms)`); + return; } - const request = new Request(url, { method: "POST", headers, body: rawBody }); + const url = `http://localhost:${port}${pathname}`; + const request = new Request(url, { method: "POST", headers: extractHeaders(nodeReq), body: rawBody }); const response = await transport.handleRequest(request, { parsedBody: body }); nodeRes.writeHead(response.status, Object.fromEntries(response.headers)); nodeRes.end(Buffer.from(await response.arrayBuffer())); @@ -680,14 +754,18 @@ export async function startMcpHttpServer(port: number, options?: { quiet?: boole return; } - if (pathname === "/mcp") { - const url = `http://localhost:${port}${pathname}`; - const headers: Record = {}; - for (const [k, v] of Object.entries(nodeReq.headers)) { - if (typeof v === "string") headers[k] = v; + if (pathname.startsWith("/mcp")) { + // GET (SSE) or DELETE — require existing session + const sessionId = nodeReq.headers["mcp-session-id"]; + const transport = typeof sessionId === "string" ? sessions.get(sessionId) : undefined; + if (!transport) { + nodeRes.writeHead(400, { "Content-Type": "application/json" }); + nodeRes.end(JSON.stringify({ error: "Missing or invalid mcp-session-id header" })); + return; } + const url = `http://localhost:${port}${pathname}`; const rawBody = nodeReq.method !== "GET" && nodeReq.method !== "HEAD" ? await collectBody(nodeReq) : undefined; - const request = new Request(url, { method: nodeReq.method || "GET", headers, ...(rawBody ? { body: rawBody } : {}) }); + const request = new Request(url, { method: nodeReq.method || "GET", headers: extractHeaders(nodeReq), ...(rawBody ? { body: rawBody } : {}) }); const response = await transport.handleRequest(request); nodeRes.writeHead(response.status, Object.fromEntries(response.headers)); nodeRes.end(Buffer.from(await response.arrayBuffer())); @@ -714,7 +792,8 @@ export async function startMcpHttpServer(port: number, options?: { quiet?: boole const stop = async () => { if (stopping) return; stopping = true; - await transport.close(); + for (const t of sessions.values()) await t.close().catch(() => {}); + sessions.clear(); httpServer.close(); store.close(); await disposeDefaultLlamaCpp(); diff --git a/src/store.ts b/src/store.ts index ff08c2a..b5c037e 100644 --- a/src/store.ts +++ b/src/store.ts @@ -16,6 +16,7 @@ import type { Database } from "./db.js"; import picomatch from "picomatch"; import { createHash } from "crypto"; import { realpathSync, statSync, mkdirSync } from "node:fs"; +import { homedir as osHomedir } from "node:os"; import { LlamaCpp, getDefaultLlamaCpp, @@ -43,7 +44,7 @@ import { // Configuration // ============================================================================= -const HOME = process.env.HOME || "/tmp"; +const HOME = osHomedir(); export const DEFAULT_EMBED_MODEL = "embeddinggemma"; export const DEFAULT_RERANK_MODEL = "ExpedientFalcon/qwen3-reranker:0.6b-q8_0"; export const DEFAULT_QUERY_MODEL = "Qwen/Qwen3-1.7B";