diff --git a/.github/workflows/hugo.yaml b/.github/workflows/hugo.yaml deleted file mode 100644 index 8175ffaa3..000000000 --- a/.github/workflows/hugo.yaml +++ /dev/null @@ -1,179 +0,0 @@ -# GitHub Actions workflow for building and deploying Hugo site to GitHub Pages -name: Deploy Hugo site to Pages - -on: - # Runs on pushes targeting the default branch - push: - branches: - - master - - # Runs on pull requests - pull_request: - branches: - - master - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages -permissions: - contents: read - pages: write - id-token: write - -# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. -# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. -concurrency: - group: "pages" - cancel-in-progress: false - -# Default to bash -defaults: - run: - shell: bash - -jobs: - # Build job - build: - runs-on: ubuntu-latest - outputs: - pages_configured: ${{ steps.pages.outcome == 'success' }} - env: - HUGO_VERSION: 0.148.0 - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - submodules: recursive - fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod - - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: '1.23' - - - name: Generate documentation - run: | - go run ./cmd/starmap/main.go generate docs --output ./docs --verbose - - - name: Setup Hugo - uses: peaceiris/actions-hugo@v3 - with: - hugo-version: ${{ env.HUGO_VERSION }} - extended: true - - - name: Setup Node - uses: actions/setup-node@v4 - with: - node-version: '20' - cache: 'npm' - cache-dependency-path: site/package-lock.json - if: hashFiles('site/package-lock.json') != '' - - - name: Install Node dependencies - run: | - if [ -f site/package.json ]; then - cd site && npm ci - fi - - - name: Setup Pages - id: pages - uses: actions/configure-pages@v5 - if: github.event_name != 'pull_request' - continue-on-error: true - - - name: Check Pages Configuration - if: github.event_name != 'pull_request' && steps.pages.outcome != 'success' - run: | - echo "::warning::GitHub Pages is not enabled for this repository." - echo "::warning::To enable deployment, go to Settings > Pages and configure GitHub Pages." - echo "::warning::The documentation has been built but will not be deployed." - - - name: Install Dart Sass - run: sudo snap install dart-sass - - - name: Ensure content symlink exists - run: | - cd site - # Remove existing symlink or directory if it exists - rm -rf content - # Create fresh symlink - ln -s ../docs content - - - name: Initialize Hugo theme submodule - run: | - git submodule update --init --recursive - # Fix theme structure - create symlink for partials - cd site/themes/hugo-book/layouts - ln -s _partials partials || true - - - name: Build with Hugo - env: - HUGO_CACHEDIR: ${{ runner.temp }}/hugo_cache - HUGO_ENVIRONMENT: production - TZ: UTC - run: | - cd site - # Use Pages base URL if available, otherwise use default - BASE_URL="${{ steps.pages.outputs.base_url }}/" - if [ -z "$BASE_URL" ] || [ "$BASE_URL" = "/" ]; then - BASE_URL="https://agentstation.github.io/starmap/" - fi - hugo \ - --gc \ - --minify \ - --baseURL "$BASE_URL" \ - --buildDrafts=${{ github.event_name == 'pull_request' }} - - - name: Upload artifact - uses: actions/upload-pages-artifact@v3 - with: - path: ./site/public - if: github.event_name != 'pull_request' && steps.pages.outcome == 'success' - - - name: Upload preview artifact - uses: actions/upload-artifact@v4 - with: - name: hugo-preview-${{ github.event.pull_request.number || github.sha }} - path: ./site/public - if: github.event_name == 'pull_request' - - # Deployment job - deploy: - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - runs-on: ubuntu-latest - needs: build - if: github.event_name != 'pull_request' && needs.build.outputs.pages_configured == 'true' - steps: - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 - - # Comment on PR with preview link - preview-comment: - runs-on: ubuntu-latest - needs: build - if: github.event_name == 'pull_request' - permissions: - pull-requests: write - steps: - - name: Comment preview link - uses: actions/github-script@v7 - with: - script: | - const artifactUrl = `https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}`; - const comment = `šŸ“š **Documentation Preview Ready!** - - The documentation has been built successfully. You can download the preview artifact from: - ${artifactUrl} - - Once this PR is merged, the documentation will be automatically deployed to GitHub Pages.`; - - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: comment - }); \ No newline at end of file diff --git a/.gitignore b/.gitignore index ac9634b4a..6f4e414b3 100644 --- a/.gitignore +++ b/.gitignore @@ -128,9 +128,14 @@ config.local.* # models.dev sources (created during sync) models.dev-git/ -# Hugo generated directories +# Hugo-generated files (Hugo is no longer used but files may exist locally) /public/ -/resources/ -/site/public/ -/site/resources/ *.lock + +# OpenAPI intermediate files (cleaned up during generation) +internal/embedded/openapi/docs.go # Generated Go code (not needed, we use //go:embed) +internal/embedded/openapi/swagger.json # Temporary file (renamed to openapi.json) +internal/embedded/openapi/swagger.yaml # Temporary file (renamed to openapi.yaml) + +# Nix build artifacts +/devbox/swag/result # Nix build output symlink diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 296bf734a..000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "site/themes/hugo-book"] - path = site/themes/hugo-book - url = https://github.com/alex-shpak/hugo-book diff --git a/API.md b/API.md index c8993c78b..23c1473ed 100644 --- a/API.md +++ b/API.md @@ -90,14 +90,15 @@ Package starmap provides a unified AI model catalog system with automatic update - [func WithAutoUpdatesDisabled\(\) Option](<#WithAutoUpdatesDisabled>) - [func WithEmbeddedCatalog\(\) Option](<#WithEmbeddedCatalog>) - [func WithLocalPath\(path string\) Option](<#WithLocalPath>) - - [func WithRemoteServer\(url string, apiKey \*string\) Option](<#WithRemoteServer>) - - [func WithRemoteServerOnly\(\) Option](<#WithRemoteServerOnly>) + - [func WithRemoteServerAPIKey\(apiKey string\) Option](<#WithRemoteServerAPIKey>) + - [func WithRemoteServerOnly\(url string\) Option](<#WithRemoteServerOnly>) + - [func WithRemoteServerURL\(url string\) Option](<#WithRemoteServerURL>) - [type Persistence](<#Persistence>) - [type Updater](<#Updater>) -## type [AutoUpdateFunc]() +## type [AutoUpdateFunc]() AutoUpdateFunc is a function that updates the catalog. @@ -211,7 +212,7 @@ type ModelUpdatedHook func(old, updated catalogs.Model) ``` -## type [Option]() +## type [Option]() Option is a function that configures a Starmap instance. @@ -220,7 +221,7 @@ type Option func(*options) error ``` -### func [WithAutoUpdateFunc]() +### func [WithAutoUpdateFunc]() ```go func WithAutoUpdateFunc(fn AutoUpdateFunc) Option @@ -229,7 +230,7 @@ func WithAutoUpdateFunc(fn AutoUpdateFunc) Option WithAutoUpdateFunc configures a custom function for updating the catalog. -### func [WithAutoUpdateInterval]() +### func [WithAutoUpdateInterval]() ```go func WithAutoUpdateInterval(interval time.Duration) Option @@ -238,7 +239,7 @@ func WithAutoUpdateInterval(interval time.Duration) Option WithAutoUpdateInterval configures how often to automatically update the catalog. -### func [WithAutoUpdatesDisabled]() +### func [WithAutoUpdatesDisabled]() ```go func WithAutoUpdatesDisabled() Option @@ -247,7 +248,7 @@ func WithAutoUpdatesDisabled() Option WithAutoUpdatesDisabled configures whether automatic updates are disabled. -### func [WithEmbeddedCatalog]() +### func [WithEmbeddedCatalog]() ```go func WithEmbeddedCatalog() Option @@ -256,7 +257,7 @@ func WithEmbeddedCatalog() Option WithEmbeddedCatalog configures whether to use an embedded catalog. It defaults to false, but takes precedence over WithLocalPath if set. -### func [WithLocalPath]() +### func [WithLocalPath]() ```go func WithLocalPath(path string) Option @@ -264,24 +265,33 @@ func WithLocalPath(path string) Option WithLocalPath configures the local source to use a specific catalog path. - -### func [WithRemoteServer]() + +### func [WithRemoteServerAPIKey]() ```go -func WithRemoteServer(url string, apiKey *string) Option +func WithRemoteServerAPIKey(apiKey string) Option ``` -WithRemoteServer configures the remote server for catalog updates. A url is required, an api key can be provided for authentication, otherwise use nil to skip Bearer token authentication. +WithRemoteServerAPIKey configures the remote server API key. -### func [WithRemoteServerOnly]() +### func [WithRemoteServerOnly]() ```go -func WithRemoteServerOnly() Option +func WithRemoteServerOnly(url string) Option ``` WithRemoteServerOnly configures whether to only use the remote server and not hit provider APIs. + +### func [WithRemoteServerURL]() + +```go +func WithRemoteServerURL(url string) Option +``` + +WithRemoteServerURL configures the remote server URL. + ## type [Persistence]() @@ -312,4 +322,4 @@ type Updater interface { Generated by [gomarkdoc]() - + \ No newline at end of file diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 39a1d97b7..3b8128b6a 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -1076,7 +1076,19 @@ starmap/ │ ā”œā”€ā”€ internal/ # Internal packages │ ā”œā”€ā”€ embedded/ # Embedded catalog data -│ │ └── catalog/ # Embedded YAML files +│ │ ā”œā”€ā”€ catalog/ # Embedded YAML files +│ │ └── openapi/ # OpenAPI 3.1 specs (JSON/YAML) +│ ā”œā”€ā”€ server/ # HTTP server implementation +│ │ ā”œā”€ā”€ server.go # Server struct & lifecycle +│ │ ā”œā”€ā”€ config.go # Configuration management +│ │ ā”œā”€ā”€ router.go # Route registration & middleware +│ │ └── handlers/ # HTTP request handlers +│ │ ā”œā”€ā”€ models.go # Model endpoints +│ │ ā”œā”€ā”€ providers.go # Provider endpoints +│ │ ā”œā”€ā”€ admin.go # Admin operations +│ │ ā”œā”€ā”€ health.go # Health checks +│ │ ā”œā”€ā”€ realtime.go # WebSocket/SSE +│ │ └── openapi.go # OpenAPI spec endpoints │ ā”œā”€ā”€ sources/ # Source implementations │ │ ā”œā”€ā”€ providers/ # Provider API clients │ │ │ ā”œā”€ā”€ openai/ # OpenAI client diff --git a/CLAUDE.md b/CLAUDE.md index 5c654fbb7..0d2e0be67 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -256,7 +256,7 @@ for _, provider := range providers { **Core packages**: catalogs, reconciler, authority, sources, errors, logging, constants, convert -**Internal**: embedded, sources/{providers,modelsdev,local,clients}, transport +**Internal**: embedded, server, server/handlers, sources/{providers,modelsdev,local,clients}, transport **Application**: cmd/application (interface), cmd/starmap/app (implementation) @@ -289,6 +289,10 @@ GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account.json | `sync.go` | 12-step sync pipeline | | `cmd/application/application.go` | Application interface (idiomatic location) | | `cmd/starmap/app/app.go` | App implementation | +| `cmd/starmap/cmd/serve/command.go` | HTTP server CLI command | +| `internal/server/server.go` | Server lifecycle & dependencies | +| `internal/server/router.go` | Route registration & middleware | +| `internal/server/handlers/handlers.go` | Handler base structure | | `pkg/reconciler/reconciler.go` | Multi-source reconciliation | | `pkg/authority/authority.go` | Field-level authorities | | `internal/sources/providers/providers.go` | Concurrent provider fetching | diff --git a/Makefile b/Makefile index 50606b54a..1c936a711 100644 --- a/Makefile +++ b/Makefile @@ -501,11 +501,28 @@ testdata: ## Update testdata for all providers (use PROVIDER=name for specific p fi # Documentation -generate: ## Generate all documentation (Go docs only) +openapi: ## Generate OpenAPI 3.1 documentation (embedded in binary) + @echo "$(BLUE)Generating OpenAPI 3.1 documentation...$(NC)" + @$(RUN_PREFIX) which swag > /dev/null || (echo "$(RED)swag not found. Run 'devbox shell' to enter the development environment$(NC)" && exit 1) + @echo "$(YELLOW)Step 1/3: Generating OpenAPI 3.1 with swag v2...$(NC)" + @$(RUN_PREFIX) swag init -g internal/server/docs.go -o internal/embedded/openapi --parseDependency --parseInternal --v3.1 + @echo "$(YELLOW)Step 2/3: Renaming generated files...$(NC)" + @mv internal/embedded/openapi/swagger.json internal/embedded/openapi/openapi.json + @mv internal/embedded/openapi/swagger.yaml internal/embedded/openapi/openapi.yaml + @rm -f internal/embedded/openapi/docs.go + @echo "$(YELLOW)Step 3/3: Verifying embedded specs...$(NC)" + @test -f internal/embedded/openapi/openapi.json || (echo "$(RED)Error: openapi.json not found$(NC)" && exit 1) + @test -f internal/embedded/openapi/openapi.yaml || (echo "$(RED)Error: openapi.yaml not found$(NC)" && exit 1) + @echo "$(GREEN)OpenAPI 3.1 specs generated and ready for embedding$(NC)" + @echo "$(GREEN) - internal/embedded/openapi/openapi.json$(NC)" + @echo "$(GREEN) - internal/embedded/openapi/openapi.yaml$(NC)" + @echo "$(BLUE)Specs will be embedded in binary via //go:embed$(NC)" + +generate: openapi ## Generate all documentation (Go docs and OpenAPI) @echo "$(BLUE)Generating Go documentation...$(NC)" @$(RUN_PREFIX) which gomarkdoc > /dev/null || (echo "$(RED)gomarkdoc not found. Install with: go install github.com/princjef/gomarkdoc/cmd/gomarkdoc@latest$(NC)" && exit 1) $(GOCMD) generate ./... - @echo "$(GREEN)Go documentation generation complete$(NC)" + @echo "$(GREEN)All documentation generation complete$(NC)" godoc: ## Generate only Go documentation using go generate @echo "$(BLUE)Generating Go documentation...$(NC)" diff --git a/README.md b/README.md index 4af154ab9..1f960e369 100644 --- a/README.md +++ b/README.md @@ -400,9 +400,67 @@ For detailed source hierarchy, authority rules, and how sources work together, s Starmap includes 500+ models from 10+ providers (OpenAI, Anthropic, Google, Groq, DeepSeek, Cerebras, and more). Each package includes comprehensive documentation in its README. -## HTTP Server (Coming Soon) +## HTTP Server -Future HTTP server with REST API, GraphQL, WebSocket, and webhooks for centralized catalog service with multi-tenant support. +Start a production-ready REST API server for programmatic catalog access: + +```bash +# Start on default port 8080 +starmap serve + +# Custom configuration +starmap serve --port 3000 --cors --auth --rate-limit 100 + +# With specific CORS origins +starmap serve --cors-origins "https://example.com,https://app.example.com" +``` + +**Features:** +- **RESTful API**: Models, providers, search endpoints with filtering +- **Real-time Updates**: WebSocket (`/api/v1/updates/ws`) and SSE (`/api/v1/updates/stream`) +- **Performance**: In-memory caching, rate limiting (per-IP) +- **Security**: Optional API key authentication, CORS support +- **Monitoring**: Health checks (`/health`, `/api/v1/ready`), metrics endpoint +- **Documentation**: OpenAPI 3.1 specs at `/api/v1/openapi.json` + +**API Endpoints:** +```bash +# Models +GET /api/v1/models # List with filtering +GET /api/v1/models/{id} # Get specific model +POST /api/v1/models/search # Advanced search + +# Providers +GET /api/v1/providers # List providers +GET /api/v1/providers/{id} # Get specific provider +GET /api/v1/providers/{id}/models # Get provider's models + +# Admin +POST /api/v1/update # Trigger catalog sync +GET /api/v1/stats # Catalog statistics + +# Health +GET /health # Liveness probe +GET /api/v1/ready # Readiness check +``` + +**Configuration Flags:** +- `--port, -p`: Server port (default: 8080) +- `--host`: Bind address (default: localhost) +- `--cors`: Enable CORS for all origins +- `--cors-origins`: Specific CORS origins (comma-separated) +- `--auth`: Enable API key authentication +- `--rate-limit`: Requests per minute per IP (default: 100) +- `--cache-ttl`: Cache TTL in seconds (default: 300) + +**Environment Variables:** +```bash +HTTP_PORT=8080 +HTTP_HOST=0.0.0.0 +STARMAP_API_KEY=your-api-key # If --auth enabled +``` + +For full server documentation, see [internal/server/README.md](internal/server/README.md). ## Configuration diff --git a/REST_API.md b/REST_API.md new file mode 100644 index 000000000..66328ab75 --- /dev/null +++ b/REST_API.md @@ -0,0 +1,809 @@ +# Starmap API Documentation + +> REST API documentation for the Starmap HTTP server + +**Version:** v1 +**Base URL:** `http://localhost:8080/api/v1` +**Last Updated:** 2025-10-15 + +## Table of Contents + +- [Overview](#overview) +- [Getting Started](#getting-started) +- [Authentication](#authentication) +- [Response Format](#response-format) +- [Error Handling](#error-handling) +- [Endpoints](#endpoints) + - [Models](#models) + - [Providers](#providers) + - [Administration](#administration) + - [Health & Metrics](#health--metrics) + - [Real-time Updates](#real-time-updates) +- [Filtering & Search](#filtering--search) +- [Rate Limiting](#rate-limiting) +- [CORS](#cors) +- [Examples](#examples) + +## Overview + +The Starmap HTTP API provides programmatic access to the unified AI model catalog. It offers: + +- **RESTful endpoints** for querying models and providers +- **Advanced filtering** with multiple criteria +- **Real-time updates** via WebSocket and Server-Sent Events +- **In-memory caching** for performance +- **Rate limiting** to prevent abuse +- **Optional authentication** with API keys + +## Getting Started + +### Starting the Server + +```bash +# Start with default settings (port 8080, no auth) +starmap serve + +# Start with custom port +starmap serve --port 3000 + +# Enable authentication +export API_KEY="your-secret-key" +starmap serve --auth + +# Enable CORS for specific origins +starmap serve --cors-origins "https://example.com,https://app.example.com" + +# Full configuration +starmap serve \ + --port 8080 \ + --host localhost \ + --cors \ + --auth \ + --rate-limit 100 \ + --cache-ttl 300 +``` + +### Configuration Options + +| Flag | Environment Variable | Default | Description | +|------|---------------------|---------|-------------| +| `--port` | `HTTP_PORT` | `8080` | Server port | +| `--host` | `HTTP_HOST` | `localhost` | Bind address | +| `--cors` | - | `false` | Enable CORS for all origins | +| `--cors-origins` | `CORS_ORIGINS` | - | Allowed CORS origins (comma-separated) | +| `--auth` | `ENABLE_AUTH` | `false` | Enable API key authentication | +| `--auth-header` | - | `X-API-Key` | Authentication header name | +| `--rate-limit` | `RATE_LIMIT_RPM` | `100` | Requests per minute per IP | +| `--cache-ttl` | `CACHE_TTL` | `300` | Cache TTL in seconds | +| `--read-timeout` | `READ_TIMEOUT` | `10s` | HTTP read timeout | +| `--write-timeout` | `WRITE_TIMEOUT` | `10s` | HTTP write timeout | +| `--idle-timeout` | `IDLE_TIMEOUT` | `120s` | HTTP idle timeout | + +## Authentication + +When authentication is enabled, all requests (except health endpoints) require an API key. + +### API Key Header + +```http +X-API-Key: your-secret-key +``` + +Or using the Authorization header: + +```http +Authorization: Bearer your-secret-key +``` + +### Public Endpoints + +The following endpoints are always publicly accessible: + +- `GET /health` +- `GET /api/v1/health` +- `GET /api/v1/ready` + +### Example + +```bash +# With X-API-Key header +curl -H "X-API-Key: your-secret-key" \ + http://localhost:8080/api/v1/models + +# With Authorization header +curl -H "Authorization: Bearer your-secret-key" \ + http://localhost:8080/api/v1/models +``` + +## Response Format + +All API responses follow a consistent format: + +### Success Response + +```json +{ + "data": { + // Response data here + }, + "error": null +} +``` + +### Error Response + +```json +{ + "data": null, + "error": { + "code": "ERROR_CODE", + "message": "Human-readable error message", + "details": "Additional error details" + } +} +``` + +## Error Handling + +### Error Codes + +| Code | HTTP Status | Description | +|------|-------------|-------------| +| `BAD_REQUEST` | 400 | Invalid request format or parameters | +| `UNAUTHORIZED` | 401 | Invalid or missing API key | +| `NOT_FOUND` | 404 | Resource not found | +| `METHOD_NOT_ALLOWED` | 405 | HTTP method not supported | +| `RATE_LIMITED` | 429 | Rate limit exceeded | +| `INTERNAL_ERROR` | 500 | Internal server error | +| `SERVICE_UNAVAILABLE` | 503 | Service temporarily unavailable | + +### Example Error Response + +```json +{ + "data": null, + "error": { + "code": "NOT_FOUND", + "message": "Model not found", + "details": "No model with ID 'gpt-5' exists" + } +} +``` + +## Endpoints + +### Models + +#### List Models + +```http +GET /api/v1/models +``` + +List all models with optional filtering. + +**Query Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `id` | string | Filter by exact model ID | +| `name` | string | Filter by exact model name (case-insensitive) | +| `name_contains` | string | Filter by partial model name match | +| `provider` | string | Filter by provider ID | +| `modality_input` | string | Filter by input modality (comma-separated) | +| `modality_output` | string | Filter by output modality (comma-separated) | +| `feature` | string | Filter by feature (streaming, tool_calls, etc.) | +| `tag` | string | Filter by tag (comma-separated) | +| `open_weights` | boolean | Filter by open weights status | +| `min_context` | integer | Minimum context window size | +| `max_context` | integer | Maximum context window size | +| `sort` | string | Sort field (id, name, release_date, context_window) | +| `order` | string | Sort order (asc, desc) | +| `limit` | integer | Maximum results (default: 100, max: 1000) | +| `offset` | integer | Result offset for pagination | + +**Example Request:** + +```bash +curl "http://localhost:8080/api/v1/models?provider=openai&feature=tool_calls&limit=10" +``` + +**Example Response:** + +```json +{ + "data": { + "models": [ + { + "id": "gpt-4", + "name": "GPT-4", + "description": "Large multimodal model", + "features": { + "modalities": { + "input": ["text", "image"], + "output": ["text"] + }, + "tool_calls": true, + "streaming": true + }, + "limits": { + "context_window": 128000, + "output_tokens": 16384 + } + } + ], + "pagination": { + "total": 1, + "limit": 10, + "offset": 0, + "count": 1 + } + }, + "error": null +} +``` + +#### Get Model by ID + +```http +GET /api/v1/models/{id} +``` + +Retrieve detailed information about a specific model. + +**Path Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `id` | string | Model ID | + +**Example Request:** + +```bash +curl http://localhost:8080/api/v1/models/gpt-4 +``` + +**Example Response:** + +```json +{ + "data": { + "id": "gpt-4", + "name": "GPT-4", + "authors": [ + { + "name": "OpenAI", + "url": "https://openai.com" + } + ], + "description": "Large multimodal model with advanced reasoning", + "metadata": { + "release_date": "2023-03-14T00:00:00Z", + "open_weights": false, + "tags": ["chat", "vision"] + }, + "features": { + "modalities": { + "input": ["text", "image"], + "output": ["text"] + }, + "tool_calls": true, + "tools": true, + "tool_choice": true, + "streaming": true + }, + "limits": { + "context_window": 128000, + "output_tokens": 16384 + }, + "pricing": { + "tokens": { + "input": { + "per_1m": 30.0 + }, + "output": { + "per_1m": 60.0 + } + } + } + }, + "error": null +} +``` + +#### Advanced Model Search + +```http +POST /api/v1/models/search +``` + +Perform advanced search with multiple criteria. + +**Request Body:** + +```json +{ + "ids": ["gpt-4", "claude-3-opus"], + "name_contains": "gpt", + "provider": "openai", + "modalities": { + "input": ["text", "image"], + "output": ["text"] + }, + "features": { + "streaming": true, + "tool_calls": true + }, + "tags": ["chat", "vision"], + "open_weights": false, + "context_window": { + "min": 32000, + "max": 200000 + }, + "output_tokens": { + "min": 4000, + "max": 16000 + }, + "release_date": { + "after": "2024-01-01", + "before": "2025-01-01" + }, + "sort": "release_date", + "order": "desc", + "max_results": 100 +} +``` + +**Example Request:** + +```bash +curl -X POST http://localhost:8080/api/v1/models/search \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "features": {"tool_calls": true}, + "context_window": {"min": 32000} + }' +``` + +**Example Response:** + +```json +{ + "data": { + "models": [...], + "count": 5 + }, + "error": null +} +``` + +### Providers + +#### List Providers + +```http +GET /api/v1/providers +``` + +List all providers. + +**Example Request:** + +```bash +curl http://localhost:8080/api/v1/providers +``` + +**Example Response:** + +```json +{ + "data": { + "providers": [ + { + "id": "openai", + "name": "OpenAI", + "model_count": 42, + "headquarters": "San Francisco, CA", + "docs_url": "https://platform.openai.com/docs" + } + ], + "count": 1 + }, + "error": null +} +``` + +#### Get Provider by ID + +```http +GET /api/v1/providers/{id} +``` + +Retrieve detailed information about a specific provider. + +**Example Request:** + +```bash +curl http://localhost:8080/api/v1/providers/openai +``` + +#### Get Provider Models + +```http +GET /api/v1/providers/{id}/models +``` + +List all models for a specific provider. + +**Example Request:** + +```bash +curl http://localhost:8080/api/v1/providers/openai/models +``` + +**Example Response:** + +```json +{ + "data": { + "provider": { + "id": "openai", + "name": "OpenAI" + }, + "models": [...], + "count": 42 + }, + "error": null +} +``` + +### Administration + +#### Trigger Catalog Update + +```http +POST /api/v1/update +``` + +Manually trigger catalog synchronization. + +**Query Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `provider` | string | Update specific provider only | + +**Example Request:** + +```bash +# Update all providers +curl -X POST http://localhost:8080/api/v1/update + +# Update specific provider +curl -X POST "http://localhost:8080/api/v1/update?provider=openai" +``` + +**Example Response:** + +```json +{ + "data": { + "status": "completed", + "total_changes": 5, + "providers_changed": 1, + "dry_run": false + }, + "error": null +} +``` + +#### Get Catalog Statistics + +```http +GET /api/v1/stats +``` + +Get catalog statistics. + +**Example Response:** + +```json +{ + "data": { + "models": { + "total": 250 + }, + "providers": { + "total": 8 + }, + "cache": { + "item_count": 42 + }, + "realtime": { + "websocket_clients": 3, + "sse_clients": 1 + } + }, + "error": null +} +``` + +### Health & Metrics + +#### Health Check + +```http +GET /api/v1/health +GET /health +``` + +Health check endpoint (liveness probe). + +**Example Response:** + +```json +{ + "data": { + "status": "healthy", + "service": "starmap-api", + "version": "v1" + }, + "error": null +} +``` + +#### Readiness Check + +```http +GET /api/v1/ready +``` + +Readiness check including cache and data source status. + +**Example Response:** + +```json +{ + "data": { + "status": "ready", + "cache": { + "items": 42 + }, + "websocket_clients": 3, + "sse_clients": 1 + }, + "error": null +} +``` + +#### Metrics + +```http +GET /metrics +``` + +Prometheus-compatible metrics endpoint. + +### Real-time Updates + +#### WebSocket + +```http +WS /api/v1/updates/ws +``` + +WebSocket connection for real-time catalog updates. + +**Message Format:** + +```json +{ + "type": "sync.completed", + "timestamp": "2025-10-14T12:00:00Z", + "data": { + "total_changes": 5, + "providers_changed": 1 + } +} +``` + +**Event Types:** + +- `client.connected` - Client connected to stream +- `sync.started` - Catalog sync initiated +- `sync.completed` - Catalog sync finished +- `model.created` - New model added +- `model.updated` - Model modified +- `model.deleted` - Model removed + +**Example (JavaScript):** + +```javascript +const ws = new WebSocket('ws://localhost:8080/api/v1/updates/ws'); + +ws.onmessage = (event) => { + const message = JSON.parse(event.data); + console.log('Event:', message.type, message.data); +}; +``` + +#### Server-Sent Events (SSE) + +```http +GET /api/v1/updates/stream +``` + +Server-Sent Events stream for catalog change notifications. + +**Example (JavaScript):** + +```javascript +const eventSource = new EventSource('http://localhost:8080/api/v1/updates/stream'); + +eventSource.addEventListener('sync.completed', (event) => { + const data = JSON.parse(event.data); + console.log('Sync completed:', data); +}); + +eventSource.addEventListener('connected', (event) => { + console.log('Connected to updates stream'); +}); +``` + +## Filtering & Search + +### Simple Filtering (GET) + +Use query parameters for simple filtering: + +```bash +# Filter by provider +curl "http://localhost:8080/api/v1/models?provider=openai" + +# Multiple filters +curl "http://localhost:8080/api/v1/models?provider=openai&feature=tool_calls&min_context=32000" + +# Modality filtering +curl "http://localhost:8080/api/v1/models?modality_input=text,image&modality_output=text" + +# Tag filtering +curl "http://localhost:8080/api/v1/models?tag=chat,vision" +``` + +### Advanced Search (POST) + +Use the search endpoint for complex queries: + +```bash +curl -X POST http://localhost:8080/api/v1/models/search \ + -H "Content-Type: application/json" \ + -d '{ + "provider": "openai", + "features": { + "tool_calls": true, + "streaming": true + }, + "context_window": { + "min": 32000 + }, + "tags": ["chat"], + "sort": "release_date", + "order": "desc" + }' +``` + +## Rate Limiting + +The API enforces rate limiting per IP address. + +**Default:** 100 requests per minute +**Header:** Rate limit info in response headers (future) + +When rate limited, you'll receive a `429` response: + +```json +{ + "data": null, + "error": { + "code": "RATE_LIMITED", + "message": "Rate limit exceeded", + "details": "Too many requests. Please try again later." + } +} +``` + +## CORS + +CORS can be configured via command-line flags: + +```bash +# Enable CORS for all origins +starmap serve --cors + +# Enable CORS for specific origins +starmap serve --cors-origins "https://example.com,https://app.example.com" +``` + +## Examples + +### Complete Workflow + +```bash +# 1. Start server +starmap serve --port 8080 + +# 2. Check health +curl http://localhost:8080/health + +# 3. List all models +curl http://localhost:8080/api/v1/models + +# 4. Search for specific models +curl -X POST http://localhost:8080/api/v1/models/search \ + -H "Content-Type: application/json" \ + -d '{"provider": "openai", "features": {"tool_calls": true}}' + +# 5. Get specific model +curl http://localhost:8080/api/v1/models/gpt-4 + +# 6. Get provider info +curl http://localhost:8080/api/v1/providers/openai + +# 7. Trigger catalog update +curl -X POST http://localhost:8080/api/v1/update + +# 8. Check statistics +curl http://localhost:8080/api/v1/stats +``` + +### With Authentication + +```bash +export API_KEY="your-secret-key" + +# Start server with auth +starmap serve --auth + +# Make authenticated request +curl -H "X-API-Key: $API_KEY" \ + http://localhost:8080/api/v1/models +``` + +### Real-time Updates + +```javascript +// WebSocket example +const ws = new WebSocket('ws://localhost:8080/api/v1/updates/ws'); + +ws.onopen = () => console.log('Connected'); +ws.onmessage = (event) => { + const msg = JSON.parse(event.data); + if (msg.type === 'sync.completed') { + console.log('Catalog updated:', msg.data.total_changes, 'changes'); + } +}; + +// SSE example +const eventSource = new EventSource('http://localhost:8080/api/v1/updates/stream'); +eventSource.onmessage = (event) => { + const data = JSON.parse(event.data); + console.log('Update:', data); +}; +``` + +## Best Practices + +1. **Use Caching**: Results are cached by default (5 min TTL) +2. **Filter Early**: Use query parameters to reduce response size +3. **Paginate**: Use `limit` and `offset` for large result sets +4. **Handle Errors**: Always check the `error` field in responses +5. **Rate Limits**: Implement client-side rate limiting +6. **Real-time**: Use WebSocket/SSE for live updates instead of polling +7. **Authentication**: Keep API keys secure, never commit to version control + +## Support + +For issues, questions, or feature requests, please visit: +- GitHub: https://github.com/agentstation/starmap +- Documentation: https://docs.starmap.dev (future) diff --git a/cmd/starmap/cmd/serve/api.go b/cmd/starmap/cmd/serve/api.go deleted file mode 100644 index 6ab3d6d97..000000000 --- a/cmd/starmap/cmd/serve/api.go +++ /dev/null @@ -1,212 +0,0 @@ -package serve - -import ( - "fmt" - "net/http" - "os" - "time" - - "github.com/spf13/cobra" - - "github.com/agentstation/starmap/cmd/application" -) - -// NewAPICommand creates the serve api command using app context. -func NewAPICommand(app application.Application) *cobra.Command { - cmd := &cobra.Command{ - Use: "api", - Short: "Serve REST API server", - Long: `Start a REST API server for the starmap catalog. - -Features: - - RESTful endpoints for models, providers, and authors - - CORS support for web applications - - Rate limiting and authentication - - Health checks and metrics - - Graceful shutdown - -The API provides programmatic access to the starmap catalog with -endpoints for listing, searching, and retrieving model information.`, - Example: ` starmap serve api # Start on default port 8080 - starmap serve api --port 3000 # Start on custom port - starmap serve api --cors # Enable CORS for all origins - starmap serve api --auth # Enable API key authentication`, - RunE: func(cmd *cobra.Command, args []string) error { - return runAPI(cmd, args, app) - }, - } - - // Add common server flags - AddCommonFlags(cmd, getDefaultAPIPort()) - - // Add API-specific flags - cmd.Flags().Bool("cors", false, "Enable CORS for all origins") - cmd.Flags().StringSlice("cors-origins", []string{}, "Allowed CORS origins") - cmd.Flags().Bool("auth", false, "Enable API key authentication") - cmd.Flags().String("auth-header", "X-API-Key", "Authentication header name") - cmd.Flags().Int("rate-limit", 100, "Requests per minute per IP") - cmd.Flags().Bool("metrics", true, "Enable metrics endpoint") - cmd.Flags().String("prefix", "/api/v1", "API path prefix") - - return cmd -} - -// runAPI starts the API server using app context. -func runAPI(cmd *cobra.Command, _ []string, app application.Application) error { - config, err := GetServerConfig(cmd, getDefaultAPIPort()) - if err != nil { - return fmt.Errorf("getting server config: %w", err) - } - - // Get API-specific flags - corsEnabled, _ := cmd.Flags().GetBool("cors") - corsOrigins, _ := cmd.Flags().GetStringSlice("cors-origins") - authEnabled, _ := cmd.Flags().GetBool("auth") - authHeader, _ := cmd.Flags().GetString("auth-header") - rateLimit, _ := cmd.Flags().GetInt("rate-limit") - metricsEnabled, _ := cmd.Flags().GetBool("metrics") - pathPrefix, _ := cmd.Flags().GetString("prefix") - - // Override with environment-specific port - if envPort := os.Getenv("STARMAP_API_PORT"); envPort != "" { - if port, err := parsePort(envPort); err == nil { - config.Port = port - } - } - - logger := app.Logger() - logger.Info(). - Int("port", config.Port). - Str("host", config.Host). - Str("prefix", pathPrefix). - Bool("cors", corsEnabled). - Bool("auth", authEnabled). - Int("rate_limit", rateLimit). - Msg("Starting API server") - - // Create HTTP server - server := &http.Server{ - Addr: config.Address(), - Handler: createAPIHandler(app, corsEnabled, corsOrigins, authEnabled, authHeader, rateLimit, metricsEnabled, pathPrefix), - ReadTimeout: 15 * time.Second, - WriteTimeout: 15 * time.Second, - IdleTimeout: 60 * time.Second, - } - - // Start server with graceful shutdown - return StartServerWithGracefulShutdown(server, "API") -} - -// createAPIHandler creates the HTTP handler using app context. -func createAPIHandler(app application.Application, corsEnabled bool, corsOrigins []string, authEnabled bool, authHeader string, _ int, metricsEnabled bool, pathPrefix string) http.Handler { - // Initialize API handlers with app context - apiHandlers, err := NewAPIHandlers(app) - if err != nil { - logger := app.Logger() - logger.Error().Err(err).Msg("Failed to initialize API handlers") - // Return a handler that returns 503 for all requests - return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - http.Error(w, `{"error":"service unavailable","message":"failed to load catalog"}`, http.StatusServiceUnavailable) - }) - } - - mux := http.NewServeMux() - - // Health check endpoint - mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := fmt.Fprint(w, `{"status":"healthy","service":"starmap-api","version":"v1"}`); err != nil { - app.Logger().Error().Err(err).Msg("Failed to write health check response") - } - }) - - // Middleware wrapper to apply CORS and auth - wrap := func(handler http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Apply CORS if enabled - if corsEnabled { - applyCORS(w, corsOrigins) - // Handle preflight requests - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - } - - // Apply auth if enabled - if authEnabled && !isAuthenticated(r, authHeader) { - w.Header().Set("Content-Type", "application/json") - http.Error(w, `{"error":"unauthorized","message":"valid API key required"}`, http.StatusUnauthorized) - return - } - - // Call the actual handler - handler(w, r) - } - } - - // REST API endpoints following documented spec - - // Models endpoints - mux.HandleFunc(pathPrefix+"/models", wrap(apiHandlers.ModelsHandler)) - mux.HandleFunc(pathPrefix+"/models/", wrap(apiHandlers.ModelByIDHandler)) - - // Providers endpoints - mux.HandleFunc(pathPrefix+"/providers", wrap(apiHandlers.ProvidersHandler)) - mux.HandleFunc(pathPrefix+"/providers/", wrap(apiHandlers.ProviderByIDHandler)) - - // Future endpoints (placeholder responses) - mux.HandleFunc(pathPrefix+"/webhooks", wrap(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - http.Error(w, `{"error":"not implemented","message":"webhooks endpoint coming soon"}`, http.StatusNotImplemented) - })) - - mux.HandleFunc(pathPrefix+"/updates/stream", wrap(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - http.Error(w, `{"error":"not implemented","message":"SSE updates endpoint coming soon"}`, http.StatusNotImplemented) - })) - - mux.HandleFunc(pathPrefix+"/sync", wrap(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - http.Error(w, `{"error":"not implemented","message":"sync endpoint coming soon"}`, http.StatusNotImplemented) - })) - - // Metrics endpoint (optional) - if metricsEnabled { - mux.HandleFunc("/metrics", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/plain") - modelsCount := len(apiHandlers.catalog.Models().List()) - if _, err := fmt.Fprintf(w, "# Starmap API Metrics\n# starmap_api_requests_total 0\n# starmap_catalog_models_total %d\n", modelsCount); err != nil { - app.Logger().Error().Err(err).Msg("Failed to write metrics response") - } - }) - } - - return mux -} - -// applyCORS applies CORS headers to the response. -func applyCORS(w http.ResponseWriter, allowedOrigins []string) { - if len(allowedOrigins) == 0 { - w.Header().Set("Access-Control-Allow-Origin", "*") - } else { - // In a real implementation, you'd check the request origin against allowed origins - w.Header().Set("Access-Control-Allow-Origin", allowedOrigins[0]) - } - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key") -} - -// isAuthenticated checks if the request is authenticated. -func isAuthenticated(r *http.Request, authHeader string) bool { - apiKey := r.Header.Get(authHeader) - // Placeholder implementation - in real use, validate against configured API keys - return apiKey != "" -} - -// getDefaultAPIPort returns the default port for API server. -func getDefaultAPIPort() int { - // Common HTTP API server port - return 8080 -} diff --git a/cmd/starmap/cmd/serve/command.go b/cmd/starmap/cmd/serve/command.go index a594efcf7..6675a87e5 100644 --- a/cmd/starmap/cmd/serve/command.go +++ b/cmd/starmap/cmd/serve/command.go @@ -1,34 +1,236 @@ +// Package serve provides HTTP server commands for the Starmap CLI. package serve import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/agentstation/starmap/cmd/application" + "github.com/agentstation/starmap/internal/server" ) // NewCommand creates the serve command using app context. func NewCommand(app application.Application) *cobra.Command { cmd := &cobra.Command{ - Use: "serve", - Short: "Start HTTP servers for various resources", - Long: `Serve starts HTTP servers for different starmap resources. + Use: "serve", + Aliases: []string{"server"}, + Short: "Start the REST API server with WebSocket and SSE support", + Long: `Start a production-ready REST API server for the starmap catalog. + +Features: + - RESTful endpoints for models, providers, and catalog management + - WebSocket support for real-time updates (/api/v1/updates/ws) + - Server-Sent Events (SSE) for streaming updates (/api/v1/updates/stream) + - In-memory caching with configurable TTL + - Rate limiting (requests per minute per IP) + - API key authentication (optional) + - CORS support for web applications + - Request logging and panic recovery + - Graceful shutdown with connection draining + - Health checks and metrics endpoints + - OpenAPI 3.1 documentation (/api/v1/openapi.json) + +The API provides programmatic access to the starmap catalog with +comprehensive filtering, search, and real-time notification capabilities.`, + Example: ` # Start on default port 8080 + starmap serve -Available services: - api - REST API server [default: :8080] + # Start on custom port with authentication + starmap serve --port 3000 --auth -Examples: - starmap serve api --port 3000 # Start API server on :3000 + # Enable CORS for specific origins + starmap serve --cors-origins "https://example.com,https://app.example.com" -Environment Variables: - PORT - Default port for single services - STARMAP_API_PORT - API server port - HOST - Bind address (default: localhost)`, - Example: ` starmap serve api --cors - starmap serve api --port 8080`, + # Enable rate limiting + starmap serve --rate-limit 60 + + # Full configuration + starmap serve --port 8080 --cors --auth --rate-limit 100`, + RunE: func(cmd *cobra.Command, args []string) error { + return runServer(cmd, args, app) + }, } - // Add subcommands with app context - cmd.AddCommand(NewAPICommand(app)) + // Server configuration flags + cmd.Flags().IntP("port", "p", 8080, "Server port") + cmd.Flags().String("host", "localhost", "Bind address") + + // CORS flags + cmd.Flags().Bool("cors", false, "Enable CORS for all origins") + cmd.Flags().StringSlice("cors-origins", []string{}, "Allowed CORS origins (comma-separated)") + + // Authentication flags + cmd.Flags().Bool("auth", false, "Enable API key authentication") + cmd.Flags().String("auth-header", "X-API-Key", "Authentication header name") + + // Performance flags + cmd.Flags().Int("rate-limit", 100, "Requests per minute per IP (0 to disable)") + cmd.Flags().Int("cache-ttl", 300, "Cache TTL in seconds") + + // Timeout flags + cmd.Flags().Duration("read-timeout", 10*time.Second, "HTTP read timeout") + cmd.Flags().Duration("write-timeout", 10*time.Second, "HTTP write timeout") + cmd.Flags().Duration("idle-timeout", 120*time.Second, "HTTP idle timeout") + + // Features flags + cmd.Flags().Bool("metrics", true, "Enable metrics endpoint") + cmd.Flags().String("prefix", "/api/v1", "API path prefix") return cmd } + +// runServer starts the API server. +func runServer(cmd *cobra.Command, _ []string, app application.Application) error { + // Parse flags into configuration + cfg := parseConfig(cmd) + logger := app.Logger() + + logger.Info(). + Int("port", cfg.Port). + Str("host", cfg.Host). + Str("prefix", cfg.PathPrefix). + Bool("cors", cfg.CORSEnabled). + Bool("auth", cfg.AuthEnabled). + Int("rate_limit", cfg.RateLimit). + Dur("cache_ttl", cfg.CacheTTL). + Msg("Starting API server") + + // Create server + srv, err := server.New(app, cfg) + if err != nil { + return fmt.Errorf("creating server: %w", err) + } + + // Start background services (WebSocket hub, SSE broadcaster) + srv.Start() + + // Create HTTP server + httpServer := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Handler: srv.Handler(), + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + IdleTimeout: cfg.IdleTimeout, + } + + // Start server with graceful shutdown + return startWithGracefulShutdown(httpServer, srv, logger) +} + +// parseConfig parses command flags into server configuration. +func parseConfig(cmd *cobra.Command) server.Config { + port, _ := cmd.Flags().GetInt("port") + host, _ := cmd.Flags().GetString("host") + corsEnabled, _ := cmd.Flags().GetBool("cors") + corsOrigins, _ := cmd.Flags().GetStringSlice("cors-origins") + authEnabled, _ := cmd.Flags().GetBool("auth") + authHeader, _ := cmd.Flags().GetString("auth-header") + rateLimit, _ := cmd.Flags().GetInt("rate-limit") + cacheTTL, _ := cmd.Flags().GetInt("cache-ttl") + readTimeout, _ := cmd.Flags().GetDuration("read-timeout") + writeTimeout, _ := cmd.Flags().GetDuration("write-timeout") + idleTimeout, _ := cmd.Flags().GetDuration("idle-timeout") + metricsEnabled, _ := cmd.Flags().GetBool("metrics") + pathPrefix, _ := cmd.Flags().GetString("prefix") + + // Override with environment variables + if envPort := os.Getenv("HTTP_PORT"); envPort != "" { + if p, err := parsePort(envPort); err == nil { + port = p + } + } + if envHost := os.Getenv("HTTP_HOST"); envHost != "" { + host = envHost + } + + return server.Config{ + Host: host, + Port: port, + PathPrefix: pathPrefix, + CORSEnabled: corsEnabled, + CORSOrigins: corsOrigins, + AuthEnabled: authEnabled, + AuthHeader: authHeader, + RateLimit: rateLimit, + CacheTTL: time.Duration(cacheTTL) * time.Second, + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + IdleTimeout: idleTimeout, + MetricsEnabled: metricsEnabled, + } +} + +// parsePort safely parses a port string to integer. +func parsePort(portStr string) (int, error) { + port, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port number: %s", portStr) + } + if port < 1 || port > 65535 { + return 0, fmt.Errorf("port out of range: %d", port) + } + return port, nil +} + +// startWithGracefulShutdown starts the HTTP server with graceful shutdown. +func startWithGracefulShutdown(httpServer *http.Server, srv *server.Server, logger *zerolog.Logger) error { + // Server errors channel + serverErr := make(chan error, 1) + + // Start server in goroutine + go func() { + logger.Info(). + Str("addr", httpServer.Addr). + Str("service", "API"). + Msg("Server starting") + + fmt.Printf("šŸš€ Starting API server on %s\n", httpServer.Addr) + fmt.Println(" Press Ctrl+C to stop") + + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + serverErr <- fmt.Errorf("server failed: %w", err) + } + }() + + // Wait for interrupt signal + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-serverErr: + return err + case sig := <-quit: + logger.Info(). + Str("signal", sig.String()). + Msg("Shutdown signal received") + + fmt.Printf("\nšŸ›‘ Shutting down API server...\n") + + // Create shutdown context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Shutdown HTTP server + if err := httpServer.Shutdown(ctx); err != nil { + return fmt.Errorf("server shutdown failed: %w", err) + } + + // Shutdown background services + if err := srv.Shutdown(ctx); err != nil { + logger.Warn().Err(err).Msg("Background services shutdown had issues") + } + + logger.Info().Msg("Server stopped gracefully") + fmt.Printf("āœ… API server stopped gracefully\n") + return nil + } +} diff --git a/cmd/starmap/cmd/serve/handlers.go b/cmd/starmap/cmd/serve/handlers.go deleted file mode 100644 index 8c6682ffb..000000000 --- a/cmd/starmap/cmd/serve/handlers.go +++ /dev/null @@ -1,425 +0,0 @@ -package serve - -import ( - "encoding/json" - "fmt" - "net/http" - "strconv" - "strings" - - "github.com/agentstation/starmap/cmd/application" - "github.com/agentstation/starmap/internal/cmd/provider" - "github.com/agentstation/starmap/pkg/catalogs" - "github.com/agentstation/starmap/pkg/errors" - "github.com/agentstation/starmap/pkg/logging" -) - -// APIHandlers holds the catalog and provides HTTP handlers for REST endpoints. -type APIHandlers struct { - catalog catalogs.Catalog -} - -// NewAPIHandlers creates a new API handlers instance using app context. -func NewAPIHandlers(app application.Application) (*APIHandlers, error) { - cat, err := app.Catalog() - if err != nil { - return nil, fmt.Errorf("loading catalog: %w", err) - } - - return &APIHandlers{ - catalog: cat, - }, nil -} - -// ModelsHandler handles /api/v1/models requests. -func (h *APIHandlers) ModelsHandler(w http.ResponseWriter, r *http.Request) { - logging.Debug(). - Str("method", r.Method). - Str("path", r.URL.Path). - Msg("Handling models request") - - switch r.Method { - case http.MethodGet: - h.handleGetModels(w, r) - case http.MethodPost: - h.handleSearchModels(w, r) - default: - h.methodNotAllowed(w, r) - } -} - -// ProvidersHandler handles /api/v1/providers requests. -func (h *APIHandlers) ProvidersHandler(w http.ResponseWriter, r *http.Request) { - logging.Debug(). - Str("method", r.Method). - Str("path", r.URL.Path). - Msg("Handling providers request") - - switch r.Method { - case http.MethodGet: - h.handleGetProviders(w, r) - default: - h.methodNotAllowed(w, r) - } -} - -// ModelByIDHandler handles /api/v1/models/{id} requests. -func (h *APIHandlers) ModelByIDHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - h.methodNotAllowed(w, r) - return - } - - // Extract model ID from path - path := strings.TrimPrefix(r.URL.Path, "/api/v1/models/") - modelID := strings.Split(path, "/")[0] - - if modelID == "" { - h.badRequest(w, "Model ID is required") - return - } - - logging.Debug(). - Str("model_id", modelID). - Msg("Handling model by ID request") - - h.handleGetModelByID(w, r, modelID) -} - -// ProviderByIDHandler handles /api/v1/providers/{id} requests. -func (h *APIHandlers) ProviderByIDHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - h.methodNotAllowed(w, r) - return - } - - // Extract provider ID from path - path := strings.TrimPrefix(r.URL.Path, "/api/v1/providers/") - parts := strings.Split(path, "/") - providerID := parts[0] - - if providerID == "" { - h.badRequest(w, "Provider ID is required") - return - } - - logging.Debug(). - Str("provider_id", providerID). - Msg("Handling provider by ID request") - - // Check if this is a sub-resource request (e.g., /providers/{id}/models) - if len(parts) > 1 && parts[1] == "models" { - h.handleGetProviderModels(w, r, providerID) - return - } - - h.handleGetProviderByID(w, r, providerID) -} - -// handleGetModels returns a list of all models. -func (h *APIHandlers) handleGetModels(w http.ResponseWriter, r *http.Request) { - // Parse query parameters - query := r.URL.Query() - providerFilter := query.Get("provider") - limitStr := query.Get("limit") - offsetStr := query.Get("offset") - - // Parse limit and offset - limit := 100 // Default limit - if limitStr != "" { - if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 { - limit = l - } - } - - offset := 0 - if offsetStr != "" { - if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { - offset = o - } - } - - // Get all models - allModels := h.catalog.Models().List() - - // Apply provider filter if specified - var filteredModels []*catalogs.Model - if providerFilter != "" { - // Get models from specific provider - providers := h.catalog.Providers().List() - for _, prov := range providers { - if string(prov.ID) == providerFilter { - for _, model := range prov.Models { - filteredModels = append(filteredModels, model) - } - break - } - } - } else { - // Convert to pointer slice for compatibility - filteredModels = make([]*catalogs.Model, len(allModels)) - for i := range allModels { - filteredModels[i] = &allModels[i] - } - } - - // Apply pagination - total := len(filteredModels) - start := offset - end := offset + limit - - if start >= total { - filteredModels = []*catalogs.Model{} - } else { - if end > total { - end = total - } - filteredModels = filteredModels[start:end] - } - - // Create response - response := map[string]any{ - "models": filteredModels, - "pagination": map[string]any{ - "total": total, - "limit": limit, - "offset": offset, - "count": len(filteredModels), - }, - } - - h.jsonResponse(w, http.StatusOK, response) -} - -// handleGetModelByID returns a specific model by ID. -func (h *APIHandlers) handleGetModelByID(w http.ResponseWriter, _ *http.Request, modelID string) { - // Use the catalog's FindModel method - model, err := h.catalog.FindModel(modelID) - if err != nil { - if _, ok := err.(*errors.NotFoundError); ok { - h.notFound(w, fmt.Sprintf("Model '%s' not found", modelID)) - return - } - h.internalError(w, err) - return - } - - h.jsonResponse(w, http.StatusOK, model) -} - -// handleSearchModels handles POST /api/v1/models/search. -func (h *APIHandlers) handleSearchModels(w http.ResponseWriter, r *http.Request) { - var searchReq struct { - Query string `json:"query"` - Providers []string `json:"providers,omitempty"` - Capability string `json:"capability,omitempty"` - MinContext int64 `json:"min_context,omitempty"` - MaxPrice float64 `json:"max_price,omitempty"` - Limit int `json:"limit,omitempty"` - Offset int `json:"offset,omitempty"` - } - - if err := json.NewDecoder(r.Body).Decode(&searchReq); err != nil { - h.badRequest(w, "Invalid JSON request body") - return - } - - // Set default limit - if searchReq.Limit == 0 { - searchReq.Limit = 100 - } - - // Get all models and filter based on search criteria - allModels := h.catalog.Models().List() - results := make([]catalogs.Model, 0, len(allModels)) - - for _, model := range allModels { - // Apply filters - if searchReq.Query != "" { - queryLower := strings.ToLower(searchReq.Query) - if !strings.Contains(strings.ToLower(model.Name), queryLower) && - !strings.Contains(strings.ToLower(model.ID), queryLower) && - !strings.Contains(strings.ToLower(model.Description), queryLower) { - continue - } - } - - if len(searchReq.Providers) > 0 { - // Check if model belongs to any of the requested providers - found := false - providers := h.catalog.Providers().List() - for _, prov := range providers { - for _, reqProv := range searchReq.Providers { - if string(prov.ID) == reqProv { - if _, exists := prov.Models[model.ID]; exists { - found = true - break - } - } - } - if found { - break - } - } - if !found { - continue - } - } - - if searchReq.MinContext > 0 && model.Limits != nil { - if model.Limits.ContextWindow < searchReq.MinContext { - continue - } - } - - if searchReq.MaxPrice > 0 && model.Pricing != nil && model.Pricing.Tokens != nil && model.Pricing.Tokens.Input != nil { - if model.Pricing.Tokens.Input.Per1M > searchReq.MaxPrice { - continue - } - } - - results = append(results, model) - } - - // Apply pagination - total := len(results) - start := searchReq.Offset - end := searchReq.Offset + searchReq.Limit - - if start >= total { - results = []catalogs.Model{} - } else { - if end > total { - end = total - } - results = results[start:end] - } - - // Create response - response := map[string]any{ - "models": results, - "search": map[string]any{ - "query": searchReq.Query, - "providers": searchReq.Providers, - "capability": searchReq.Capability, - }, - "pagination": map[string]any{ - "total": total, - "limit": searchReq.Limit, - "offset": searchReq.Offset, - "count": len(results), - }, - } - - h.jsonResponse(w, http.StatusOK, response) -} - -// handleGetProviders returns a list of all providers. -func (h *APIHandlers) handleGetProviders(w http.ResponseWriter, _ *http.Request) { - providers := h.catalog.Providers().List() - - // Create simplified provider list - providerList := make([]map[string]any, 0, len(providers)) - for _, prov := range providers { - providerInfo := map[string]any{ - "id": prov.ID, - "name": prov.Name, - "model_count": len(prov.Models), - } - - if prov.Headquarters != nil { - providerInfo["headquarters"] = *prov.Headquarters - } - - if prov.Catalog != nil && prov.Catalog.Docs != nil { - providerInfo["docs_url"] = *prov.Catalog.Docs - } - - providerList = append(providerList, providerInfo) - } - - h.jsonResponse(w, http.StatusOK, map[string]any{ - "providers": providerList, - "count": len(providerList), - }) -} - -// handleGetProviderByID returns a specific provider by ID. -func (h *APIHandlers) handleGetProviderByID(w http.ResponseWriter, _ *http.Request, providerID string) { - prov, err := provider.Get(h.catalog, providerID) - if err != nil { - if _, ok := err.(*errors.NotFoundError); ok { - h.notFound(w, fmt.Sprintf("Provider '%s' not found", providerID)) - return - } - h.internalError(w, err) - return - } - - h.jsonResponse(w, http.StatusOK, prov) -} - -// handleGetProviderModels returns models for a specific provider. -func (h *APIHandlers) handleGetProviderModels(w http.ResponseWriter, _ *http.Request, providerID string) { - prov, err := provider.Get(h.catalog, providerID) - if err != nil { - if _, ok := err.(*errors.NotFoundError); ok { - h.notFound(w, fmt.Sprintf("Provider '%s' not found", providerID)) - return - } - h.internalError(w, err) - return - } - - // Convert map to slice - models := make([]*catalogs.Model, 0, len(prov.Models)) - for _, model := range prov.Models { - models = append(models, model) - } - - response := map[string]any{ - "provider": map[string]any{ - "id": prov.ID, - "name": prov.Name, - }, - "models": models, - "count": len(models), - } - - h.jsonResponse(w, http.StatusOK, response) -} - -// Helper methods for HTTP responses - -func (h *APIHandlers) jsonResponse(w http.ResponseWriter, status int, data any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - - if err := json.NewEncoder(w).Encode(data); err != nil { - logging.Error().Err(err).Msg("Failed to encode JSON response") - } -} - -func (h *APIHandlers) errorResponse(w http.ResponseWriter, status int, message string) { - h.jsonResponse(w, status, map[string]string{ - "error": http.StatusText(status), - "message": message, - }) -} - -func (h *APIHandlers) badRequest(w http.ResponseWriter, message string) { - h.errorResponse(w, http.StatusBadRequest, message) -} - -func (h *APIHandlers) notFound(w http.ResponseWriter, message string) { - h.errorResponse(w, http.StatusNotFound, message) -} - -func (h *APIHandlers) methodNotAllowed(w http.ResponseWriter, r *http.Request) { - h.errorResponse(w, http.StatusMethodNotAllowed, fmt.Sprintf("Method %s not allowed", r.Method)) -} - -func (h *APIHandlers) internalError(w http.ResponseWriter, err error) { - logging.Error().Err(err).Msg("Internal server error") - h.errorResponse(w, http.StatusInternalServerError, "Internal server error") -} diff --git a/cmd/starmap/cmd/serve/shared.go b/cmd/starmap/cmd/serve/shared.go deleted file mode 100644 index 8f741e674..000000000 --- a/cmd/starmap/cmd/serve/shared.go +++ /dev/null @@ -1,125 +0,0 @@ -package serve - -import ( - "context" - "fmt" - "net/http" - "os" - "os/signal" - "strconv" - "syscall" - "time" - - "github.com/spf13/cobra" -) - -// ServerConfig holds common server configuration. -type ServerConfig struct { - Port int - Host string - Environment string - ConfigFile string -} - -// GetServerConfig extracts common server configuration from command flags and environment. -func GetServerConfig(cmd *cobra.Command, defaultPort int) (*ServerConfig, error) { - port, _ := cmd.Flags().GetInt("port") - host, _ := cmd.Flags().GetString("host") - env, _ := cmd.Flags().GetString("env") - config, _ := cmd.Flags().GetString("config") - - // Use default port if not specified - if port == 0 { - port = defaultPort - } - - // Override with PORT environment variable if set - if envPort := os.Getenv("PORT"); envPort != "" { - if p, err := strconv.Atoi(envPort); err == nil { - port = p - } - } - - // Override with HOST environment variable if set - if envHost := os.Getenv("HOST"); envHost != "" { - host = envHost - } - - return &ServerConfig{ - Port: port, - Host: host, - Environment: env, - ConfigFile: config, - }, nil -} - -// Address returns the full address string for binding. -func (c *ServerConfig) Address() string { - return fmt.Sprintf("%s:%d", c.Host, c.Port) -} - -// URL returns the full URL for the server. -func (c *ServerConfig) URL() string { - hostname := c.Host - if hostname == "" || hostname == "0.0.0.0" { - hostname = "localhost" - } - return fmt.Sprintf("http://%s:%d", hostname, c.Port) -} - -// AddCommonFlags adds common server flags to a command. -func AddCommonFlags(cmd *cobra.Command, defaultPort int) { - cmd.Flags().IntP("port", "p", defaultPort, "Port to bind server to") - cmd.Flags().String("host", "localhost", "Host address to bind to") - cmd.Flags().String("env", "development", "Environment mode (development, production)") - cmd.Flags().String("config", "", "Configuration file path") -} - -// StartServerWithGracefulShutdown starts an HTTP server with graceful shutdown. -func StartServerWithGracefulShutdown(server *http.Server, serviceName string) error { - // Start server in a goroutine - serverErr := make(chan error, 1) - go func() { - fmt.Printf("šŸš€ Starting %s server on %s\n", serviceName, server.Addr) - fmt.Println("Press Ctrl+C to stop") - - if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - serverErr <- fmt.Errorf("server failed to start: %w", err) - } - }() - - // Wait for interrupt signal to gracefully shutdown the server - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - - select { - case err := <-serverErr: - return err - case <-quit: - fmt.Printf("\nšŸ›‘ Shutting down %s server...\n", serviceName) - - // Give outstanding requests a deadline to complete - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Shutdown server gracefully - if err := server.Shutdown(ctx); err != nil { - return fmt.Errorf("server forced to shutdown: %w", err) - } - - fmt.Printf("āœ… %s server stopped gracefully\n", serviceName) - return nil - } -} - -// parsePort safely parses a port string to integer. -func parsePort(portStr string) (int, error) { - port, err := strconv.Atoi(portStr) - if err != nil { - return 0, fmt.Errorf("invalid port number: %s", portStr) - } - if port < 1 || port > 65535 { - return 0, fmt.Errorf("port out of range: %d", port) - } - return port, nil -} diff --git a/devbox.json b/devbox.json index da6ea6817..09ddd5a4b 100644 --- a/devbox.json +++ b/devbox.json @@ -2,19 +2,17 @@ "packages": [ "go@1.24.5", "gomarkdoc@1.1.0", - "hugo@0.148.2", - "nodejs@24.5.0", "golangci-lint@2.3.1", - "goreleaser@2.11.2" + "goreleaser@2.11.2", + "path:./devbox/swag#swag" ], "shell": { "init_hook": [ "PS1=\"$(echo -e \"\\033[1;34m%~\\033[0m \\n\\033[0;32m%n@devbox\\033[0m āžœ \")\"", "echo 'šŸ“¦ Starmap development environment loaded'", - "echo 'Tools available: go, hugo, gomarkdoc, golangci-lint, node'" + "echo 'Tools available: go, gomarkdoc, swag (v2 via flake), golangci-lint'" ], "scripts": { - "site": "cd site && hugo serve", "vet": "go vet ./...", "lint": "golangci-lint run", "lint-fix": "golangci-lint run --fix", diff --git a/devbox.lock b/devbox.lock index 0176719b9..8eceb535e 100644 --- a/devbox.lock +++ b/devbox.lock @@ -196,135 +196,6 @@ "store_path": "/nix/store/43bcvl4gr5q8j34gvy0px31y2p6d9kcq-goreleaser-2.11.2" } } - }, - "hugo@0.148.2": { - "last_modified": "2025-07-28T17:09:23Z", - "resolved": "github:NixOS/nixpkgs/648f70160c03151bc2121d179291337ad6bc564b#hugo", - "source": "devbox-search", - "version": "0.148.2", - "systems": { - "aarch64-darwin": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/9yj3fphpkjkkhhr8pfxk5r6ws41n17qy-hugo-0.148.2", - "default": true - } - ], - "store_path": "/nix/store/9yj3fphpkjkkhhr8pfxk5r6ws41n17qy-hugo-0.148.2" - }, - "aarch64-linux": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/zg314jwfh6q428ymnl57nv7jqyyaq1qz-hugo-0.148.2", - "default": true - } - ], - "store_path": "/nix/store/zg314jwfh6q428ymnl57nv7jqyyaq1qz-hugo-0.148.2" - }, - "x86_64-darwin": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/wyhy996fyn61b1cgbl5b7xpha74kxzcq-hugo-0.148.2", - "default": true - } - ], - "store_path": "/nix/store/wyhy996fyn61b1cgbl5b7xpha74kxzcq-hugo-0.148.2" - }, - "x86_64-linux": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/ji15a88bnlzhpkinvbmbb9vnkhmmx3m2-hugo-0.148.2", - "default": true - } - ], - "store_path": "/nix/store/ji15a88bnlzhpkinvbmbb9vnkhmmx3m2-hugo-0.148.2" - } - } - }, - "nodejs@24.5.0": { - "last_modified": "2025-08-11T07:05:29Z", - "plugin_version": "0.0.2", - "resolved": "github:NixOS/nixpkgs/9585e9192aadc13ec3e49f33f8333bd3cda524df#nodejs_24", - "source": "devbox-search", - "version": "24.5.0", - "systems": { - "aarch64-darwin": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/b1j05q96hwagn787p2jlgqcjg2nf5x49-nodejs-24.5.0", - "default": true - }, - { - "name": "dev", - "path": "/nix/store/j6ayg4xpqy9xdxgrhpqylzq8v7v07c6r-nodejs-24.5.0-dev" - }, - { - "name": "libv8", - "path": "/nix/store/3ys6v5s5gvd9snwnl4saynl6av7mz3vy-nodejs-24.5.0-libv8" - } - ], - "store_path": "/nix/store/b1j05q96hwagn787p2jlgqcjg2nf5x49-nodejs-24.5.0" - }, - "aarch64-linux": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/1kn0vh4gf3a22arldrw694apq3fhgp15-nodejs-24.5.0", - "default": true - }, - { - "name": "dev", - "path": "/nix/store/i3lqaj3j6znhnzh8ayka6q85r81ppxnw-nodejs-24.5.0-dev" - }, - { - "name": "libv8", - "path": "/nix/store/jjw6xgmg6qynp336g9igqnzlfbhzxr2i-nodejs-24.5.0-libv8" - } - ], - "store_path": "/nix/store/1kn0vh4gf3a22arldrw694apq3fhgp15-nodejs-24.5.0" - }, - "x86_64-darwin": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/sbcg21wp4bdzyh2542v77sp535kvfbfq-nodejs-24.5.0", - "default": true - }, - { - "name": "libv8", - "path": "/nix/store/75b7iix0pbmxmfnmv90l3q0ll1gc75az-nodejs-24.5.0-libv8" - }, - { - "name": "dev", - "path": "/nix/store/fg7pi9s6m0spci1pfqbny0kxmk832i3r-nodejs-24.5.0-dev" - } - ], - "store_path": "/nix/store/sbcg21wp4bdzyh2542v77sp535kvfbfq-nodejs-24.5.0" - }, - "x86_64-linux": { - "outputs": [ - { - "name": "out", - "path": "/nix/store/357id3rjy9417k4dkvxxmpgd9bxrwc7l-nodejs-24.5.0", - "default": true - }, - { - "name": "dev", - "path": "/nix/store/0drh8jjq84sji6889l2k3ysmvy7sc9sg-nodejs-24.5.0-dev" - }, - { - "name": "libv8", - "path": "/nix/store/kdlv4q7sgap0z43cylklhxz1g1q7751b-nodejs-24.5.0-libv8" - } - ], - "store_path": "/nix/store/357id3rjy9417k4dkvxxmpgd9bxrwc7l-nodejs-24.5.0" - } - } } } } diff --git a/devbox/swag/flake.nix b/devbox/swag/flake.nix new file mode 100644 index 000000000..9b381f319 --- /dev/null +++ b/devbox/swag/flake.nix @@ -0,0 +1,73 @@ +{ + description = "Swag v2 - OpenAPI 3.1 documentation generator"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + version = "2.0.0-rc4"; + + platformInfo = { + "x86_64-linux" = { + platform = "Linux_x86_64"; + sha256 = pkgs.lib.fakeHash; + }; + "aarch64-linux" = { + platform = "Linux_arm64"; + sha256 = pkgs.lib.fakeHash; + }; + "i686-linux" = { + platform = "Linux_i386"; + sha256 = pkgs.lib.fakeHash; + }; + "x86_64-darwin" = { + platform = "Darwin_x86_64"; + sha256 = pkgs.lib.fakeHash; + }; + "aarch64-darwin" = { + platform = "Darwin_arm64"; + sha256 = "sha256-eeMsOoXkqQpO9PkE6VGjBPG/slDtVCKfNSBT/NRSyqs="; + }; + }; + + buildSwag = info: + pkgs.stdenv.mkDerivation { + pname = "swag"; + inherit version; + + src = pkgs.fetchurl { + url = "https://github.com/swaggo/swag/releases/download/v${version}/swag_${version}_${info.platform}.tar.gz"; + sha256 = info.sha256; + }; + + sourceRoot = "."; + + installPhase = '' + mkdir -p $out/bin + cp swag $out/bin/ + chmod +x $out/bin/swag + ''; + }; + + swagForSystem = + if builtins.hasAttr system platformInfo + then buildSwag platformInfo.${system} + else null; + + in { + packages = { + swag = swagForSystem; + default = swagForSystem; + }; + + apps.default = flake-utils.lib.mkApp { + drv = self.packages.${system}.default; + }; + } + ); +} diff --git a/go.mod b/go.mod index b39c84ca4..fc90fd822 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,11 @@ require ( github.com/agentstation/utc v0.0.0-20250811234424-7f4e474c689c github.com/goccy/go-yaml v1.18.0 github.com/google/go-cmp v0.7.0 + github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.5.1 github.com/mattn/go-isatty v0.0.19 github.com/olekukonko/tablewriter v1.0.9 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 @@ -22,7 +24,6 @@ require ( require ( cloud.google.com/go v0.120.0 // indirect cloud.google.com/go/compute/metadata v0.7.0 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.15.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -33,7 +34,6 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect @@ -42,7 +42,6 @@ require ( github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.12.0 // indirect diff --git a/go.sum b/go.sum index c39b3b5b5..c9bd526c4 100644 --- a/go.sum +++ b/go.sum @@ -7,7 +7,6 @@ cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQ github.com/agentstation/utc v0.0.0-20250811234424-7f4e474c689c h1:0tLR/VD8pruLqKP7AYgPTEObMoyS7eQzqCJg2/FRdkE= github.com/agentstation/utc v0.0.0-20250811234424-7f4e474c689c/go.mod h1:6/sYtnBRR4MhH8Oj9JTG9buSeBgWA9c1ppctsY1nFa8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -65,6 +64,8 @@ github.com/olekukonko/ll v0.0.9 h1:Y+1YqDfVkqMWuEQMclsF9HUR5+a82+dxJuL1HHSRpxI= github.com/olekukonko/ll v0.0.9/go.mod h1:En+sEW0JNETl26+K8eZ6/W4UQ7CYSrrgg/EdIYT2H8g= github.com/olekukonko/tablewriter v1.0.9 h1:XGwRsYLC2bY7bNd93Dk51bcPZksWZmLYuaTHR0FqfL8= github.com/olekukonko/tablewriter v1.0.9/go.mod h1:5c+EBPeSqvXnLLgkm9isDdzR3wjfBkHR9Nhfp3NWrzo= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -77,7 +78,6 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= -github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= diff --git a/internal/embedded/openapi/embed.go b/internal/embedded/openapi/embed.go new file mode 100644 index 000000000..cf50fedd8 --- /dev/null +++ b/internal/embedded/openapi/embed.go @@ -0,0 +1,17 @@ +// Package openapi embeds the OpenAPI 3.0 specification files for the Starmap HTTP API. +// These files are embedded at build time and served by the API server at runtime. +package openapi + +import _ "embed" + +// SpecJSON contains the OpenAPI 3.0 specification in JSON format. +// Served at: GET /api/v1/openapi.json +// +//go:embed openapi.json +var SpecJSON []byte + +// SpecYAML contains the OpenAPI 3.0 specification in YAML format. +// Served at: GET /api/v1/openapi.yaml +// +//go:embed openapi.yaml +var SpecYAML []byte diff --git a/internal/embedded/openapi/generate.go b/internal/embedded/openapi/generate.go new file mode 100644 index 000000000..115f9fb8d --- /dev/null +++ b/internal/embedded/openapi/generate.go @@ -0,0 +1,10 @@ +package openapi + +// OpenAPI specs are generated by the Makefile's 'openapi' target. +// Run: make openapi +// +// The generation process: +// 1. swag v2 natively generates OpenAPI 3.1 from code annotations +// 2. Generated files are renamed from swagger.* to openapi.* +// 3. Intermediate docs.go is removed (not needed with //go:embed) +// 4. Final openapi.json and openapi.yaml are embedded via embed.go diff --git a/internal/embedded/openapi/openapi.json b/internal/embedded/openapi/openapi.json new file mode 100644 index 000000000..82c5d0961 --- /dev/null +++ b/internal/embedded/openapi/openapi.json @@ -0,0 +1,10 @@ +{ + "components": {"schemas":{"catalogs.ArchitectureType":{"description":"Type of architecture","type":"string","x-enum-comments":{"ArchitectureTypeCNN":"Convolutional Neural Networks","ArchitectureTypeDiffusion":"Diffusion models (Stable Diffusion, DALL-E, etc.)","ArchitectureTypeGAN":"Generative Adversarial Networks","ArchitectureTypeGRU":"Gated Recurrent Unit networks","ArchitectureTypeLSTM":"Long Short-Term Memory networks","ArchitectureTypeMoE":"Mixture of Experts (Mixtral, GLaM, Switch Transformer)","ArchitectureTypeRNN":"Recurrent Neural Networks","ArchitectureTypeTransformer":"Transformer-based models (GPT, BERT, LLaMA, etc.)","ArchitectureTypeVAE":"Variational Autoencoders"},"x-enum-varnames":["ArchitectureTypeTransformer","ArchitectureTypeMoE","ArchitectureTypeCNN","ArchitectureTypeRNN","ArchitectureTypeLSTM","ArchitectureTypeGRU","ArchitectureTypeVAE","ArchitectureTypeGAN","ArchitectureTypeDiffusion"]},"catalogs.Author":{"properties":{"aliases":{"description":"Alternative IDs this author is known by (e.g., in provider catalogs)","items":{"description":"Unique identifier for the author","type":"string","x-enum-varnames":["AuthorIDOpenAI","AuthorIDAnthropic","AuthorIDGoogle","AuthorIDDeepMind","AuthorIDMeta","AuthorIDMicrosoft","AuthorIDMistralAI","AuthorIDCohere","AuthorIDGroq","AuthorIDAlibabaQwen","AuthorIDQwen","AuthorIDXAI","AuthorIDStanford","AuthorIDMIT","AuthorIDCMU","AuthorIDUCBerkeley","AuthorIDCornell","AuthorIDPrinceton","AuthorIDHarvard","AuthorIDOxford","AuthorIDCambridge","AuthorIDETHZurich","AuthorIDUWashington","AuthorIDUChicago","AuthorIDYale","AuthorIDDuke","AuthorIDCaltech","AuthorIDHuggingFace","AuthorIDEleutherAI","AuthorIDTogether","AuthorIDMosaicML","AuthorIDStabilityAI","AuthorIDRunwayML","AuthorIDMidjourney","AuthorIDLAION","AuthorIDBigScience","AuthorIDAlignmentRC","AuthorIDH2OAI","AuthorIDMoxin","AuthorIDBaidu","AuthorIDTencent","AuthorIDByteDance","AuthorIDDeepSeek","AuthorIDBAAI","AuthorID01AI","AuthorIDBaichuan","AuthorIDMiniMax","AuthorIDMoonshot","AuthorIDShanghaiAI","AuthorIDZhipuAI","AuthorIDSenseTime","AuthorIDHuawei","AuthorIDTsinghua","AuthorIDPeking","AuthorIDNVIDIA","AuthorIDSalesforce","AuthorIDIBM","AuthorIDApple","AuthorIDAmazon","AuthorIDAdept","AuthorIDAI21","AuthorIDInflection","AuthorIDCharacter","AuthorIDPerplexity","AuthorIDAnysphere","AuthorIDCursor","AuthorIDCognitiveComputations","AuthorIDEricHartford","AuthorIDNousResearch","AuthorIDTeknium","AuthorIDJonDurbin","AuthorIDLMSYS","AuthorIDVicuna","AuthorIDAlpacaTeam","AuthorIDWizardLM","AuthorIDOpenOrca","AuthorIDPhind","AuthorIDCodeFuse","AuthorIDTHUDM","AuthorIDGeorgiaTechRI","AuthorIDFastChat","AuthorIDUnknown"]},"type":"array","uniqueItems":false},"catalog":{"$ref":"#/components/schemas/catalogs.AuthorCatalog"},"created_at":{"description":"Timestamps for record keeping and auditing","type":"string"},"description":{"description":"Description of what the author is known for","type":"string"},"github":{"description":"GitHub profile/organization URL","type":"string"},"headquarters":{"description":"Company/organization info","type":"string"},"huggingface":{"description":"Hugging Face profile/organization URL","type":"string"},"icon_url":{"description":"Author icon/logo URL","type":"string"},"id":{"$ref":"#/components/schemas/catalogs.AuthorID"},"name":{"description":"Display name of the author","type":"string"},"twitter":{"description":"X (formerly Twitter) profile URL","type":"string"},"updated_at":{"description":"Last updated date (YYYY-MM or YYYY-MM-DD format)","type":"string"},"website":{"description":"Website, social links, and other relevant URLs","type":"string"}},"type":"object"},"catalogs.AuthorAttribution":{"description":"Model attribution configuration for multi-provider inference","properties":{"patterns":{"description":"Glob patterns to match model IDs","items":{"type":"string"},"type":"array","uniqueItems":false},"provider_id":{"$ref":"#/components/schemas/catalogs.ProviderID"}},"type":"object"},"catalogs.AuthorCatalog":{"description":"Catalog and models","properties":{"attribution":{"$ref":"#/components/schemas/catalogs.AuthorAttribution"},"description":{"description":"Optional description of this mapping relationship","type":"string"}},"type":"object"},"catalogs.AuthorID":{"description":"Unique identifier for the author","type":"string","x-enum-varnames":["AuthorIDOpenAI","AuthorIDAnthropic","AuthorIDGoogle","AuthorIDDeepMind","AuthorIDMeta","AuthorIDMicrosoft","AuthorIDMistralAI","AuthorIDCohere","AuthorIDGroq","AuthorIDAlibabaQwen","AuthorIDQwen","AuthorIDXAI","AuthorIDStanford","AuthorIDMIT","AuthorIDCMU","AuthorIDUCBerkeley","AuthorIDCornell","AuthorIDPrinceton","AuthorIDHarvard","AuthorIDOxford","AuthorIDCambridge","AuthorIDETHZurich","AuthorIDUWashington","AuthorIDUChicago","AuthorIDYale","AuthorIDDuke","AuthorIDCaltech","AuthorIDHuggingFace","AuthorIDEleutherAI","AuthorIDTogether","AuthorIDMosaicML","AuthorIDStabilityAI","AuthorIDRunwayML","AuthorIDMidjourney","AuthorIDLAION","AuthorIDBigScience","AuthorIDAlignmentRC","AuthorIDH2OAI","AuthorIDMoxin","AuthorIDBaidu","AuthorIDTencent","AuthorIDByteDance","AuthorIDDeepSeek","AuthorIDBAAI","AuthorID01AI","AuthorIDBaichuan","AuthorIDMiniMax","AuthorIDMoonshot","AuthorIDShanghaiAI","AuthorIDZhipuAI","AuthorIDSenseTime","AuthorIDHuawei","AuthorIDTsinghua","AuthorIDPeking","AuthorIDNVIDIA","AuthorIDSalesforce","AuthorIDIBM","AuthorIDApple","AuthorIDAmazon","AuthorIDAdept","AuthorIDAI21","AuthorIDInflection","AuthorIDCharacter","AuthorIDPerplexity","AuthorIDAnysphere","AuthorIDCursor","AuthorIDCognitiveComputations","AuthorIDEricHartford","AuthorIDNousResearch","AuthorIDTeknium","AuthorIDJonDurbin","AuthorIDLMSYS","AuthorIDVicuna","AuthorIDAlpacaTeam","AuthorIDWizardLM","AuthorIDOpenOrca","AuthorIDPhind","AuthorIDCodeFuse","AuthorIDTHUDM","AuthorIDGeorgiaTechRI","AuthorIDFastChat","AuthorIDUnknown"]},"catalogs.AuthorMapping":{"description":"Author extraction","properties":{"field":{"description":"Field to extract from (e.g., \"owned_by\")","type":"string"},"normalized":{"additionalProperties":{"description":"Unique identifier for the author","type":"string","x-enum-varnames":["AuthorIDOpenAI","AuthorIDAnthropic","AuthorIDGoogle","AuthorIDDeepMind","AuthorIDMeta","AuthorIDMicrosoft","AuthorIDMistralAI","AuthorIDCohere","AuthorIDGroq","AuthorIDAlibabaQwen","AuthorIDQwen","AuthorIDXAI","AuthorIDStanford","AuthorIDMIT","AuthorIDCMU","AuthorIDUCBerkeley","AuthorIDCornell","AuthorIDPrinceton","AuthorIDHarvard","AuthorIDOxford","AuthorIDCambridge","AuthorIDETHZurich","AuthorIDUWashington","AuthorIDUChicago","AuthorIDYale","AuthorIDDuke","AuthorIDCaltech","AuthorIDHuggingFace","AuthorIDEleutherAI","AuthorIDTogether","AuthorIDMosaicML","AuthorIDStabilityAI","AuthorIDRunwayML","AuthorIDMidjourney","AuthorIDLAION","AuthorIDBigScience","AuthorIDAlignmentRC","AuthorIDH2OAI","AuthorIDMoxin","AuthorIDBaidu","AuthorIDTencent","AuthorIDByteDance","AuthorIDDeepSeek","AuthorIDBAAI","AuthorID01AI","AuthorIDBaichuan","AuthorIDMiniMax","AuthorIDMoonshot","AuthorIDShanghaiAI","AuthorIDZhipuAI","AuthorIDSenseTime","AuthorIDHuawei","AuthorIDTsinghua","AuthorIDPeking","AuthorIDNVIDIA","AuthorIDSalesforce","AuthorIDIBM","AuthorIDApple","AuthorIDAmazon","AuthorIDAdept","AuthorIDAI21","AuthorIDInflection","AuthorIDCharacter","AuthorIDPerplexity","AuthorIDAnysphere","AuthorIDCursor","AuthorIDCognitiveComputations","AuthorIDEricHartford","AuthorIDNousResearch","AuthorIDTeknium","AuthorIDJonDurbin","AuthorIDLMSYS","AuthorIDVicuna","AuthorIDAlpacaTeam","AuthorIDWizardLM","AuthorIDOpenOrca","AuthorIDPhind","AuthorIDCodeFuse","AuthorIDTHUDM","AuthorIDGeorgiaTechRI","AuthorIDFastChat","AuthorIDUnknown"]},"description":"Normalization map (e.g., \"Meta\" -\u003e \"meta\")","type":"object"}},"type":"object"},"catalogs.EndpointType":{"description":"Required: API style","type":"string","x-enum-varnames":["EndpointTypeOpenAI","EndpointTypeAnthropic","EndpointTypeGoogle","EndpointTypeGoogleCloud"]},"catalogs.FeatureRule":{"properties":{"contains":{"description":"If field contains any of these strings","items":{"type":"string"},"type":"array","uniqueItems":false},"feature":{"description":"Feature to enable (e.g., \"tools\", \"reasoning\")","type":"string"},"field":{"description":"Field to check (e.g., \"id\", \"owned_by\")","type":"string"},"value":{"description":"Value to set for the feature","type":"boolean"}},"type":"object"},"catalogs.FieldMapping":{"properties":{"from":{"description":"Source field path in API response (e.g., \"max_model_len\")","type":"string"},"to":{"description":"Target field path in Model (e.g., \"limits.context_window\")","type":"string"}},"type":"object"},"catalogs.FloatRange":{"description":"Alternative sampling strategies (niche)","properties":{"default":{"description":"Default value","type":"number"},"max":{"description":"Maximum value","type":"number"},"min":{"description":"Minimum value","type":"number"}},"type":"object"},"catalogs.IntRange":{"description":"ReasoningTokens - specific token allocation for reasoning processes","properties":{"default":{"description":"Default value","type":"integer"},"max":{"description":"Maximum value","type":"integer"},"min":{"description":"Minimum value","type":"integer"}},"type":"object"},"catalogs.Model":{"properties":{"attachments":{"$ref":"#/components/schemas/catalogs.ModelAttachments"},"authors":{"description":"Authors/organizations of the model (if known)","items":{"$ref":"#/components/schemas/catalogs.Author"},"type":"array","uniqueItems":false},"created_at":{"description":"Timestamps for record keeping and auditing","type":"string"},"description":{"description":"Description of the model and its use cases","type":"string"},"features":{"$ref":"#/components/schemas/catalogs.ModelFeatures"},"generation":{"$ref":"#/components/schemas/catalogs.ModelGeneration"},"id":{"description":"Core identity","type":"string"},"limits":{"$ref":"#/components/schemas/catalogs.ModelLimits"},"metadata":{"$ref":"#/components/schemas/catalogs.ModelMetadata"},"name":{"description":"Display name (must not be empty)","type":"string"},"pricing":{"$ref":"#/components/schemas/catalogs.ModelPricing"},"reasoning":{"$ref":"#/components/schemas/catalogs.ModelControlLevels"},"reasoning_tokens":{"$ref":"#/components/schemas/catalogs.IntRange"},"response":{"$ref":"#/components/schemas/catalogs.ModelDelivery"},"tools":{"$ref":"#/components/schemas/catalogs.ModelTools"},"updated_at":{"description":"Last updated date (YYYY-MM or YYYY-MM-DD format)","type":"string"},"verbosity":{"$ref":"#/components/schemas/catalogs.ModelControlLevels"}},"type":"object"},"catalogs.ModelArchitecture":{"description":"Technical architecture details","properties":{"base_model":{"description":"Base model ID if fine-tuned","type":"string"},"fine_tuned":{"description":"Whether this is a fine-tuned variant","type":"boolean"},"parameter_count":{"description":"Model size (e.g., \"7B\", \"70B\", \"405B\")","type":"string"},"precision":{"description":"Legacy precision format (use Quantization for filtering)","type":"string"},"quantization":{"$ref":"#/components/schemas/catalogs.Quantization"},"quantized":{"description":"Whether the model has been quantized","type":"boolean"},"tokenizer":{"$ref":"#/components/schemas/catalogs.Tokenizer"},"type":{"$ref":"#/components/schemas/catalogs.ArchitectureType"}},"type":"object"},"catalogs.ModelAttachments":{"description":"Attachments - attachment support details","properties":{"max_file_size":{"description":"Maximum file size in bytes","type":"integer"},"max_files":{"description":"Maximum number of files per request","type":"integer"},"mime_types":{"description":"Supported MIME types","items":{"type":"string"},"type":"array","uniqueItems":false}},"type":"object"},"catalogs.ModelControlLevel":{"type":"string","x-enum-varnames":["ModelControlLevelMinimum","ModelControlLevelLow","ModelControlLevelMedium","ModelControlLevelHigh","ModelControlLevelMaximum"]},"catalogs.ModelControlLevels":{"description":"Verbosity - response verbosity levels","properties":{"default":{"description":"Default level","type":"string","x-enum-varnames":["ModelControlLevelMinimum","ModelControlLevelLow","ModelControlLevelMedium","ModelControlLevelHigh","ModelControlLevelMaximum"]},"levels":{"description":"Which levels this model supports","items":{"$ref":"#/components/schemas/catalogs.ModelControlLevel"},"type":"array","uniqueItems":false}},"type":"object"},"catalogs.ModelDelivery":{"description":"Delivery - technical response delivery capabilities (formats, protocols, streaming)","properties":{"formats":{"description":"Available response formats (if format_response feature enabled)","items":{"$ref":"#/components/schemas/catalogs.ModelResponseFormat"},"type":"array","uniqueItems":false},"protocols":{"description":"Response delivery mechanisms","items":{"$ref":"#/components/schemas/catalogs.ModelResponseProtocol"},"type":"array","uniqueItems":false},"streaming":{"description":"Supported streaming modes (sse, websocket, chunked)","items":{"$ref":"#/components/schemas/catalogs.ModelStreaming"},"type":"array","uniqueItems":false}},"type":"object"},"catalogs.ModelFeatures":{"description":"Features - what this model can do","properties":{"allowed_tokens":{"description":"[Niche] Supports token whitelist","type":"boolean"},"attachments":{"description":"Attachment support details","type":"boolean"},"bad_words":{"description":"[Advanced] Supports bad words/disallowed tokens","type":"boolean"},"best_of":{"description":"[Advanced] Supports server-side sampling with best selection","type":"boolean"},"contrastive_search_penalty_alpha":{"description":"[Niche] Supports contrastive decoding","type":"boolean"},"diversity_penalty":{"description":"[Niche] Supports diversity penalty in beam search","type":"boolean"},"early_stopping":{"description":"[Niche] Supports early stopping in beam search","type":"boolean"},"echo":{"description":"[Advanced] Supports echoing prompt with completion","type":"boolean"},"format_response":{"description":"Response delivery","type":"boolean"},"frequency_penalty":{"description":"Generation control - Repetition control","type":"boolean"},"include_reasoning":{"description":"Supports including reasoning traces in response","type":"boolean"},"length_penalty":{"description":"[Niche] Supports length penalty (seq2seq style)","type":"boolean"},"logit_bias":{"description":"Generation control - Token biasing","type":"boolean"},"logprobs":{"description":"Generation control - Observability","type":"boolean"},"max_output_tokens":{"description":"[Core] Supports max_output_tokens parameter (some providers distinguish from max_tokens)","type":"boolean"},"max_tokens":{"description":"Generation control - Length and termination","type":"boolean"},"min_p":{"description":"[Advanced] Supports min_p parameter (minimum probability threshold)","type":"boolean"},"mirostat":{"description":"Generation control - Alternative sampling strategies (niche)","type":"boolean"},"mirostat_eta":{"description":"[Niche] Supports Mirostat eta parameter","type":"boolean"},"mirostat_tau":{"description":"[Niche] Supports Mirostat tau parameter","type":"boolean"},"modalities":{"$ref":"#/components/schemas/catalogs.ModelModalities"},"n":{"description":"Generation control - Multiplicity and reranking","type":"boolean"},"no_repeat_ngram_size":{"description":"[Niche] Supports n-gram repetition blocking","type":"boolean"},"num_beams":{"description":"Generation control - Beam search (niche)","type":"boolean"},"presence_penalty":{"description":"[Core] Supports presence penalty","type":"boolean"},"reasoning":{"description":"Reasoning \u0026 Verbosity","type":"boolean"},"reasoning_effort":{"description":"Supports configurable reasoning intensity","type":"boolean"},"reasoning_tokens":{"description":"Supports specific reasoning token allocation","type":"boolean"},"repetition_penalty":{"description":"[Advanced] Supports repetition penalty","type":"boolean"},"seed":{"description":"Generation control - Determinism","type":"boolean"},"stop":{"description":"[Core] Supports stop sequences/words","type":"boolean"},"stop_token_ids":{"description":"[Advanced] Supports stop token IDs (numeric)","type":"boolean"},"streaming":{"description":"Supports response streaming","type":"boolean"},"structured_outputs":{"description":"Supports structured outputs (JSON schema validation)","type":"boolean"},"temperature":{"description":"Generation control - Core sampling and decoding","type":"boolean"},"tfs":{"description":"[Advanced] Supports tail free sampling","type":"boolean"},"tool_calls":{"description":"Core capabilities\nTool calling system - three distinct aspects:","type":"boolean"},"tool_choice":{"description":"Supports tool choice strategies (auto/none/required control)","type":"boolean"},"tools":{"description":"Accepts tool definitions in requests (accepts tools parameter)","type":"boolean"},"top_a":{"description":"[Advanced] Supports top_a parameter (top-a sampling)","type":"boolean"},"top_k":{"description":"[Advanced] Supports top_k parameter","type":"boolean"},"top_logprobs":{"description":"[Core] Supports returning top N log probabilities","type":"boolean"},"top_p":{"description":"[Core] Supports top_p parameter (nucleus sampling)","type":"boolean"},"typical_p":{"description":"[Advanced] Supports typical_p parameter (typical sampling)","type":"boolean"},"verbosity":{"description":"Supports verbosity control (GPT-5+)","type":"boolean"},"web_search":{"description":"Supports web search capabilities","type":"boolean"}},"type":"object"},"catalogs.ModelGeneration":{"description":"Generation - core chat completions generation controls","properties":{"best_of":{"$ref":"#/components/schemas/catalogs.IntRange"},"contrastive_search_penalty_alpha":{"$ref":"#/components/schemas/catalogs.FloatRange"},"diversity_penalty":{"$ref":"#/components/schemas/catalogs.FloatRange"},"frequency_penalty":{"$ref":"#/components/schemas/catalogs.FloatRange"},"length_penalty":{"$ref":"#/components/schemas/catalogs.FloatRange"},"max_output_tokens":{"type":"integer"},"max_tokens":{"description":"Length and termination","type":"integer"},"min_p":{"$ref":"#/components/schemas/catalogs.FloatRange"},"mirostat_eta":{"$ref":"#/components/schemas/catalogs.FloatRange"},"mirostat_tau":{"$ref":"#/components/schemas/catalogs.FloatRange"},"n":{"$ref":"#/components/schemas/catalogs.IntRange"},"no_repeat_ngram_size":{"$ref":"#/components/schemas/catalogs.IntRange"},"num_beams":{"$ref":"#/components/schemas/catalogs.IntRange"},"presence_penalty":{"$ref":"#/components/schemas/catalogs.FloatRange"},"repetition_penalty":{"$ref":"#/components/schemas/catalogs.FloatRange"},"temperature":{"$ref":"#/components/schemas/catalogs.FloatRange"},"tfs":{"$ref":"#/components/schemas/catalogs.FloatRange"},"top_a":{"$ref":"#/components/schemas/catalogs.FloatRange"},"top_k":{"$ref":"#/components/schemas/catalogs.IntRange"},"top_logprobs":{"description":"Observability","type":"integer"},"top_p":{"$ref":"#/components/schemas/catalogs.FloatRange"},"typical_p":{"$ref":"#/components/schemas/catalogs.FloatRange"}},"type":"object"},"catalogs.ModelLimits":{"description":"Model limits","properties":{"context_window":{"description":"Context window size in tokens","type":"integer"},"output_tokens":{"description":"Maximum output tokens","type":"integer"}},"type":"object"},"catalogs.ModelMetadata":{"description":"Metadata - version and timing information","properties":{"architecture":{"$ref":"#/components/schemas/catalogs.ModelArchitecture"},"knowledge_cutoff":{"description":"Knowledge cutoff date (YYYY-MM or YYYY-MM-DD format)","type":"string"},"open_weights":{"description":"Whether model weights are open","type":"boolean"},"release_date":{"description":"Release date (YYYY-MM or YYYY-MM-DD format)","type":"string"},"tags":{"description":"Use case tags for categorizing the model","items":{"$ref":"#/components/schemas/catalogs.ModelTag"},"type":"array","uniqueItems":false}},"type":"object"},"catalogs.ModelModalities":{"description":"Input/Output modalities","properties":{"input":{"description":"Supported input modalities","items":{"$ref":"#/components/schemas/catalogs.ModelModality"},"type":"array","uniqueItems":false},"output":{"description":"Supported output modalities","items":{"type":"string","x-enum-comments":{"ModelModalityEmbedding":"Vector embeddings"},"x-enum-varnames":["ModelModalityText","ModelModalityAudio","ModelModalityImage","ModelModalityVideo","ModelModalityPDF","ModelModalityEmbedding"]},"type":"array","uniqueItems":false}},"type":"object"},"catalogs.ModelModality":{"type":"string","x-enum-comments":{"ModelModalityEmbedding":"Vector embeddings"},"x-enum-varnames":["ModelModalityText","ModelModalityAudio","ModelModalityImage","ModelModalityVideo","ModelModalityPDF","ModelModalityEmbedding"]},"catalogs.ModelOperationPricing":{"description":"Fixed costs per operation","properties":{"audio_gen":{"description":"Cost per audio generated","type":"number"},"audio_input":{"description":"Cost per audio input","type":"number"},"function_call":{"description":"Cost per function call","type":"number"},"image_gen":{"description":"Generation operations","type":"number"},"image_input":{"description":"Media operations","type":"number"},"request":{"description":"Core operations","type":"number"},"tool_use":{"description":"Cost per tool usage","type":"number"},"video_gen":{"description":"Cost per video generated","type":"number"},"video_input":{"description":"Cost per video input","type":"number"},"web_search":{"description":"Service operations","type":"number"}},"type":"object"},"catalogs.ModelPricing":{"description":"Operational characteristics","properties":{"currency":{"$ref":"#/components/schemas/catalogs.ModelPricingCurrency"},"operations":{"$ref":"#/components/schemas/catalogs.ModelOperationPricing"},"tokens":{"$ref":"#/components/schemas/catalogs.ModelTokenPricing"}},"type":"object"},"catalogs.ModelPricingCurrency":{"description":"Metadata","type":"string","x-enum-comments":{"ModelPricingCurrencyAUD":"Australian Dollar","ModelPricingCurrencyCAD":"Canadian Dollar","ModelPricingCurrencyCNY":"Chinese Yuan","ModelPricingCurrencyEUR":"Euro","ModelPricingCurrencyGBP":"British Pound Sterling","ModelPricingCurrencyJPY":"Japanese Yen","ModelPricingCurrencyNZD":"New Zealand Dollar","ModelPricingCurrencyUSD":"US Dollar"},"x-enum-varnames":["ModelPricingCurrencyUSD","ModelPricingCurrencyEUR","ModelPricingCurrencyJPY","ModelPricingCurrencyGBP","ModelPricingCurrencyAUD","ModelPricingCurrencyCAD","ModelPricingCurrencyCNY","ModelPricingCurrencyNZD"]},"catalogs.ModelResponseFormat":{"type":"string","x-enum-comments":{"ModelResponseFormatFunctionCall":"Tool/function calling for structured data","ModelResponseFormatJSON":"JSON encouraged via prompting","ModelResponseFormatJSONMode":"Forced valid JSON (OpenAI style)","ModelResponseFormatJSONObject":"Same as json_mode (OpenAI API name)","ModelResponseFormatJSONSchema":"Schema-validated JSON (OpenAI structured output)","ModelResponseFormatStructuredOutput":"General structured output support","ModelResponseFormatText":"Plain text responses (default)"},"x-enum-varnames":["ModelResponseFormatText","ModelResponseFormatJSON","ModelResponseFormatJSONMode","ModelResponseFormatJSONObject","ModelResponseFormatJSONSchema","ModelResponseFormatStructuredOutput","ModelResponseFormatFunctionCall"]},"catalogs.ModelResponseProtocol":{"type":"string","x-enum-comments":{"ModelResponseProtocolGRPC":"gRPC protocol","ModelResponseProtocolHTTP":"HTTP/HTTPS REST API","ModelResponseProtocolWebSocket":"WebSocket protocol"},"x-enum-varnames":["ModelResponseProtocolHTTP","ModelResponseProtocolGRPC","ModelResponseProtocolWebSocket"]},"catalogs.ModelStreaming":{"type":"string","x-enum-comments":{"ModelStreamingChunked":"HTTP chunked transfer encoding","ModelStreamingSSE":"Server-Sent Events streaming","ModelStreamingWebSocket":"WebSocket streaming"},"x-enum-varnames":["ModelStreamingSSE","ModelStreamingWebSocket","ModelStreamingChunked"]},"catalogs.ModelTag":{"type":"string","x-enum-comments":{"ModelTagAudio":"Audio processing","ModelTagChat":"Conversational AI","ModelTagCoding":"Programming and code generation","ModelTagCreative":"Creative content generation","ModelTagEducation":"Educational content","ModelTagEmbedding":"Text embeddings","ModelTagFinance":"Financial analysis","ModelTagFunctionCalling":"Tool/function calling","ModelTagImageToText":"Image captioning/OCR","ModelTagInstruct":"Instruction following","ModelTagLegal":"Legal document processing","ModelTagMath":"Mathematical problem solving","ModelTagMedical":"Medical and healthcare","ModelTagMultimodal":"Multiple input modalities","ModelTagQA":"Question answering","ModelTagReasoning":"Logical reasoning and problem solving","ModelTagResearch":"Research and analysis","ModelTagRoleplay":"Character roleplay and simulation","ModelTagScience":"Scientific applications","ModelTagSpeechToText":"Speech recognition","ModelTagSummarization":"Text summarization","ModelTagTextToImage":"Text-to-image generation","ModelTagTextToSpeech":"Text-to-speech synthesis","ModelTagTranslation":"Language translation","ModelTagVision":"Computer vision","ModelTagWriting":"Creative and technical writing"},"x-enum-varnames":["ModelTagCoding","ModelTagWriting","ModelTagReasoning","ModelTagMath","ModelTagChat","ModelTagInstruct","ModelTagResearch","ModelTagCreative","ModelTagRoleplay","ModelTagFunctionCalling","ModelTagEmbedding","ModelTagSummarization","ModelTagTranslation","ModelTagQA","ModelTagVision","ModelTagMultimodal","ModelTagAudio","ModelTagTextToImage","ModelTagTextToSpeech","ModelTagSpeechToText","ModelTagImageToText","ModelTagMedical","ModelTagLegal","ModelTagFinance","ModelTagScience","ModelTagEducation"]},"catalogs.ModelTokenCachePricing":{"description":"Cache operations","properties":{"read":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"},"write":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"}},"type":"object"},"catalogs.ModelTokenCost":{"description":"Alternative flat cache structure (for backward compatibility)","properties":{"per_1m_tokens":{"description":"Cost per 1M tokens","type":"number"},"per_token":{"description":"Cost per individual token","type":"number"}},"type":"object"},"catalogs.ModelTokenPricing":{"description":"Token-based costs","properties":{"cache":{"$ref":"#/components/schemas/catalogs.ModelTokenCachePricing"},"cache_read":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"},"cache_write":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"},"input":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"},"output":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"},"reasoning":{"$ref":"#/components/schemas/catalogs.ModelTokenCost"}},"type":"object"},"catalogs.ModelTools":{"description":"Tools - external tool and capability integrations","properties":{"tool_choices":{"description":"Tool calling configuration\nSpecifies which tool choice strategies this model supports.\nRequires both Tools=true and ToolChoice=true in ModelFeatures.\nCommon values: [\"auto\"], [\"auto\", \"none\"], [\"auto\", \"none\", \"required\"]","items":{"$ref":"#/components/schemas/catalogs.ToolChoice"},"type":"array","uniqueItems":false},"web_search":{"$ref":"#/components/schemas/catalogs.ModelWebSearch"}},"type":"object"},"catalogs.ModelWebSearch":{"description":"Web search configuration\nOnly applicable if WebSearch=true in ModelFeatures","properties":{"default_context_size":{"description":"Default search context size","type":"string","x-enum-varnames":["ModelControlLevelMinimum","ModelControlLevelLow","ModelControlLevelMedium","ModelControlLevelHigh","ModelControlLevelMaximum"]},"max_results":{"description":"Plugin-based web search options (for models using OpenRouter's web plugin)","type":"integer"},"search_context_sizes":{"description":"Built-in web search options (for models with native web search like GPT-4.1, Perplexity)","items":{"type":"string","x-enum-varnames":["ModelControlLevelMinimum","ModelControlLevelLow","ModelControlLevelMedium","ModelControlLevelHigh","ModelControlLevelMaximum"]},"type":"array","uniqueItems":false},"search_prompt":{"description":"Custom prompt for search results","type":"string"}},"type":"object"},"catalogs.Provider":{"properties":{"aliases":{"description":"Alternative IDs this provider is known by (e.g., in models.dev)","items":{"description":"Optional provider to source models from","type":"string","x-enum-varnames":["ProviderIDAlibabaQwen","ProviderIDAnthropic","ProviderIDAnyscale","ProviderIDCerebras","ProviderIDCheckstep","ProviderIDCohere","ProviderIDConectys","ProviderIDCove","ProviderIDDeepMind","ProviderIDDeepSeek","ProviderIDGoogleAIStudio","ProviderIDGoogleVertex","ProviderIDGroq","ProviderIDHuggingFace","ProviderIDMeta","ProviderIDMicrosoft","ProviderIDMistralAI","ProviderIDOpenAI","ProviderIDOpenRouter","ProviderIDPerplexity","ProviderIDReplicate","ProviderIDSafetyKit","ProviderIDTogetherAI","ProviderIDVirtuousAI","ProviderIDWebPurify","ProviderIDXAI"]},"type":"array","uniqueItems":false},"api_key":{"$ref":"#/components/schemas/catalogs.ProviderAPIKey"},"catalog":{"$ref":"#/components/schemas/catalogs.ProviderCatalog"},"chat_completions":{"$ref":"#/components/schemas/catalogs.ProviderChatCompletions"},"env_vars":{"description":"Environment variables configuration","items":{"$ref":"#/components/schemas/catalogs.ProviderEnvVar"},"type":"array","uniqueItems":false},"governance_policy":{"$ref":"#/components/schemas/catalogs.ProviderGovernancePolicy"},"headquarters":{"description":"Company headquarters location","type":"string"},"icon_url":{"description":"Provider icon/logo URL","type":"string"},"id":{"description":"Core identification and integration","type":"string","x-enum-varnames":["ProviderIDAlibabaQwen","ProviderIDAnthropic","ProviderIDAnyscale","ProviderIDCerebras","ProviderIDCheckstep","ProviderIDCohere","ProviderIDConectys","ProviderIDCove","ProviderIDDeepMind","ProviderIDDeepSeek","ProviderIDGoogleAIStudio","ProviderIDGoogleVertex","ProviderIDGroq","ProviderIDHuggingFace","ProviderIDMeta","ProviderIDMicrosoft","ProviderIDMistralAI","ProviderIDOpenAI","ProviderIDOpenRouter","ProviderIDPerplexity","ProviderIDReplicate","ProviderIDSafetyKit","ProviderIDTogetherAI","ProviderIDVirtuousAI","ProviderIDWebPurify","ProviderIDXAI"]},"name":{"description":"Display name (must not be empty)","type":"string"},"privacy_policy":{"$ref":"#/components/schemas/catalogs.ProviderPrivacyPolicy"},"retention_policy":{"$ref":"#/components/schemas/catalogs.ProviderRetentionPolicy"},"status_page_url":{"description":"Status \u0026 Health","type":"string"}},"type":"object"},"catalogs.ProviderAPIKey":{"description":"API key configuration","properties":{"header":{"description":"Header name to send the API key in","type":"string"},"name":{"description":"Name of the API key parameter","type":"string"},"pattern":{"description":"Glob pattern to match the API key","type":"string"},"query_param":{"description":"Query parameter name to send the API key in","type":"string"},"scheme":{"$ref":"#/components/schemas/catalogs.ProviderAPIKeyScheme"}},"type":"object"},"catalogs.ProviderAPIKeyScheme":{"description":"Authentication scheme (e.g., \"Bearer\", \"Basic\", or empty for direct value)","type":"string","x-enum-comments":{"ProviderAPIKeySchemeBasic":"Basic authentication","ProviderAPIKeySchemeBearer":"Bearer token authentication (OAuth 2.0 style)","ProviderAPIKeySchemeDirect":"Direct value (no scheme prefix)"},"x-enum-varnames":["ProviderAPIKeySchemeBearer","ProviderAPIKeySchemeBasic","ProviderAPIKeySchemeDirect"]},"catalogs.ProviderCatalog":{"description":"Models","properties":{"authors":{"description":"List of authors to fetch from (for providers like Google Vertex AI)","items":{"description":"Unique identifier for the author","type":"string","x-enum-varnames":["AuthorIDOpenAI","AuthorIDAnthropic","AuthorIDGoogle","AuthorIDDeepMind","AuthorIDMeta","AuthorIDMicrosoft","AuthorIDMistralAI","AuthorIDCohere","AuthorIDGroq","AuthorIDAlibabaQwen","AuthorIDQwen","AuthorIDXAI","AuthorIDStanford","AuthorIDMIT","AuthorIDCMU","AuthorIDUCBerkeley","AuthorIDCornell","AuthorIDPrinceton","AuthorIDHarvard","AuthorIDOxford","AuthorIDCambridge","AuthorIDETHZurich","AuthorIDUWashington","AuthorIDUChicago","AuthorIDYale","AuthorIDDuke","AuthorIDCaltech","AuthorIDHuggingFace","AuthorIDEleutherAI","AuthorIDTogether","AuthorIDMosaicML","AuthorIDStabilityAI","AuthorIDRunwayML","AuthorIDMidjourney","AuthorIDLAION","AuthorIDBigScience","AuthorIDAlignmentRC","AuthorIDH2OAI","AuthorIDMoxin","AuthorIDBaidu","AuthorIDTencent","AuthorIDByteDance","AuthorIDDeepSeek","AuthorIDBAAI","AuthorID01AI","AuthorIDBaichuan","AuthorIDMiniMax","AuthorIDMoonshot","AuthorIDShanghaiAI","AuthorIDZhipuAI","AuthorIDSenseTime","AuthorIDHuawei","AuthorIDTsinghua","AuthorIDPeking","AuthorIDNVIDIA","AuthorIDSalesforce","AuthorIDIBM","AuthorIDApple","AuthorIDAmazon","AuthorIDAdept","AuthorIDAI21","AuthorIDInflection","AuthorIDCharacter","AuthorIDPerplexity","AuthorIDAnysphere","AuthorIDCursor","AuthorIDCognitiveComputations","AuthorIDEricHartford","AuthorIDNousResearch","AuthorIDTeknium","AuthorIDJonDurbin","AuthorIDLMSYS","AuthorIDVicuna","AuthorIDAlpacaTeam","AuthorIDWizardLM","AuthorIDOpenOrca","AuthorIDPhind","AuthorIDCodeFuse","AuthorIDTHUDM","AuthorIDGeorgiaTechRI","AuthorIDFastChat","AuthorIDUnknown"]},"type":"array","uniqueItems":false},"docs":{"description":"Documentation URL","type":"string"},"endpoint":{"$ref":"#/components/schemas/catalogs.ProviderEndpoint"}},"type":"object"},"catalogs.ProviderChatCompletions":{"description":"Chat completions API configuration","properties":{"health_api_url":{"description":"URL to health/status API for this service","type":"string"},"health_components":{"description":"Specific components to monitor for chat completions","items":{"$ref":"#/components/schemas/catalogs.ProviderHealthComponent"},"type":"array","uniqueItems":false},"url":{"description":"Chat completions API endpoint URL","type":"string"}},"type":"object"},"catalogs.ProviderEndpoint":{"description":"API endpoint configuration","properties":{"auth_required":{"description":"Required: Whether auth needed","type":"boolean"},"author_mapping":{"$ref":"#/components/schemas/catalogs.AuthorMapping"},"feature_rules":{"description":"Feature inference rules","items":{"$ref":"#/components/schemas/catalogs.FeatureRule"},"type":"array","uniqueItems":false},"field_mappings":{"description":"Field mappings","items":{"$ref":"#/components/schemas/catalogs.FieldMapping"},"type":"array","uniqueItems":false},"type":{"$ref":"#/components/schemas/catalogs.EndpointType"},"url":{"description":"Required: API endpoint","type":"string"}},"type":"object"},"catalogs.ProviderEnvVar":{"properties":{"description":{"description":"Human-readable description","type":"string"},"name":{"description":"Environment variable name","type":"string"},"pattern":{"description":"Optional validation pattern","type":"string"},"required":{"description":"Whether this env var is required","type":"boolean"}},"type":"object"},"catalogs.ProviderGovernancePolicy":{"description":"Oversight and moderation practices","properties":{"moderated":{"description":"Whether provider content is moderated","type":"boolean"},"moderation_required":{"description":"Whether the provider requires moderation","type":"boolean"},"moderator":{"description":"Who moderates the provider","type":"string"}},"type":"object"},"catalogs.ProviderHealthComponent":{"properties":{"id":{"description":"Component ID from the health API","type":"string"},"name":{"description":"Human-readable component name","type":"string"}},"type":"object"},"catalogs.ProviderID":{"description":"Optional provider to source models from","type":"string","x-enum-varnames":["ProviderIDAlibabaQwen","ProviderIDAnthropic","ProviderIDAnyscale","ProviderIDCerebras","ProviderIDCheckstep","ProviderIDCohere","ProviderIDConectys","ProviderIDCove","ProviderIDDeepMind","ProviderIDDeepSeek","ProviderIDGoogleAIStudio","ProviderIDGoogleVertex","ProviderIDGroq","ProviderIDHuggingFace","ProviderIDMeta","ProviderIDMicrosoft","ProviderIDMistralAI","ProviderIDOpenAI","ProviderIDOpenRouter","ProviderIDPerplexity","ProviderIDReplicate","ProviderIDSafetyKit","ProviderIDTogetherAI","ProviderIDVirtuousAI","ProviderIDWebPurify","ProviderIDXAI"]},"catalogs.ProviderPrivacyPolicy":{"description":"Privacy, Retention, and Governance Policies","properties":{"privacy_policy_url":{"description":"Link to privacy policy","type":"string"},"retains_data":{"description":"Whether provider stores/retains user data","type":"boolean"},"terms_of_service_url":{"description":"Link to terms of service","type":"string"},"trains_on_data":{"description":"Whether provider trains models on user data","type":"boolean"}},"type":"object"},"catalogs.ProviderRetentionPolicy":{"description":"Data retention and deletion practices","properties":{"details":{"description":"Human-readable description","type":"string"},"duration":{"$ref":"#/components/schemas/time.Duration"},"type":{"$ref":"#/components/schemas/catalogs.ProviderRetentionType"}},"type":"object"},"catalogs.ProviderRetentionType":{"description":"Type of retention policy","type":"string","x-enum-comments":{"ProviderRetentionTypeConditional":"Based on conditions (e.g., \"until account deletion\")","ProviderRetentionTypeFixed":"Specific duration (use Duration field)","ProviderRetentionTypeIndefinite":"Forever (duration = nil)","ProviderRetentionTypeNone":"No retention (immediate deletion)"},"x-enum-varnames":["ProviderRetentionTypeFixed","ProviderRetentionTypeNone","ProviderRetentionTypeIndefinite","ProviderRetentionTypeConditional"]},"catalogs.Quantization":{"description":"Quantization level used by the model","type":"string","x-enum-comments":{"QuantizationBF16":"Brain floating point (16 bit)","QuantizationFP16":"Floating point (16 bit)","QuantizationFP32":"Floating point (32 bit)","QuantizationFP4":"Floating point (4 bit)","QuantizationFP6":"Floating point (6 bit)","QuantizationFP8":"Floating point (8 bit)","QuantizationINT4":"Integer (4 bit)","QuantizationINT8":"Integer (8 bit)","QuantizationUnknown":"Unknown quantization"},"x-enum-varnames":["QuantizationINT4","QuantizationINT8","QuantizationFP4","QuantizationFP6","QuantizationFP8","QuantizationFP16","QuantizationBF16","QuantizationFP32","QuantizationUnknown"]},"catalogs.Tokenizer":{"description":"Tokenizer type used by the model","type":"string","x-enum-comments":{"TokenizerClaude":"Claude tokenizer","TokenizerCohere":"Cohere tokenizer","TokenizerDeepSeek":"DeepSeek tokenizer","TokenizerGPT":"GPT tokenizer (OpenAI)","TokenizerGemini":"Gemini tokenizer (Google)","TokenizerGrok":"Grok tokenizer (xAI)","TokenizerLlama2":"LLaMA 2 tokenizer","TokenizerLlama3":"LLaMA 3 tokenizer","TokenizerLlama4":"LLaMA 4 tokenizer","TokenizerMistral":"Mistral tokenizer","TokenizerNova":"Nova tokenizer (Amazon)","TokenizerQwen":"Qwen tokenizer","TokenizerQwen3":"Qwen 3 tokenizer","TokenizerRouter":"Router-based tokenizer","TokenizerUnknown":"Unknown tokenizer type","TokenizerYi":"Yi tokenizer"},"x-enum-varnames":["TokenizerClaude","TokenizerCohere","TokenizerDeepSeek","TokenizerGPT","TokenizerGemini","TokenizerGrok","TokenizerLlama2","TokenizerLlama3","TokenizerLlama4","TokenizerMistral","TokenizerNova","TokenizerQwen","TokenizerQwen3","TokenizerRouter","TokenizerYi","TokenizerUnknown"]},"catalogs.ToolChoice":{"type":"string","x-enum-comments":{"ToolChoiceAuto":"Model autonomously decides whether to call tools based on context","ToolChoiceNone":"Model will never call tools, even if tool definitions are provided","ToolChoiceRequired":"Model must call at least one tool before responding"},"x-enum-varnames":["ToolChoiceAuto","ToolChoiceNone","ToolChoiceRequired"]},"data":{"properties":{"data":{"type":"object"}},"type":"object"},"error":{"properties":{"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"},"handlers.DateRange":{"properties":{"after":{"type":"string"},"before":{"type":"string"}},"type":"object"},"handlers.IntRange":{"properties":{"max":{"type":"integer"},"min":{"type":"integer"}},"type":"object"},"handlers.SearchModalities":{"properties":{"input":{"items":{"type":"string"},"type":"array","uniqueItems":false},"output":{"items":{"type":"string"},"type":"array","uniqueItems":false}},"type":"object"},"handlers.SearchRequest":{"properties":{"context_window":{"$ref":"#/components/schemas/handlers.IntRange"},"features":{"additionalProperties":{"type":"boolean"},"type":"object"},"ids":{"items":{"type":"string"},"type":"array","uniqueItems":false},"max_results":{"type":"integer"},"modalities":{"$ref":"#/components/schemas/handlers.SearchModalities"},"name_contains":{"type":"string"},"open_weights":{"type":"boolean"},"order":{"type":"string"},"output_tokens":{"$ref":"#/components/schemas/handlers.IntRange"},"provider":{"type":"string"},"release_date":{"$ref":"#/components/schemas/handlers.DateRange"},"sort":{"type":"string"},"tags":{"items":{"type":"string"},"type":"array","uniqueItems":false}},"type":"object"},"response.Error":{"properties":{"code":{"type":"string"},"details":{"type":"string"},"message":{"type":"string"}},"type":"object"},"response.Response":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"},"time.Duration":{"description":"nil = forever, 0 = immediate deletion","type":"integer","x-enum-varnames":["minDuration","maxDuration","Nanosecond","Microsecond","Millisecond","Second","Minute","Hour"]}},"securitySchemes":{"ApiKeyAuth":{"description":"API key for authentication (optional, configurable)","in":"header","name":"X-API-Key","type":"apiKey"}}}, + "info": {"contact":{"name":"Starmap Project","url":"https://github.com/agentstation/starmap"},"description":"REST API for the Starmap AI model catalog with real-time updates via WebSocket and SSE.\n\nFeatures:\n- Comprehensive model and provider queries\n- Advanced filtering and search\n- Real-time updates via WebSocket and Server-Sent Events\n- In-memory caching for performance\n- Rate limiting and authentication support","license":{"name":"MIT","url":"https://github.com/agentstation/starmap/blob/master/LICENSE"},"title":"Starmap API","version":"1.0"}, + "externalDocs": {"description":"","url":""}, + "paths": {"/api/v1/health":{"get":{"description":"Health check endpoint (liveness probe)","requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"}},"summary":"Health check","tags":["health"]}},"/api/v1/models":{"get":{"description":"List all models with optional filtering","parameters":[{"description":"Filter by exact model ID","in":"query","name":"id","schema":{"type":"string"}},{"description":"Filter by exact model name (case-insensitive)","in":"query","name":"name","schema":{"type":"string"}},{"description":"Filter by partial model name match","in":"query","name":"name_contains","schema":{"type":"string"}},{"description":"Filter by provider ID","in":"query","name":"provider","schema":{"type":"string"}},{"description":"Filter by input modality (comma-separated)","in":"query","name":"modality_input","schema":{"type":"string"}},{"description":"Filter by output modality (comma-separated)","in":"query","name":"modality_output","schema":{"type":"string"}},{"description":"Filter by feature (streaming, tool_calls, etc.)","in":"query","name":"feature","schema":{"type":"string"}},{"description":"Filter by tag (comma-separated)","in":"query","name":"tag","schema":{"type":"string"}},{"description":"Filter by open weights status","in":"query","name":"open_weights","schema":{"type":"boolean"}},{"description":"Minimum context window size","in":"query","name":"min_context","schema":{"type":"integer"}},{"description":"Maximum context window size","in":"query","name":"max_context","schema":{"type":"integer"}},{"description":"Sort field (id, name, release_date, context_window, created_at, updated_at)","in":"query","name":"sort","schema":{"type":"string"}},{"description":"Sort order (asc, desc)","in":"query","name":"order","schema":{"type":"string"}},{"description":"Maximum number of results (default: 100, max: 1000)","in":"query","name":"limit","schema":{"type":"integer"}},{"description":"Result offset for pagination","in":"query","name":"offset","schema":{"type":"integer"}}],"requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"400":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Bad Request"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"List models","tags":["models"]}},"/api/v1/models/search":{"post":{"description":"Advanced search with multiple criteria","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/handlers.SearchRequest"}}},"description":"Search criteria","required":true},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"400":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Bad Request"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Search models","tags":["models"]}},"/api/v1/models/{id}":{"get":{"description":"Retrieve detailed information about a specific model","parameters":[{"description":"Model ID","in":"path","name":"id","required":true,"schema":{"type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"404":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Not Found"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Get model by ID","tags":["models"]}},"/api/v1/openapi.json":{"get":{"description":"Returns the OpenAPI 3.1 specification for this API in JSON format","responses":{"200":{"content":{"application/json":{"schema":{"type":"object"}}},"description":"OpenAPI 3.1 specification"}},"summary":"Get OpenAPI specification (JSON)","tags":["meta"]}},"/api/v1/openapi.yaml":{"get":{"description":"Returns the OpenAPI 3.1 specification for this API in YAML format","responses":{"200":{"content":{"application/json":{"schema":{"type":"string"}},"application/x-yaml":{"schema":{"type":"string"}}},"description":"OpenAPI 3.1 specification"}},"summary":"Get OpenAPI specification (YAML)","tags":["meta"]}},"/api/v1/providers":{"get":{"description":"List all providers","requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"List providers","tags":["providers"]}},"/api/v1/providers/{id}":{"get":{"description":"Retrieve detailed information about a specific provider","parameters":[{"description":"Provider ID","in":"path","name":"id","required":true,"schema":{"type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"404":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Not Found"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Get provider by ID","tags":["providers"]}},"/api/v1/providers/{id}/models":{"get":{"description":"List all models for a specific provider","parameters":[{"description":"Provider ID","in":"path","name":"id","required":true,"schema":{"type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"404":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Not Found"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Get provider models","tags":["providers"]}},"/api/v1/ready":{"get":{"description":"Readiness check including cache and data source status","requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"503":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Service Unavailable"}},"summary":"Readiness check","tags":["health"]}},"/api/v1/stats":{"get":{"description":"Get catalog statistics (model count, provider count, last sync)","requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Catalog statistics","tags":["admin"]}},"/api/v1/update":{"post":{"description":"Manually trigger catalog synchronization","parameters":[{"description":"Update specific provider only","in":"query","name":"provider","schema":{"type":"string"}}],"requestBody":{"content":{"application/json":{"schema":{"type":"object"}}}},"responses":{"200":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"OK"},"500":{"content":{"application/json":{"schema":{"allOf":[{"$ref":"#/components/schemas/error"}],"properties":{"data":{},"error":{"$ref":"#/components/schemas/response.Error"}},"type":"object"}}},"description":"Internal Server Error"}},"security":[{"ApiKeyAuth":[]}],"summary":"Trigger catalog update","tags":["admin"]}},"/api/v1/updates/stream":{"get":{"description":"Server-Sent Events stream for catalog change notifications","responses":{"200":{"content":{"text/event-stream":{"schema":{"type":"string"}}},"description":"Event stream"}},"summary":"SSE updates stream","tags":["updates"]}},"/api/v1/updates/ws":{"get":{"description":"WebSocket connection for real-time catalog updates","responses":{"101":{"description":"Switching Protocols"}},"summary":"WebSocket updates","tags":["updates"]}}}, + "openapi": "3.1.0", + "servers": [ + {"url":"localhost:8080/api/v1"} + ] +} \ No newline at end of file diff --git a/internal/embedded/openapi/openapi.yaml b/internal/embedded/openapi/openapi.yaml new file mode 100644 index 000000000..02b2ae581 --- /dev/null +++ b/internal/embedded/openapi/openapi.yaml @@ -0,0 +1,2223 @@ +components: + schemas: + catalogs.ArchitectureType: + description: Type of architecture + type: string + x-enum-comments: + ArchitectureTypeCNN: Convolutional Neural Networks + ArchitectureTypeDiffusion: Diffusion models (Stable Diffusion, DALL-E, etc.) + ArchitectureTypeGAN: Generative Adversarial Networks + ArchitectureTypeGRU: Gated Recurrent Unit networks + ArchitectureTypeLSTM: Long Short-Term Memory networks + ArchitectureTypeMoE: Mixture of Experts (Mixtral, GLaM, Switch Transformer) + ArchitectureTypeRNN: Recurrent Neural Networks + ArchitectureTypeTransformer: Transformer-based models (GPT, BERT, LLaMA, etc.) + ArchitectureTypeVAE: Variational Autoencoders + x-enum-varnames: + - ArchitectureTypeTransformer + - ArchitectureTypeMoE + - ArchitectureTypeCNN + - ArchitectureTypeRNN + - ArchitectureTypeLSTM + - ArchitectureTypeGRU + - ArchitectureTypeVAE + - ArchitectureTypeGAN + - ArchitectureTypeDiffusion + catalogs.Author: + properties: + aliases: + description: Alternative IDs this author is known by (e.g., in provider + catalogs) + items: + description: Unique identifier for the author + type: string + x-enum-varnames: + - AuthorIDOpenAI + - AuthorIDAnthropic + - AuthorIDGoogle + - AuthorIDDeepMind + - AuthorIDMeta + - AuthorIDMicrosoft + - AuthorIDMistralAI + - AuthorIDCohere + - AuthorIDGroq + - AuthorIDAlibabaQwen + - AuthorIDQwen + - AuthorIDXAI + - AuthorIDStanford + - AuthorIDMIT + - AuthorIDCMU + - AuthorIDUCBerkeley + - AuthorIDCornell + - AuthorIDPrinceton + - AuthorIDHarvard + - AuthorIDOxford + - AuthorIDCambridge + - AuthorIDETHZurich + - AuthorIDUWashington + - AuthorIDUChicago + - AuthorIDYale + - AuthorIDDuke + - AuthorIDCaltech + - AuthorIDHuggingFace + - AuthorIDEleutherAI + - AuthorIDTogether + - AuthorIDMosaicML + - AuthorIDStabilityAI + - AuthorIDRunwayML + - AuthorIDMidjourney + - AuthorIDLAION + - AuthorIDBigScience + - AuthorIDAlignmentRC + - AuthorIDH2OAI + - AuthorIDMoxin + - AuthorIDBaidu + - AuthorIDTencent + - AuthorIDByteDance + - AuthorIDDeepSeek + - AuthorIDBAAI + - AuthorID01AI + - AuthorIDBaichuan + - AuthorIDMiniMax + - AuthorIDMoonshot + - AuthorIDShanghaiAI + - AuthorIDZhipuAI + - AuthorIDSenseTime + - AuthorIDHuawei + - AuthorIDTsinghua + - AuthorIDPeking + - AuthorIDNVIDIA + - AuthorIDSalesforce + - AuthorIDIBM + - AuthorIDApple + - AuthorIDAmazon + - AuthorIDAdept + - AuthorIDAI21 + - AuthorIDInflection + - AuthorIDCharacter + - AuthorIDPerplexity + - AuthorIDAnysphere + - AuthorIDCursor + - AuthorIDCognitiveComputations + - AuthorIDEricHartford + - AuthorIDNousResearch + - AuthorIDTeknium + - AuthorIDJonDurbin + - AuthorIDLMSYS + - AuthorIDVicuna + - AuthorIDAlpacaTeam + - AuthorIDWizardLM + - AuthorIDOpenOrca + - AuthorIDPhind + - AuthorIDCodeFuse + - AuthorIDTHUDM + - AuthorIDGeorgiaTechRI + - AuthorIDFastChat + - AuthorIDUnknown + type: array + uniqueItems: false + catalog: + $ref: '#/components/schemas/catalogs.AuthorCatalog' + created_at: + description: Timestamps for record keeping and auditing + type: string + description: + description: Description of what the author is known for + type: string + github: + description: GitHub profile/organization URL + type: string + headquarters: + description: Company/organization info + type: string + huggingface: + description: Hugging Face profile/organization URL + type: string + icon_url: + description: Author icon/logo URL + type: string + id: + $ref: '#/components/schemas/catalogs.AuthorID' + name: + description: Display name of the author + type: string + twitter: + description: X (formerly Twitter) profile URL + type: string + updated_at: + description: Last updated date (YYYY-MM or YYYY-MM-DD format) + type: string + website: + description: Website, social links, and other relevant URLs + type: string + type: object + catalogs.AuthorAttribution: + description: Model attribution configuration for multi-provider inference + properties: + patterns: + description: Glob patterns to match model IDs + items: + type: string + type: array + uniqueItems: false + provider_id: + $ref: '#/components/schemas/catalogs.ProviderID' + type: object + catalogs.AuthorCatalog: + description: Catalog and models + properties: + attribution: + $ref: '#/components/schemas/catalogs.AuthorAttribution' + description: + description: Optional description of this mapping relationship + type: string + type: object + catalogs.AuthorID: + description: Unique identifier for the author + type: string + x-enum-varnames: + - AuthorIDOpenAI + - AuthorIDAnthropic + - AuthorIDGoogle + - AuthorIDDeepMind + - AuthorIDMeta + - AuthorIDMicrosoft + - AuthorIDMistralAI + - AuthorIDCohere + - AuthorIDGroq + - AuthorIDAlibabaQwen + - AuthorIDQwen + - AuthorIDXAI + - AuthorIDStanford + - AuthorIDMIT + - AuthorIDCMU + - AuthorIDUCBerkeley + - AuthorIDCornell + - AuthorIDPrinceton + - AuthorIDHarvard + - AuthorIDOxford + - AuthorIDCambridge + - AuthorIDETHZurich + - AuthorIDUWashington + - AuthorIDUChicago + - AuthorIDYale + - AuthorIDDuke + - AuthorIDCaltech + - AuthorIDHuggingFace + - AuthorIDEleutherAI + - AuthorIDTogether + - AuthorIDMosaicML + - AuthorIDStabilityAI + - AuthorIDRunwayML + - AuthorIDMidjourney + - AuthorIDLAION + - AuthorIDBigScience + - AuthorIDAlignmentRC + - AuthorIDH2OAI + - AuthorIDMoxin + - AuthorIDBaidu + - AuthorIDTencent + - AuthorIDByteDance + - AuthorIDDeepSeek + - AuthorIDBAAI + - AuthorID01AI + - AuthorIDBaichuan + - AuthorIDMiniMax + - AuthorIDMoonshot + - AuthorIDShanghaiAI + - AuthorIDZhipuAI + - AuthorIDSenseTime + - AuthorIDHuawei + - AuthorIDTsinghua + - AuthorIDPeking + - AuthorIDNVIDIA + - AuthorIDSalesforce + - AuthorIDIBM + - AuthorIDApple + - AuthorIDAmazon + - AuthorIDAdept + - AuthorIDAI21 + - AuthorIDInflection + - AuthorIDCharacter + - AuthorIDPerplexity + - AuthorIDAnysphere + - AuthorIDCursor + - AuthorIDCognitiveComputations + - AuthorIDEricHartford + - AuthorIDNousResearch + - AuthorIDTeknium + - AuthorIDJonDurbin + - AuthorIDLMSYS + - AuthorIDVicuna + - AuthorIDAlpacaTeam + - AuthorIDWizardLM + - AuthorIDOpenOrca + - AuthorIDPhind + - AuthorIDCodeFuse + - AuthorIDTHUDM + - AuthorIDGeorgiaTechRI + - AuthorIDFastChat + - AuthorIDUnknown + catalogs.AuthorMapping: + description: Author extraction + properties: + field: + description: Field to extract from (e.g., "owned_by") + type: string + normalized: + additionalProperties: + description: Unique identifier for the author + type: string + x-enum-varnames: + - AuthorIDOpenAI + - AuthorIDAnthropic + - AuthorIDGoogle + - AuthorIDDeepMind + - AuthorIDMeta + - AuthorIDMicrosoft + - AuthorIDMistralAI + - AuthorIDCohere + - AuthorIDGroq + - AuthorIDAlibabaQwen + - AuthorIDQwen + - AuthorIDXAI + - AuthorIDStanford + - AuthorIDMIT + - AuthorIDCMU + - AuthorIDUCBerkeley + - AuthorIDCornell + - AuthorIDPrinceton + - AuthorIDHarvard + - AuthorIDOxford + - AuthorIDCambridge + - AuthorIDETHZurich + - AuthorIDUWashington + - AuthorIDUChicago + - AuthorIDYale + - AuthorIDDuke + - AuthorIDCaltech + - AuthorIDHuggingFace + - AuthorIDEleutherAI + - AuthorIDTogether + - AuthorIDMosaicML + - AuthorIDStabilityAI + - AuthorIDRunwayML + - AuthorIDMidjourney + - AuthorIDLAION + - AuthorIDBigScience + - AuthorIDAlignmentRC + - AuthorIDH2OAI + - AuthorIDMoxin + - AuthorIDBaidu + - AuthorIDTencent + - AuthorIDByteDance + - AuthorIDDeepSeek + - AuthorIDBAAI + - AuthorID01AI + - AuthorIDBaichuan + - AuthorIDMiniMax + - AuthorIDMoonshot + - AuthorIDShanghaiAI + - AuthorIDZhipuAI + - AuthorIDSenseTime + - AuthorIDHuawei + - AuthorIDTsinghua + - AuthorIDPeking + - AuthorIDNVIDIA + - AuthorIDSalesforce + - AuthorIDIBM + - AuthorIDApple + - AuthorIDAmazon + - AuthorIDAdept + - AuthorIDAI21 + - AuthorIDInflection + - AuthorIDCharacter + - AuthorIDPerplexity + - AuthorIDAnysphere + - AuthorIDCursor + - AuthorIDCognitiveComputations + - AuthorIDEricHartford + - AuthorIDNousResearch + - AuthorIDTeknium + - AuthorIDJonDurbin + - AuthorIDLMSYS + - AuthorIDVicuna + - AuthorIDAlpacaTeam + - AuthorIDWizardLM + - AuthorIDOpenOrca + - AuthorIDPhind + - AuthorIDCodeFuse + - AuthorIDTHUDM + - AuthorIDGeorgiaTechRI + - AuthorIDFastChat + - AuthorIDUnknown + description: Normalization map (e.g., "Meta" -> "meta") + type: object + type: object + catalogs.EndpointType: + description: 'Required: API style' + type: string + x-enum-varnames: + - EndpointTypeOpenAI + - EndpointTypeAnthropic + - EndpointTypeGoogle + - EndpointTypeGoogleCloud + catalogs.FeatureRule: + properties: + contains: + description: If field contains any of these strings + items: + type: string + type: array + uniqueItems: false + feature: + description: Feature to enable (e.g., "tools", "reasoning") + type: string + field: + description: Field to check (e.g., "id", "owned_by") + type: string + value: + description: Value to set for the feature + type: boolean + type: object + catalogs.FieldMapping: + properties: + from: + description: Source field path in API response (e.g., "max_model_len") + type: string + to: + description: Target field path in Model (e.g., "limits.context_window") + type: string + type: object + catalogs.FloatRange: + description: Alternative sampling strategies (niche) + properties: + default: + description: Default value + type: number + max: + description: Maximum value + type: number + min: + description: Minimum value + type: number + type: object + catalogs.IntRange: + description: ReasoningTokens - specific token allocation for reasoning processes + properties: + default: + description: Default value + type: integer + max: + description: Maximum value + type: integer + min: + description: Minimum value + type: integer + type: object + catalogs.Model: + properties: + attachments: + $ref: '#/components/schemas/catalogs.ModelAttachments' + authors: + description: Authors/organizations of the model (if known) + items: + $ref: '#/components/schemas/catalogs.Author' + type: array + uniqueItems: false + created_at: + description: Timestamps for record keeping and auditing + type: string + description: + description: Description of the model and its use cases + type: string + features: + $ref: '#/components/schemas/catalogs.ModelFeatures' + generation: + $ref: '#/components/schemas/catalogs.ModelGeneration' + id: + description: Core identity + type: string + limits: + $ref: '#/components/schemas/catalogs.ModelLimits' + metadata: + $ref: '#/components/schemas/catalogs.ModelMetadata' + name: + description: Display name (must not be empty) + type: string + pricing: + $ref: '#/components/schemas/catalogs.ModelPricing' + reasoning: + $ref: '#/components/schemas/catalogs.ModelControlLevels' + reasoning_tokens: + $ref: '#/components/schemas/catalogs.IntRange' + response: + $ref: '#/components/schemas/catalogs.ModelDelivery' + tools: + $ref: '#/components/schemas/catalogs.ModelTools' + updated_at: + description: Last updated date (YYYY-MM or YYYY-MM-DD format) + type: string + verbosity: + $ref: '#/components/schemas/catalogs.ModelControlLevels' + type: object + catalogs.ModelArchitecture: + description: Technical architecture details + properties: + base_model: + description: Base model ID if fine-tuned + type: string + fine_tuned: + description: Whether this is a fine-tuned variant + type: boolean + parameter_count: + description: Model size (e.g., "7B", "70B", "405B") + type: string + precision: + description: Legacy precision format (use Quantization for filtering) + type: string + quantization: + $ref: '#/components/schemas/catalogs.Quantization' + quantized: + description: Whether the model has been quantized + type: boolean + tokenizer: + $ref: '#/components/schemas/catalogs.Tokenizer' + type: + $ref: '#/components/schemas/catalogs.ArchitectureType' + type: object + catalogs.ModelAttachments: + description: Attachments - attachment support details + properties: + max_file_size: + description: Maximum file size in bytes + type: integer + max_files: + description: Maximum number of files per request + type: integer + mime_types: + description: Supported MIME types + items: + type: string + type: array + uniqueItems: false + type: object + catalogs.ModelControlLevel: + type: string + x-enum-varnames: + - ModelControlLevelMinimum + - ModelControlLevelLow + - ModelControlLevelMedium + - ModelControlLevelHigh + - ModelControlLevelMaximum + catalogs.ModelControlLevels: + description: Verbosity - response verbosity levels + properties: + default: + description: Default level + type: string + x-enum-varnames: + - ModelControlLevelMinimum + - ModelControlLevelLow + - ModelControlLevelMedium + - ModelControlLevelHigh + - ModelControlLevelMaximum + levels: + description: Which levels this model supports + items: + $ref: '#/components/schemas/catalogs.ModelControlLevel' + type: array + uniqueItems: false + type: object + catalogs.ModelDelivery: + description: Delivery - technical response delivery capabilities (formats, protocols, + streaming) + properties: + formats: + description: Available response formats (if format_response feature enabled) + items: + $ref: '#/components/schemas/catalogs.ModelResponseFormat' + type: array + uniqueItems: false + protocols: + description: Response delivery mechanisms + items: + $ref: '#/components/schemas/catalogs.ModelResponseProtocol' + type: array + uniqueItems: false + streaming: + description: Supported streaming modes (sse, websocket, chunked) + items: + $ref: '#/components/schemas/catalogs.ModelStreaming' + type: array + uniqueItems: false + type: object + catalogs.ModelFeatures: + description: Features - what this model can do + properties: + allowed_tokens: + description: '[Niche] Supports token whitelist' + type: boolean + attachments: + description: Attachment support details + type: boolean + bad_words: + description: '[Advanced] Supports bad words/disallowed tokens' + type: boolean + best_of: + description: '[Advanced] Supports server-side sampling with best selection' + type: boolean + contrastive_search_penalty_alpha: + description: '[Niche] Supports contrastive decoding' + type: boolean + diversity_penalty: + description: '[Niche] Supports diversity penalty in beam search' + type: boolean + early_stopping: + description: '[Niche] Supports early stopping in beam search' + type: boolean + echo: + description: '[Advanced] Supports echoing prompt with completion' + type: boolean + format_response: + description: Response delivery + type: boolean + frequency_penalty: + description: Generation control - Repetition control + type: boolean + include_reasoning: + description: Supports including reasoning traces in response + type: boolean + length_penalty: + description: '[Niche] Supports length penalty (seq2seq style)' + type: boolean + logit_bias: + description: Generation control - Token biasing + type: boolean + logprobs: + description: Generation control - Observability + type: boolean + max_output_tokens: + description: '[Core] Supports max_output_tokens parameter (some providers + distinguish from max_tokens)' + type: boolean + max_tokens: + description: Generation control - Length and termination + type: boolean + min_p: + description: '[Advanced] Supports min_p parameter (minimum probability threshold)' + type: boolean + mirostat: + description: Generation control - Alternative sampling strategies (niche) + type: boolean + mirostat_eta: + description: '[Niche] Supports Mirostat eta parameter' + type: boolean + mirostat_tau: + description: '[Niche] Supports Mirostat tau parameter' + type: boolean + modalities: + $ref: '#/components/schemas/catalogs.ModelModalities' + "n": + description: Generation control - Multiplicity and reranking + type: boolean + no_repeat_ngram_size: + description: '[Niche] Supports n-gram repetition blocking' + type: boolean + num_beams: + description: Generation control - Beam search (niche) + type: boolean + presence_penalty: + description: '[Core] Supports presence penalty' + type: boolean + reasoning: + description: Reasoning & Verbosity + type: boolean + reasoning_effort: + description: Supports configurable reasoning intensity + type: boolean + reasoning_tokens: + description: Supports specific reasoning token allocation + type: boolean + repetition_penalty: + description: '[Advanced] Supports repetition penalty' + type: boolean + seed: + description: Generation control - Determinism + type: boolean + stop: + description: '[Core] Supports stop sequences/words' + type: boolean + stop_token_ids: + description: '[Advanced] Supports stop token IDs (numeric)' + type: boolean + streaming: + description: Supports response streaming + type: boolean + structured_outputs: + description: Supports structured outputs (JSON schema validation) + type: boolean + temperature: + description: Generation control - Core sampling and decoding + type: boolean + tfs: + description: '[Advanced] Supports tail free sampling' + type: boolean + tool_calls: + description: |- + Core capabilities + Tool calling system - three distinct aspects: + type: boolean + tool_choice: + description: Supports tool choice strategies (auto/none/required control) + type: boolean + tools: + description: Accepts tool definitions in requests (accepts tools parameter) + type: boolean + top_a: + description: '[Advanced] Supports top_a parameter (top-a sampling)' + type: boolean + top_k: + description: '[Advanced] Supports top_k parameter' + type: boolean + top_logprobs: + description: '[Core] Supports returning top N log probabilities' + type: boolean + top_p: + description: '[Core] Supports top_p parameter (nucleus sampling)' + type: boolean + typical_p: + description: '[Advanced] Supports typical_p parameter (typical sampling)' + type: boolean + verbosity: + description: Supports verbosity control (GPT-5+) + type: boolean + web_search: + description: Supports web search capabilities + type: boolean + type: object + catalogs.ModelGeneration: + description: Generation - core chat completions generation controls + properties: + best_of: + $ref: '#/components/schemas/catalogs.IntRange' + contrastive_search_penalty_alpha: + $ref: '#/components/schemas/catalogs.FloatRange' + diversity_penalty: + $ref: '#/components/schemas/catalogs.FloatRange' + frequency_penalty: + $ref: '#/components/schemas/catalogs.FloatRange' + length_penalty: + $ref: '#/components/schemas/catalogs.FloatRange' + max_output_tokens: + type: integer + max_tokens: + description: Length and termination + type: integer + min_p: + $ref: '#/components/schemas/catalogs.FloatRange' + mirostat_eta: + $ref: '#/components/schemas/catalogs.FloatRange' + mirostat_tau: + $ref: '#/components/schemas/catalogs.FloatRange' + "n": + $ref: '#/components/schemas/catalogs.IntRange' + no_repeat_ngram_size: + $ref: '#/components/schemas/catalogs.IntRange' + num_beams: + $ref: '#/components/schemas/catalogs.IntRange' + presence_penalty: + $ref: '#/components/schemas/catalogs.FloatRange' + repetition_penalty: + $ref: '#/components/schemas/catalogs.FloatRange' + temperature: + $ref: '#/components/schemas/catalogs.FloatRange' + tfs: + $ref: '#/components/schemas/catalogs.FloatRange' + top_a: + $ref: '#/components/schemas/catalogs.FloatRange' + top_k: + $ref: '#/components/schemas/catalogs.IntRange' + top_logprobs: + description: Observability + type: integer + top_p: + $ref: '#/components/schemas/catalogs.FloatRange' + typical_p: + $ref: '#/components/schemas/catalogs.FloatRange' + type: object + catalogs.ModelLimits: + description: Model limits + properties: + context_window: + description: Context window size in tokens + type: integer + output_tokens: + description: Maximum output tokens + type: integer + type: object + catalogs.ModelMetadata: + description: Metadata - version and timing information + properties: + architecture: + $ref: '#/components/schemas/catalogs.ModelArchitecture' + knowledge_cutoff: + description: Knowledge cutoff date (YYYY-MM or YYYY-MM-DD format) + type: string + open_weights: + description: Whether model weights are open + type: boolean + release_date: + description: Release date (YYYY-MM or YYYY-MM-DD format) + type: string + tags: + description: Use case tags for categorizing the model + items: + $ref: '#/components/schemas/catalogs.ModelTag' + type: array + uniqueItems: false + type: object + catalogs.ModelModalities: + description: Input/Output modalities + properties: + input: + description: Supported input modalities + items: + $ref: '#/components/schemas/catalogs.ModelModality' + type: array + uniqueItems: false + output: + description: Supported output modalities + items: + type: string + x-enum-comments: + ModelModalityEmbedding: Vector embeddings + x-enum-varnames: + - ModelModalityText + - ModelModalityAudio + - ModelModalityImage + - ModelModalityVideo + - ModelModalityPDF + - ModelModalityEmbedding + type: array + uniqueItems: false + type: object + catalogs.ModelModality: + type: string + x-enum-comments: + ModelModalityEmbedding: Vector embeddings + x-enum-varnames: + - ModelModalityText + - ModelModalityAudio + - ModelModalityImage + - ModelModalityVideo + - ModelModalityPDF + - ModelModalityEmbedding + catalogs.ModelOperationPricing: + description: Fixed costs per operation + properties: + audio_gen: + description: Cost per audio generated + type: number + audio_input: + description: Cost per audio input + type: number + function_call: + description: Cost per function call + type: number + image_gen: + description: Generation operations + type: number + image_input: + description: Media operations + type: number + request: + description: Core operations + type: number + tool_use: + description: Cost per tool usage + type: number + video_gen: + description: Cost per video generated + type: number + video_input: + description: Cost per video input + type: number + web_search: + description: Service operations + type: number + type: object + catalogs.ModelPricing: + description: Operational characteristics + properties: + currency: + $ref: '#/components/schemas/catalogs.ModelPricingCurrency' + operations: + $ref: '#/components/schemas/catalogs.ModelOperationPricing' + tokens: + $ref: '#/components/schemas/catalogs.ModelTokenPricing' + type: object + catalogs.ModelPricingCurrency: + description: Metadata + type: string + x-enum-comments: + ModelPricingCurrencyAUD: Australian Dollar + ModelPricingCurrencyCAD: Canadian Dollar + ModelPricingCurrencyCNY: Chinese Yuan + ModelPricingCurrencyEUR: Euro + ModelPricingCurrencyGBP: British Pound Sterling + ModelPricingCurrencyJPY: Japanese Yen + ModelPricingCurrencyNZD: New Zealand Dollar + ModelPricingCurrencyUSD: US Dollar + x-enum-varnames: + - ModelPricingCurrencyUSD + - ModelPricingCurrencyEUR + - ModelPricingCurrencyJPY + - ModelPricingCurrencyGBP + - ModelPricingCurrencyAUD + - ModelPricingCurrencyCAD + - ModelPricingCurrencyCNY + - ModelPricingCurrencyNZD + catalogs.ModelResponseFormat: + type: string + x-enum-comments: + ModelResponseFormatFunctionCall: Tool/function calling for structured data + ModelResponseFormatJSON: JSON encouraged via prompting + ModelResponseFormatJSONMode: Forced valid JSON (OpenAI style) + ModelResponseFormatJSONObject: Same as json_mode (OpenAI API name) + ModelResponseFormatJSONSchema: Schema-validated JSON (OpenAI structured output) + ModelResponseFormatStructuredOutput: General structured output support + ModelResponseFormatText: Plain text responses (default) + x-enum-varnames: + - ModelResponseFormatText + - ModelResponseFormatJSON + - ModelResponseFormatJSONMode + - ModelResponseFormatJSONObject + - ModelResponseFormatJSONSchema + - ModelResponseFormatStructuredOutput + - ModelResponseFormatFunctionCall + catalogs.ModelResponseProtocol: + type: string + x-enum-comments: + ModelResponseProtocolGRPC: gRPC protocol + ModelResponseProtocolHTTP: HTTP/HTTPS REST API + ModelResponseProtocolWebSocket: WebSocket protocol + x-enum-varnames: + - ModelResponseProtocolHTTP + - ModelResponseProtocolGRPC + - ModelResponseProtocolWebSocket + catalogs.ModelStreaming: + type: string + x-enum-comments: + ModelStreamingChunked: HTTP chunked transfer encoding + ModelStreamingSSE: Server-Sent Events streaming + ModelStreamingWebSocket: WebSocket streaming + x-enum-varnames: + - ModelStreamingSSE + - ModelStreamingWebSocket + - ModelStreamingChunked + catalogs.ModelTag: + type: string + x-enum-comments: + ModelTagAudio: Audio processing + ModelTagChat: Conversational AI + ModelTagCoding: Programming and code generation + ModelTagCreative: Creative content generation + ModelTagEducation: Educational content + ModelTagEmbedding: Text embeddings + ModelTagFinance: Financial analysis + ModelTagFunctionCalling: Tool/function calling + ModelTagImageToText: Image captioning/OCR + ModelTagInstruct: Instruction following + ModelTagLegal: Legal document processing + ModelTagMath: Mathematical problem solving + ModelTagMedical: Medical and healthcare + ModelTagMultimodal: Multiple input modalities + ModelTagQA: Question answering + ModelTagReasoning: Logical reasoning and problem solving + ModelTagResearch: Research and analysis + ModelTagRoleplay: Character roleplay and simulation + ModelTagScience: Scientific applications + ModelTagSpeechToText: Speech recognition + ModelTagSummarization: Text summarization + ModelTagTextToImage: Text-to-image generation + ModelTagTextToSpeech: Text-to-speech synthesis + ModelTagTranslation: Language translation + ModelTagVision: Computer vision + ModelTagWriting: Creative and technical writing + x-enum-varnames: + - ModelTagCoding + - ModelTagWriting + - ModelTagReasoning + - ModelTagMath + - ModelTagChat + - ModelTagInstruct + - ModelTagResearch + - ModelTagCreative + - ModelTagRoleplay + - ModelTagFunctionCalling + - ModelTagEmbedding + - ModelTagSummarization + - ModelTagTranslation + - ModelTagQA + - ModelTagVision + - ModelTagMultimodal + - ModelTagAudio + - ModelTagTextToImage + - ModelTagTextToSpeech + - ModelTagSpeechToText + - ModelTagImageToText + - ModelTagMedical + - ModelTagLegal + - ModelTagFinance + - ModelTagScience + - ModelTagEducation + catalogs.ModelTokenCachePricing: + description: Cache operations + properties: + read: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + write: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + type: object + catalogs.ModelTokenCost: + description: Alternative flat cache structure (for backward compatibility) + properties: + per_1m_tokens: + description: Cost per 1M tokens + type: number + per_token: + description: Cost per individual token + type: number + type: object + catalogs.ModelTokenPricing: + description: Token-based costs + properties: + cache: + $ref: '#/components/schemas/catalogs.ModelTokenCachePricing' + cache_read: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + cache_write: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + input: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + output: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + reasoning: + $ref: '#/components/schemas/catalogs.ModelTokenCost' + type: object + catalogs.ModelTools: + description: Tools - external tool and capability integrations + properties: + tool_choices: + description: |- + Tool calling configuration + Specifies which tool choice strategies this model supports. + Requires both Tools=true and ToolChoice=true in ModelFeatures. + Common values: ["auto"], ["auto", "none"], ["auto", "none", "required"] + items: + $ref: '#/components/schemas/catalogs.ToolChoice' + type: array + uniqueItems: false + web_search: + $ref: '#/components/schemas/catalogs.ModelWebSearch' + type: object + catalogs.ModelWebSearch: + description: |- + Web search configuration + Only applicable if WebSearch=true in ModelFeatures + properties: + default_context_size: + description: Default search context size + type: string + x-enum-varnames: + - ModelControlLevelMinimum + - ModelControlLevelLow + - ModelControlLevelMedium + - ModelControlLevelHigh + - ModelControlLevelMaximum + max_results: + description: Plugin-based web search options (for models using OpenRouter's + web plugin) + type: integer + search_context_sizes: + description: Built-in web search options (for models with native web search + like GPT-4.1, Perplexity) + items: + type: string + x-enum-varnames: + - ModelControlLevelMinimum + - ModelControlLevelLow + - ModelControlLevelMedium + - ModelControlLevelHigh + - ModelControlLevelMaximum + type: array + uniqueItems: false + search_prompt: + description: Custom prompt for search results + type: string + type: object + catalogs.Provider: + properties: + aliases: + description: Alternative IDs this provider is known by (e.g., in models.dev) + items: + description: Optional provider to source models from + type: string + x-enum-varnames: + - ProviderIDAlibabaQwen + - ProviderIDAnthropic + - ProviderIDAnyscale + - ProviderIDCerebras + - ProviderIDCheckstep + - ProviderIDCohere + - ProviderIDConectys + - ProviderIDCove + - ProviderIDDeepMind + - ProviderIDDeepSeek + - ProviderIDGoogleAIStudio + - ProviderIDGoogleVertex + - ProviderIDGroq + - ProviderIDHuggingFace + - ProviderIDMeta + - ProviderIDMicrosoft + - ProviderIDMistralAI + - ProviderIDOpenAI + - ProviderIDOpenRouter + - ProviderIDPerplexity + - ProviderIDReplicate + - ProviderIDSafetyKit + - ProviderIDTogetherAI + - ProviderIDVirtuousAI + - ProviderIDWebPurify + - ProviderIDXAI + type: array + uniqueItems: false + api_key: + $ref: '#/components/schemas/catalogs.ProviderAPIKey' + catalog: + $ref: '#/components/schemas/catalogs.ProviderCatalog' + chat_completions: + $ref: '#/components/schemas/catalogs.ProviderChatCompletions' + env_vars: + description: Environment variables configuration + items: + $ref: '#/components/schemas/catalogs.ProviderEnvVar' + type: array + uniqueItems: false + governance_policy: + $ref: '#/components/schemas/catalogs.ProviderGovernancePolicy' + headquarters: + description: Company headquarters location + type: string + icon_url: + description: Provider icon/logo URL + type: string + id: + description: Core identification and integration + type: string + x-enum-varnames: + - ProviderIDAlibabaQwen + - ProviderIDAnthropic + - ProviderIDAnyscale + - ProviderIDCerebras + - ProviderIDCheckstep + - ProviderIDCohere + - ProviderIDConectys + - ProviderIDCove + - ProviderIDDeepMind + - ProviderIDDeepSeek + - ProviderIDGoogleAIStudio + - ProviderIDGoogleVertex + - ProviderIDGroq + - ProviderIDHuggingFace + - ProviderIDMeta + - ProviderIDMicrosoft + - ProviderIDMistralAI + - ProviderIDOpenAI + - ProviderIDOpenRouter + - ProviderIDPerplexity + - ProviderIDReplicate + - ProviderIDSafetyKit + - ProviderIDTogetherAI + - ProviderIDVirtuousAI + - ProviderIDWebPurify + - ProviderIDXAI + name: + description: Display name (must not be empty) + type: string + privacy_policy: + $ref: '#/components/schemas/catalogs.ProviderPrivacyPolicy' + retention_policy: + $ref: '#/components/schemas/catalogs.ProviderRetentionPolicy' + status_page_url: + description: Status & Health + type: string + type: object + catalogs.ProviderAPIKey: + description: API key configuration + properties: + header: + description: Header name to send the API key in + type: string + name: + description: Name of the API key parameter + type: string + pattern: + description: Glob pattern to match the API key + type: string + query_param: + description: Query parameter name to send the API key in + type: string + scheme: + $ref: '#/components/schemas/catalogs.ProviderAPIKeyScheme' + type: object + catalogs.ProviderAPIKeyScheme: + description: Authentication scheme (e.g., "Bearer", "Basic", or empty for direct + value) + type: string + x-enum-comments: + ProviderAPIKeySchemeBasic: Basic authentication + ProviderAPIKeySchemeBearer: Bearer token authentication (OAuth 2.0 style) + ProviderAPIKeySchemeDirect: Direct value (no scheme prefix) + x-enum-varnames: + - ProviderAPIKeySchemeBearer + - ProviderAPIKeySchemeBasic + - ProviderAPIKeySchemeDirect + catalogs.ProviderCatalog: + description: Models + properties: + authors: + description: List of authors to fetch from (for providers like Google Vertex + AI) + items: + description: Unique identifier for the author + type: string + x-enum-varnames: + - AuthorIDOpenAI + - AuthorIDAnthropic + - AuthorIDGoogle + - AuthorIDDeepMind + - AuthorIDMeta + - AuthorIDMicrosoft + - AuthorIDMistralAI + - AuthorIDCohere + - AuthorIDGroq + - AuthorIDAlibabaQwen + - AuthorIDQwen + - AuthorIDXAI + - AuthorIDStanford + - AuthorIDMIT + - AuthorIDCMU + - AuthorIDUCBerkeley + - AuthorIDCornell + - AuthorIDPrinceton + - AuthorIDHarvard + - AuthorIDOxford + - AuthorIDCambridge + - AuthorIDETHZurich + - AuthorIDUWashington + - AuthorIDUChicago + - AuthorIDYale + - AuthorIDDuke + - AuthorIDCaltech + - AuthorIDHuggingFace + - AuthorIDEleutherAI + - AuthorIDTogether + - AuthorIDMosaicML + - AuthorIDStabilityAI + - AuthorIDRunwayML + - AuthorIDMidjourney + - AuthorIDLAION + - AuthorIDBigScience + - AuthorIDAlignmentRC + - AuthorIDH2OAI + - AuthorIDMoxin + - AuthorIDBaidu + - AuthorIDTencent + - AuthorIDByteDance + - AuthorIDDeepSeek + - AuthorIDBAAI + - AuthorID01AI + - AuthorIDBaichuan + - AuthorIDMiniMax + - AuthorIDMoonshot + - AuthorIDShanghaiAI + - AuthorIDZhipuAI + - AuthorIDSenseTime + - AuthorIDHuawei + - AuthorIDTsinghua + - AuthorIDPeking + - AuthorIDNVIDIA + - AuthorIDSalesforce + - AuthorIDIBM + - AuthorIDApple + - AuthorIDAmazon + - AuthorIDAdept + - AuthorIDAI21 + - AuthorIDInflection + - AuthorIDCharacter + - AuthorIDPerplexity + - AuthorIDAnysphere + - AuthorIDCursor + - AuthorIDCognitiveComputations + - AuthorIDEricHartford + - AuthorIDNousResearch + - AuthorIDTeknium + - AuthorIDJonDurbin + - AuthorIDLMSYS + - AuthorIDVicuna + - AuthorIDAlpacaTeam + - AuthorIDWizardLM + - AuthorIDOpenOrca + - AuthorIDPhind + - AuthorIDCodeFuse + - AuthorIDTHUDM + - AuthorIDGeorgiaTechRI + - AuthorIDFastChat + - AuthorIDUnknown + type: array + uniqueItems: false + docs: + description: Documentation URL + type: string + endpoint: + $ref: '#/components/schemas/catalogs.ProviderEndpoint' + type: object + catalogs.ProviderChatCompletions: + description: Chat completions API configuration + properties: + health_api_url: + description: URL to health/status API for this service + type: string + health_components: + description: Specific components to monitor for chat completions + items: + $ref: '#/components/schemas/catalogs.ProviderHealthComponent' + type: array + uniqueItems: false + url: + description: Chat completions API endpoint URL + type: string + type: object + catalogs.ProviderEndpoint: + description: API endpoint configuration + properties: + auth_required: + description: 'Required: Whether auth needed' + type: boolean + author_mapping: + $ref: '#/components/schemas/catalogs.AuthorMapping' + feature_rules: + description: Feature inference rules + items: + $ref: '#/components/schemas/catalogs.FeatureRule' + type: array + uniqueItems: false + field_mappings: + description: Field mappings + items: + $ref: '#/components/schemas/catalogs.FieldMapping' + type: array + uniqueItems: false + type: + $ref: '#/components/schemas/catalogs.EndpointType' + url: + description: 'Required: API endpoint' + type: string + type: object + catalogs.ProviderEnvVar: + properties: + description: + description: Human-readable description + type: string + name: + description: Environment variable name + type: string + pattern: + description: Optional validation pattern + type: string + required: + description: Whether this env var is required + type: boolean + type: object + catalogs.ProviderGovernancePolicy: + description: Oversight and moderation practices + properties: + moderated: + description: Whether provider content is moderated + type: boolean + moderation_required: + description: Whether the provider requires moderation + type: boolean + moderator: + description: Who moderates the provider + type: string + type: object + catalogs.ProviderHealthComponent: + properties: + id: + description: Component ID from the health API + type: string + name: + description: Human-readable component name + type: string + type: object + catalogs.ProviderID: + description: Optional provider to source models from + type: string + x-enum-varnames: + - ProviderIDAlibabaQwen + - ProviderIDAnthropic + - ProviderIDAnyscale + - ProviderIDCerebras + - ProviderIDCheckstep + - ProviderIDCohere + - ProviderIDConectys + - ProviderIDCove + - ProviderIDDeepMind + - ProviderIDDeepSeek + - ProviderIDGoogleAIStudio + - ProviderIDGoogleVertex + - ProviderIDGroq + - ProviderIDHuggingFace + - ProviderIDMeta + - ProviderIDMicrosoft + - ProviderIDMistralAI + - ProviderIDOpenAI + - ProviderIDOpenRouter + - ProviderIDPerplexity + - ProviderIDReplicate + - ProviderIDSafetyKit + - ProviderIDTogetherAI + - ProviderIDVirtuousAI + - ProviderIDWebPurify + - ProviderIDXAI + catalogs.ProviderPrivacyPolicy: + description: Privacy, Retention, and Governance Policies + properties: + privacy_policy_url: + description: Link to privacy policy + type: string + retains_data: + description: Whether provider stores/retains user data + type: boolean + terms_of_service_url: + description: Link to terms of service + type: string + trains_on_data: + description: Whether provider trains models on user data + type: boolean + type: object + catalogs.ProviderRetentionPolicy: + description: Data retention and deletion practices + properties: + details: + description: Human-readable description + type: string + duration: + $ref: '#/components/schemas/time.Duration' + type: + $ref: '#/components/schemas/catalogs.ProviderRetentionType' + type: object + catalogs.ProviderRetentionType: + description: Type of retention policy + type: string + x-enum-comments: + ProviderRetentionTypeConditional: Based on conditions (e.g., "until account + deletion") + ProviderRetentionTypeFixed: Specific duration (use Duration field) + ProviderRetentionTypeIndefinite: Forever (duration = nil) + ProviderRetentionTypeNone: No retention (immediate deletion) + x-enum-varnames: + - ProviderRetentionTypeFixed + - ProviderRetentionTypeNone + - ProviderRetentionTypeIndefinite + - ProviderRetentionTypeConditional + catalogs.Quantization: + description: Quantization level used by the model + type: string + x-enum-comments: + QuantizationBF16: Brain floating point (16 bit) + QuantizationFP4: Floating point (4 bit) + QuantizationFP6: Floating point (6 bit) + QuantizationFP8: Floating point (8 bit) + QuantizationFP16: Floating point (16 bit) + QuantizationFP32: Floating point (32 bit) + QuantizationINT4: Integer (4 bit) + QuantizationINT8: Integer (8 bit) + QuantizationUnknown: Unknown quantization + x-enum-varnames: + - QuantizationINT4 + - QuantizationINT8 + - QuantizationFP4 + - QuantizationFP6 + - QuantizationFP8 + - QuantizationFP16 + - QuantizationBF16 + - QuantizationFP32 + - QuantizationUnknown + catalogs.Tokenizer: + description: Tokenizer type used by the model + type: string + x-enum-comments: + TokenizerClaude: Claude tokenizer + TokenizerCohere: Cohere tokenizer + TokenizerDeepSeek: DeepSeek tokenizer + TokenizerGPT: GPT tokenizer (OpenAI) + TokenizerGemini: Gemini tokenizer (Google) + TokenizerGrok: Grok tokenizer (xAI) + TokenizerLlama2: LLaMA 2 tokenizer + TokenizerLlama3: LLaMA 3 tokenizer + TokenizerLlama4: LLaMA 4 tokenizer + TokenizerMistral: Mistral tokenizer + TokenizerNova: Nova tokenizer (Amazon) + TokenizerQwen: Qwen tokenizer + TokenizerQwen3: Qwen 3 tokenizer + TokenizerRouter: Router-based tokenizer + TokenizerUnknown: Unknown tokenizer type + TokenizerYi: Yi tokenizer + x-enum-varnames: + - TokenizerClaude + - TokenizerCohere + - TokenizerDeepSeek + - TokenizerGPT + - TokenizerGemini + - TokenizerGrok + - TokenizerLlama2 + - TokenizerLlama3 + - TokenizerLlama4 + - TokenizerMistral + - TokenizerNova + - TokenizerQwen + - TokenizerQwen3 + - TokenizerRouter + - TokenizerYi + - TokenizerUnknown + catalogs.ToolChoice: + type: string + x-enum-comments: + ToolChoiceAuto: Model autonomously decides whether to call tools based on + context + ToolChoiceNone: Model will never call tools, even if tool definitions are + provided + ToolChoiceRequired: Model must call at least one tool before responding + x-enum-varnames: + - ToolChoiceAuto + - ToolChoiceNone + - ToolChoiceRequired + data: + properties: + data: + type: object + type: object + error: + properties: + error: + $ref: '#/components/schemas/response.Error' + type: object + handlers.DateRange: + properties: + after: + type: string + before: + type: string + type: object + handlers.IntRange: + properties: + max: + type: integer + min: + type: integer + type: object + handlers.SearchModalities: + properties: + input: + items: + type: string + type: array + uniqueItems: false + output: + items: + type: string + type: array + uniqueItems: false + type: object + handlers.SearchRequest: + properties: + context_window: + $ref: '#/components/schemas/handlers.IntRange' + features: + additionalProperties: + type: boolean + type: object + ids: + items: + type: string + type: array + uniqueItems: false + max_results: + type: integer + modalities: + $ref: '#/components/schemas/handlers.SearchModalities' + name_contains: + type: string + open_weights: + type: boolean + order: + type: string + output_tokens: + $ref: '#/components/schemas/handlers.IntRange' + provider: + type: string + release_date: + $ref: '#/components/schemas/handlers.DateRange' + sort: + type: string + tags: + items: + type: string + type: array + uniqueItems: false + type: object + response.Error: + properties: + code: + type: string + details: + type: string + message: + type: string + type: object + response.Response: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + time.Duration: + description: nil = forever, 0 = immediate deletion + type: integer + x-enum-varnames: + - minDuration + - maxDuration + - Nanosecond + - Microsecond + - Millisecond + - Second + - Minute + - Hour + securitySchemes: + ApiKeyAuth: + description: API key for authentication (optional, configurable) + in: header + name: X-API-Key + type: apiKey +externalDocs: + description: "" + url: "" +info: + contact: + name: Starmap Project + url: https://github.com/agentstation/starmap + description: |- + REST API for the Starmap AI model catalog with real-time updates via WebSocket and SSE. + + Features: + - Comprehensive model and provider queries + - Advanced filtering and search + - Real-time updates via WebSocket and Server-Sent Events + - In-memory caching for performance + - Rate limiting and authentication support + license: + name: MIT + url: https://github.com/agentstation/starmap/blob/master/LICENSE + title: Starmap API + version: "1.0" +openapi: 3.1.0 +paths: + /api/v1/health: + get: + description: Health check endpoint (liveness probe) + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + summary: Health check + tags: + - health + /api/v1/models: + get: + description: List all models with optional filtering + parameters: + - description: Filter by exact model ID + in: query + name: id + schema: + type: string + - description: Filter by exact model name (case-insensitive) + in: query + name: name + schema: + type: string + - description: Filter by partial model name match + in: query + name: name_contains + schema: + type: string + - description: Filter by provider ID + in: query + name: provider + schema: + type: string + - description: Filter by input modality (comma-separated) + in: query + name: modality_input + schema: + type: string + - description: Filter by output modality (comma-separated) + in: query + name: modality_output + schema: + type: string + - description: Filter by feature (streaming, tool_calls, etc.) + in: query + name: feature + schema: + type: string + - description: Filter by tag (comma-separated) + in: query + name: tag + schema: + type: string + - description: Filter by open weights status + in: query + name: open_weights + schema: + type: boolean + - description: Minimum context window size + in: query + name: min_context + schema: + type: integer + - description: Maximum context window size + in: query + name: max_context + schema: + type: integer + - description: Sort field (id, name, release_date, context_window, created_at, + updated_at) + in: query + name: sort + schema: + type: string + - description: Sort order (asc, desc) + in: query + name: order + schema: + type: string + - description: 'Maximum number of results (default: 100, max: 1000)' + in: query + name: limit + schema: + type: integer + - description: Result offset for pagination + in: query + name: offset + schema: + type: integer + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "400": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Bad Request + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: List models + tags: + - models + /api/v1/models/{id}: + get: + description: Retrieve detailed information about a specific model + parameters: + - description: Model ID + in: path + name: id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "404": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Not Found + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Get model by ID + tags: + - models + /api/v1/models/search: + post: + description: Advanced search with multiple criteria + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/handlers.SearchRequest' + description: Search criteria + required: true + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "400": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Bad Request + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Search models + tags: + - models + /api/v1/openapi.json: + get: + description: Returns the OpenAPI 3.1 specification for this API in JSON format + responses: + "200": + content: + application/json: + schema: + type: object + description: OpenAPI 3.1 specification + summary: Get OpenAPI specification (JSON) + tags: + - meta + /api/v1/openapi.yaml: + get: + description: Returns the OpenAPI 3.1 specification for this API in YAML format + responses: + "200": + content: + application/json: + schema: + type: string + application/x-yaml: + schema: + type: string + description: OpenAPI 3.1 specification + summary: Get OpenAPI specification (YAML) + tags: + - meta + /api/v1/providers: + get: + description: List all providers + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: List providers + tags: + - providers + /api/v1/providers/{id}: + get: + description: Retrieve detailed information about a specific provider + parameters: + - description: Provider ID + in: path + name: id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "404": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Not Found + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Get provider by ID + tags: + - providers + /api/v1/providers/{id}/models: + get: + description: List all models for a specific provider + parameters: + - description: Provider ID + in: path + name: id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "404": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Not Found + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Get provider models + tags: + - providers + /api/v1/ready: + get: + description: Readiness check including cache and data source status + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "503": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Service Unavailable + summary: Readiness check + tags: + - health + /api/v1/stats: + get: + description: Get catalog statistics (model count, provider count, last sync) + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Catalog statistics + tags: + - admin + /api/v1/update: + post: + description: Manually trigger catalog synchronization + parameters: + - description: Update specific provider only + in: query + name: provider + schema: + type: string + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: OK + "500": + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/error' + properties: + data: {} + error: + $ref: '#/components/schemas/response.Error' + type: object + description: Internal Server Error + security: + - ApiKeyAuth: [] + summary: Trigger catalog update + tags: + - admin + /api/v1/updates/stream: + get: + description: Server-Sent Events stream for catalog change notifications + responses: + "200": + content: + text/event-stream: + schema: + type: string + description: Event stream + summary: SSE updates stream + tags: + - updates + /api/v1/updates/ws: + get: + description: WebSocket connection for real-time catalog updates + responses: + "101": + description: Switching Protocols + summary: WebSocket updates + tags: + - updates +servers: +- url: localhost:8080/api/v1 diff --git a/internal/server/cache/README.md b/internal/server/cache/README.md new file mode 100644 index 000000000..29ecba234 --- /dev/null +++ b/internal/server/cache/README.md @@ -0,0 +1,124 @@ + + + + +# cache + +```go +import "github.com/agentstation/starmap/internal/server/cache" +``` + +Package cache provides an in\-memory caching layer for the HTTP server. It uses patrickmn/go\-cache for TTL\-based caching with LRU\-like eviction. + +## Index + +- [type Cache](<#Cache>) + - [func New\(defaultTTL, cleanupInterval time.Duration\) \*Cache](<#New>) + - [func \(c \*Cache\) Clear\(\)](<#Cache.Clear>) + - [func \(c \*Cache\) Delete\(key string\)](<#Cache.Delete>) + - [func \(c \*Cache\) Get\(key string\) \(any, bool\)](<#Cache.Get>) + - [func \(c \*Cache\) GetStats\(\) Stats](<#Cache.GetStats>) + - [func \(c \*Cache\) ItemCount\(\) int](<#Cache.ItemCount>) + - [func \(c \*Cache\) Set\(key string, value any\)](<#Cache.Set>) + - [func \(c \*Cache\) SetWithTTL\(key string, value any, ttl time.Duration\)](<#Cache.SetWithTTL>) +- [type Stats](<#Stats>) + + + +## type [Cache]() + +Cache wraps go\-cache with additional features for HTTP caching. + +```go +type Cache struct { + // contains filtered or unexported fields +} +``` + + +### func [New]() + +```go +func New(defaultTTL, cleanupInterval time.Duration) *Cache +``` + +New creates a new cache with the given TTL and cleanup interval. defaultTTL is the default expiration time for cache entries. cleanupInterval is how often expired items are removed from memory. + + +### func \(\*Cache\) [Clear]() + +```go +func (c *Cache) Clear() +``` + +Clear removes all items from the cache. + + +### func \(\*Cache\) [Delete]() + +```go +func (c *Cache) Delete(key string) +``` + +Delete removes a value from the cache. + + +### func \(\*Cache\) [Get]() + +```go +func (c *Cache) Get(key string) (any, bool) +``` + +Get retrieves a value from the cache. + + +### func \(\*Cache\) [GetStats]() + +```go +func (c *Cache) GetStats() Stats +``` + +GetStats returns current cache statistics. + + +### func \(\*Cache\) [ItemCount]() + +```go +func (c *Cache) ItemCount() int +``` + +ItemCount returns the number of items in the cache. + + +### func \(\*Cache\) [Set]() + +```go +func (c *Cache) Set(key string, value any) +``` + +Set stores a value in the cache with default TTL. + + +### func \(\*Cache\) [SetWithTTL]() + +```go +func (c *Cache) SetWithTTL(key string, value any, ttl time.Duration) +``` + +SetWithTTL stores a value in the cache with custom TTL. + + +## type [Stats]() + +Stats returns cache statistics. + +```go +type Stats struct { + ItemCount int `json:"item_count"` +} +``` + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/cache/cache.go b/internal/server/cache/cache.go new file mode 100644 index 000000000..195201b3d --- /dev/null +++ b/internal/server/cache/cache.go @@ -0,0 +1,65 @@ +// Package cache provides an in-memory caching layer for the HTTP server. +// It uses patrickmn/go-cache for TTL-based caching with LRU-like eviction. +package cache + +import ( + "time" + + gocache "github.com/patrickmn/go-cache" +) + +// Cache wraps go-cache with additional features for HTTP caching. +type Cache struct { + store *gocache.Cache +} + +// New creates a new cache with the given TTL and cleanup interval. +// defaultTTL is the default expiration time for cache entries. +// cleanupInterval is how often expired items are removed from memory. +func New(defaultTTL, cleanupInterval time.Duration) *Cache { + return &Cache{ + store: gocache.New(defaultTTL, cleanupInterval), + } +} + +// Get retrieves a value from the cache. +func (c *Cache) Get(key string) (any, bool) { + return c.store.Get(key) +} + +// Set stores a value in the cache with default TTL. +func (c *Cache) Set(key string, value any) { + c.store.Set(key, value, gocache.DefaultExpiration) +} + +// SetWithTTL stores a value in the cache with custom TTL. +func (c *Cache) SetWithTTL(key string, value any, ttl time.Duration) { + c.store.Set(key, value, ttl) +} + +// Delete removes a value from the cache. +func (c *Cache) Delete(key string) { + c.store.Delete(key) +} + +// Clear removes all items from the cache. +func (c *Cache) Clear() { + c.store.Flush() +} + +// ItemCount returns the number of items in the cache. +func (c *Cache) ItemCount() int { + return c.store.ItemCount() +} + +// Stats returns cache statistics. +type Stats struct { + ItemCount int `json:"item_count"` +} + +// GetStats returns current cache statistics. +func (c *Cache) GetStats() Stats { + return Stats{ + ItemCount: c.store.ItemCount(), + } +} diff --git a/internal/server/cache/cache_test.go b/internal/server/cache/cache_test.go new file mode 100644 index 000000000..9beb71e8f --- /dev/null +++ b/internal/server/cache/cache_test.go @@ -0,0 +1,370 @@ +package cache + +import ( + "sync" + "testing" + "time" +) + +// TestCache_New tests cache creation. +func TestCache_New(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + if c == nil { + t.Fatal("New() returned nil") + } + if c.store == nil { + t.Error("cache store not initialized") + } +} + +// TestCache_BasicOperations tests Get, Set, and Delete. +func TestCache_BasicOperations(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + t.Run("Set and Get", func(t *testing.T) { + c.Set("key1", "value1") + + val, found := c.Get("key1") + if !found { + t.Error("expected key1 to be found") + } + if val != "value1" { + t.Errorf("expected value1, got %v", val) + } + }) + + t.Run("Get non-existent key", func(t *testing.T) { + _, found := c.Get("nonexistent") + if found { + t.Error("expected nonexistent key to not be found") + } + }) + + t.Run("Set and Delete", func(t *testing.T) { + c.Set("key2", "value2") + c.Delete("key2") + + _, found := c.Get("key2") + if found { + t.Error("expected key2 to be deleted") + } + }) + + t.Run("Delete non-existent key", func(t *testing.T) { + // Should not panic + c.Delete("nonexistent") + }) +} + +// TestCache_SetWithTTL tests custom TTL. +func TestCache_SetWithTTL(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + // Set with very short TTL + c.SetWithTTL("expiring", "value", 50*time.Millisecond) + + // Should exist immediately + _, found := c.Get("expiring") + if !found { + t.Error("expected key to exist immediately") + } + + // Wait for expiration + time.Sleep(100 * time.Millisecond) + + // Should be expired + _, found = c.Get("expiring") + if found { + t.Error("expected key to be expired") + } +} + +// TestCache_Clear tests clearing all items. +func TestCache_Clear(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + // Add multiple items + c.Set("key1", "value1") + c.Set("key2", "value2") + c.Set("key3", "value3") + + if count := c.ItemCount(); count != 3 { + t.Errorf("expected 3 items, got %d", count) + } + + // Clear cache + c.Clear() + + if count := c.ItemCount(); count != 0 { + t.Errorf("expected 0 items after clear, got %d", count) + } + + // Verify items are gone + _, found := c.Get("key1") + if found { + t.Error("expected key1 to be cleared") + } +} + +// TestCache_ItemCount tests item counting. +func TestCache_ItemCount(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + tests := []struct { + name string + setup func() + expected int + }{ + { + name: "empty cache", + setup: func() {}, + expected: 0, + }, + { + name: "one item", + setup: func() { + c.Set("key1", "value1") + }, + expected: 1, + }, + { + name: "multiple items", + setup: func() { + c.Set("key2", "value2") + c.Set("key3", "value3") + }, + expected: 3, + }, + { + name: "after deletion", + setup: func() { + c.Delete("key1") + }, + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup() + count := c.ItemCount() + if count != tt.expected { + t.Errorf("expected %d items, got %d", tt.expected, count) + } + }) + } +} + +// TestCache_GetStats tests statistics retrieval. +func TestCache_GetStats(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + // Add some items + c.Set("key1", "value1") + c.Set("key2", "value2") + + stats := c.GetStats() + if stats.ItemCount != 2 { + t.Errorf("expected ItemCount=2, got %d", stats.ItemCount) + } +} + +// TestCache_ConcurrentAccess tests thread-safety with concurrent operations. +func TestCache_ConcurrentAccess(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + const numGoroutines = 100 + const numOperations = 100 + + var wg sync.WaitGroup + + // Concurrent writes + t.Run("concurrent writes", func(t *testing.T) { + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := "key-" + string(rune(id)) + "-" + string(rune(j)) + c.Set(key, id*numOperations+j) + } + }(i) + } + wg.Wait() + }) + + // Concurrent reads + t.Run("concurrent reads", func(t *testing.T) { + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := "key-" + string(rune(id)) + "-" + string(rune(j)) + c.Get(key) + } + }(i) + } + wg.Wait() + }) + + // Mixed operations + t.Run("mixed operations", func(t *testing.T) { + wg.Add(numGoroutines * 3) + + // Writers + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + c.Set("mixed-"+string(rune(id)), j) + } + }(i) + } + + // Readers + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + c.Get("mixed-" + string(rune(id))) + } + }(i) + } + + // Deleters + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + c.Delete("mixed-" + string(rune(id))) + } + }(i) + } + + wg.Wait() + }) + + // Should not panic - test passes if we get here +} + +// TestCache_ComplexTypes tests caching complex data types. +func TestCache_ComplexTypes(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + type TestStruct struct { + Name string + Count int + Tags []string + } + + tests := []struct { + name string + key string + value any + }{ + { + name: "string", + key: "str", + value: "hello", + }, + { + name: "int", + key: "int", + value: 42, + }, + { + name: "slice", + key: "slice", + value: []string{"a", "b", "c"}, + }, + { + name: "map", + key: "map", + value: map[string]int{"one": 1, "two": 2}, + }, + { + name: "struct", + key: "struct", + value: TestStruct{ + Name: "test", + Count: 123, + Tags: []string{"tag1", "tag2"}, + }, + }, + { + name: "pointer", + key: "ptr", + value: &TestStruct{ + Name: "pointer-test", + Count: 456, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c.Set(tt.key, tt.value) + + val, found := c.Get(tt.key) + if !found { + t.Errorf("expected key %q to be found", tt.key) + } + + // Type assertion depends on the type + // Just verify we got something back + if val == nil { + t.Errorf("expected non-nil value for key %q", tt.key) + } + }) + } +} + +// TestCache_Overwrite tests overwriting existing keys. +func TestCache_Overwrite(t *testing.T) { + c := New(5*time.Minute, 10*time.Minute) + + // Set initial value + c.Set("key", "value1") + + val, _ := c.Get("key") + if val != "value1" { + t.Errorf("expected value1, got %v", val) + } + + // Overwrite with new value + c.Set("key", "value2") + + val, _ = c.Get("key") + if val != "value2" { + t.Errorf("expected value2, got %v", val) + } + + // Verify only one item in cache + if count := c.ItemCount(); count != 1 { + t.Errorf("expected 1 item, got %d", count) + } +} + +// TestCache_DefaultExpiration tests default TTL behavior. +func TestCache_DefaultExpiration(t *testing.T) { + // Create cache with 100ms default TTL + c := New(100*time.Millisecond, 200*time.Millisecond) + + c.Set("key", "value") + + // Should exist immediately + _, found := c.Get("key") + if !found { + t.Error("expected key to exist immediately") + } + + // Wait for default expiration + time.Sleep(150 * time.Millisecond) + + // Should be expired + _, found = c.Get("key") + if found { + t.Error("expected key to be expired after default TTL") + } +} diff --git a/internal/server/cache/generate.go b/internal/server/cache/generate.go new file mode 100644 index 000000000..e4938eaef --- /dev/null +++ b/internal/server/cache/generate.go @@ -0,0 +1,3 @@ +package cache + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/cache diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 000000000..5459cbfa9 --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,52 @@ +package server + +import "time" + +// Config holds server configuration. +type Config struct { + // Server settings + Host string + Port int + + // API settings + PathPrefix string + + // CORS settings + CORSEnabled bool + CORSOrigins []string + + // Authentication settings + AuthEnabled bool + AuthHeader string + + // Performance settings + RateLimit int // Requests per minute per IP (0 to disable) + CacheTTL time.Duration + + // HTTP timeouts + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + + // Features + MetricsEnabled bool +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() Config { + return Config{ + Host: "localhost", + Port: 8080, + PathPrefix: "/api/v1", + CORSEnabled: false, + CORSOrigins: []string{}, + AuthEnabled: false, + AuthHeader: "X-API-Key", + RateLimit: 100, + CacheTTL: 5 * time.Minute, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + MetricsEnabled: true, + } +} diff --git a/internal/server/docs.go b/internal/server/docs.go new file mode 100644 index 000000000..f80e66099 --- /dev/null +++ b/internal/server/docs.go @@ -0,0 +1,31 @@ +// Package server provides HTTP server implementation for the Starmap API. +// +// This file contains general API documentation annotations for Swag/OpenAPI generation. +// These annotations describe the overall API (title, version, security, etc.) +// while individual endpoint annotations live in the handler files. +package server + +// @title Starmap API +// @version 1.0 +// @description REST API for the Starmap AI model catalog with real-time updates via WebSocket and SSE. +// @description +// @description Features: +// @description - Comprehensive model and provider queries +// @description - Advanced filtering and search +// @description - Real-time updates via WebSocket and Server-Sent Events +// @description - In-memory caching for performance +// @description - Rate limiting and authentication support +// +// @contact.name Starmap Project +// @contact.url https://github.com/agentstation/starmap +// +// @license.name MIT +// @license.url https://github.com/agentstation/starmap/blob/master/LICENSE +// +// @host localhost:8080 +// @BasePath /api/v1 +// +// @securityDefinitions.apikey ApiKeyAuth +// @in header +// @name X-API-Key +// @description API key for authentication (optional, configurable) diff --git a/internal/server/events/adapters/adapters_test.go b/internal/server/events/adapters/adapters_test.go new file mode 100644 index 000000000..e51831d82 --- /dev/null +++ b/internal/server/events/adapters/adapters_test.go @@ -0,0 +1,350 @@ +package adapters + +import ( + "testing" + "time" + + "github.com/agentstation/starmap/internal/server/events" + "github.com/agentstation/starmap/internal/server/sse" + ws "github.com/agentstation/starmap/internal/server/websocket" + "github.com/rs/zerolog" +) + +// TestNewSSESubscriber tests SSE subscriber creation. +func TestNewSSESubscriber(t *testing.T) { + logger := zerolog.Nop() + broadcaster := sse.NewBroadcaster(&logger) + + sub := NewSSESubscriber(broadcaster) + + if sub == nil { + t.Fatal("NewSSESubscriber returned nil") + } + + if sub.broadcaster != broadcaster { + t.Error("broadcaster not set correctly") + } +} + +// TestSSESubscriber_Send tests sending events via SSE adapter. +func TestSSESubscriber_Send(t *testing.T) { + logger := zerolog.Nop() + broadcaster := sse.NewBroadcaster(&logger) + sub := NewSSESubscriber(broadcaster) + + // Test sending various event types + testEvents := []events.Event{ + {Type: events.ModelAdded, Timestamp: time.Now(), Data: map[string]any{"model": "gpt-4"}}, + {Type: events.ModelUpdated, Timestamp: time.Now(), Data: map[string]any{"model": "claude-3"}}, + {Type: events.ModelDeleted, Timestamp: time.Now(), Data: map[string]any{"id": "gpt-3"}}, + {Type: events.SyncStarted, Timestamp: time.Now(), Data: map[string]any{"provider": "openai"}}, + {Type: events.SyncCompleted, Timestamp: time.Now(), Data: map[string]any{"count": 10}}, + {Type: events.ClientConnected, Timestamp: time.Now(), Data: map[string]any{"id": "client-1"}}, + } + + for i, event := range testEvents { + err := sub.Send(event) + if err != nil { + t.Errorf("event %d: Send() returned error: %v", i, err) + } + } +} + +// TestSSESubscriber_Send_WithNilData tests sending event with nil data. +func TestSSESubscriber_Send_WithNilData(t *testing.T) { + logger := zerolog.Nop() + broadcaster := sse.NewBroadcaster(&logger) + sub := NewSSESubscriber(broadcaster) + + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: nil, + } + + err := sub.Send(event) + if err != nil { + t.Errorf("Send() with nil data returned error: %v", err) + } +} + +// TestSSESubscriber_Send_WithComplexData tests sending event with complex data types. +func TestSSESubscriber_Send_WithComplexData(t *testing.T) { + logger := zerolog.Nop() + broadcaster := sse.NewBroadcaster(&logger) + sub := NewSSESubscriber(broadcaster) + + complexData := map[string]any{ + "models": []string{"gpt-4", "claude-3", "gemini-pro"}, + "count": 100, + "metadata": map[string]any{ + "provider": "openai", + "version": "v1", + }, + "tags": []string{"production", "verified"}, + } + + event := events.Event{ + Type: events.SyncCompleted, + Timestamp: time.Now(), + Data: complexData, + } + + err := sub.Send(event) + if err != nil { + t.Errorf("Send() with complex data returned error: %v", err) + } +} + +// TestSSESubscriber_Close tests closing SSE subscriber. +func TestSSESubscriber_Close(t *testing.T) { + logger := zerolog.Nop() + broadcaster := sse.NewBroadcaster(&logger) + sub := NewSSESubscriber(broadcaster) + + // Close should be a no-op and not return error + err := sub.Close() + if err != nil { + t.Errorf("Close() returned error: %v", err) + } + + // Should be able to call Close multiple times + err = sub.Close() + if err != nil { + t.Errorf("second Close() returned error: %v", err) + } + + // Should still be able to send after close (since Close is a no-op) + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + + err = sub.Send(event) + if err != nil { + t.Errorf("Send() after Close() returned error: %v", err) + } +} + +// TestNewWebSocketSubscriber tests WebSocket subscriber creation. +func TestNewWebSocketSubscriber(t *testing.T) { + logger := zerolog.Nop() + hub := ws.NewHub(&logger) + + sub := NewWebSocketSubscriber(hub) + + if sub == nil { + t.Fatal("NewWebSocketSubscriber returned nil") + } + + if sub.hub != hub { + t.Error("hub not set correctly") + } +} + +// TestWebSocketSubscriber_Send tests sending events via WebSocket adapter. +func TestWebSocketSubscriber_Send(t *testing.T) { + logger := zerolog.Nop() + hub := ws.NewHub(&logger) + sub := NewWebSocketSubscriber(hub) + + // Test sending various event types + testEvents := []events.Event{ + {Type: events.ModelAdded, Timestamp: time.Now(), Data: map[string]any{"model": "gpt-4"}}, + {Type: events.ModelUpdated, Timestamp: time.Now(), Data: map[string]any{"model": "claude-3"}}, + {Type: events.ModelDeleted, Timestamp: time.Now(), Data: map[string]any{"id": "gpt-3"}}, + {Type: events.SyncStarted, Timestamp: time.Now(), Data: map[string]any{"provider": "openai"}}, + {Type: events.SyncCompleted, Timestamp: time.Now(), Data: map[string]any{"count": 50}}, + {Type: events.ClientConnected, Timestamp: time.Now(), Data: map[string]any{"id": "ws-1"}}, + } + + for i, event := range testEvents { + err := sub.Send(event) + if err != nil { + t.Errorf("event %d: Send() returned error: %v", i, err) + } + } +} + +// TestWebSocketSubscriber_Send_WithNilData tests sending event with nil data. +func TestWebSocketSubscriber_Send_WithNilData(t *testing.T) { + logger := zerolog.Nop() + hub := ws.NewHub(&logger) + sub := NewWebSocketSubscriber(hub) + + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: nil, + } + + err := sub.Send(event) + if err != nil { + t.Errorf("Send() with nil data returned error: %v", err) + } +} + +// TestWebSocketSubscriber_Send_WithComplexData tests sending event with complex data types. +func TestWebSocketSubscriber_Send_WithComplexData(t *testing.T) { + logger := zerolog.Nop() + hub := ws.NewHub(&logger) + sub := NewWebSocketSubscriber(hub) + + complexData := map[string]any{ + "models": []string{"gpt-4", "claude-3", "gemini-pro"}, + "count": 100, + "metadata": map[string]any{ + "provider": "openai", + "version": "v1", + }, + "tags": []string{"production", "verified"}, + } + + event := events.Event{ + Type: events.SyncCompleted, + Timestamp: time.Now(), + Data: complexData, + } + + err := sub.Send(event) + if err != nil { + t.Errorf("Send() with complex data returned error: %v", err) + } +} + +// TestWebSocketSubscriber_Close tests closing WebSocket subscriber. +func TestWebSocketSubscriber_Close(t *testing.T) { + logger := zerolog.Nop() + hub := ws.NewHub(&logger) + sub := NewWebSocketSubscriber(hub) + + // Close should be a no-op and not return error + err := sub.Close() + if err != nil { + t.Errorf("Close() returned error: %v", err) + } + + // Should be able to call Close multiple times + err = sub.Close() + if err != nil { + t.Errorf("second Close() returned error: %v", err) + } + + // Should still be able to send after close (since Close is a no-op) + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + + err = sub.Send(event) + if err != nil { + t.Errorf("Send() after Close() returned error: %v", err) + } +} + +// TestAdapters_EventTypeConversion tests that all event types are handled correctly. +func TestAdapters_EventTypeConversion(t *testing.T) { + eventTypes := []events.EventType{ + events.ModelAdded, + events.ModelUpdated, + events.ModelDeleted, + events.SyncStarted, + events.SyncCompleted, + events.ClientConnected, + } + + logger := zerolog.Nop() + + for _, eventType := range eventTypes { + t.Run(string(eventType), func(t *testing.T) { + // Test SSE subscriber + sseBroadcaster := sse.NewBroadcaster(&logger) + sseSub := NewSSESubscriber(sseBroadcaster) + + sseEvent := events.Event{ + Type: eventType, + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + + if err := sseSub.Send(sseEvent); err != nil { + t.Errorf("SSE Send() failed: %v", err) + } + + // Test WebSocket subscriber + wsHub := ws.NewHub(&logger) + wsSub := NewWebSocketSubscriber(wsHub) + + wsEvent := events.Event{ + Type: eventType, + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + + if err := wsSub.Send(wsEvent); err != nil { + t.Errorf("WebSocket Send() failed: %v", err) + } + }) + } +} + +// TestAdapters_ConcurrentSend tests concurrent sending to ensure thread safety. +func TestAdapters_ConcurrentSend(t *testing.T) { + logger := zerolog.Nop() + + t.Run("SSE concurrent", func(t *testing.T) { + broadcaster := sse.NewBroadcaster(&logger) + sub := NewSSESubscriber(broadcaster) + + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(id int) { + defer func() { done <- true }() + for j := 0; j < 10; j++ { + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: map[string]any{"id": id, "iteration": j}, + } + if err := sub.Send(event); err != nil { + t.Errorf("goroutine %d: Send() failed: %v", id, err) + } + } + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + }) + + t.Run("WebSocket concurrent", func(t *testing.T) { + hub := ws.NewHub(&logger) + sub := NewWebSocketSubscriber(hub) + + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(id int) { + defer func() { done <- true }() + for j := 0; j < 10; j++ { + event := events.Event{ + Type: events.ModelAdded, + Timestamp: time.Now(), + Data: map[string]any{"id": id, "iteration": j}, + } + if err := sub.Send(event); err != nil { + t.Errorf("goroutine %d: Send() failed: %v", id, err) + } + } + }(i) + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + }) +} diff --git a/internal/server/events/adapters/sse.go b/internal/server/events/adapters/sse.go new file mode 100644 index 000000000..90ee03638 --- /dev/null +++ b/internal/server/events/adapters/sse.go @@ -0,0 +1,33 @@ +package adapters + +import ( + "fmt" + + "github.com/agentstation/starmap/internal/server/events" + "github.com/agentstation/starmap/internal/server/sse" +) + +// SSESubscriber adapts the SSE broadcaster to the Subscriber interface. +type SSESubscriber struct { + broadcaster *sse.Broadcaster +} + +// NewSSESubscriber creates a new SSE subscriber. +func NewSSESubscriber(broadcaster *sse.Broadcaster) *SSESubscriber { + return &SSESubscriber{broadcaster: broadcaster} +} + +// Send delivers an event to all SSE clients. +func (s *SSESubscriber) Send(event events.Event) error { + s.broadcaster.Broadcast(sse.Event{ + Event: string(event.Type), + ID: fmt.Sprintf("%d", event.Timestamp.Unix()), + Data: event.Data, + }) + return nil +} + +// Close is a no-op for SSE (broadcaster manages its own lifecycle). +func (s *SSESubscriber) Close() error { + return nil +} diff --git a/internal/server/events/adapters/websocket.go b/internal/server/events/adapters/websocket.go new file mode 100644 index 000000000..655d4b756 --- /dev/null +++ b/internal/server/events/adapters/websocket.go @@ -0,0 +1,32 @@ +// Package adapters provides transport-specific implementations of the Subscriber interface. +package adapters + +import ( + "github.com/agentstation/starmap/internal/server/events" + ws "github.com/agentstation/starmap/internal/server/websocket" +) + +// WebSocketSubscriber adapts the WebSocket hub to the Subscriber interface. +type WebSocketSubscriber struct { + hub *ws.Hub +} + +// NewWebSocketSubscriber creates a new WebSocket subscriber. +func NewWebSocketSubscriber(hub *ws.Hub) *WebSocketSubscriber { + return &WebSocketSubscriber{hub: hub} +} + +// Send delivers an event to all WebSocket clients. +func (w *WebSocketSubscriber) Send(event events.Event) error { + w.hub.Broadcast(ws.Message{ + Type: string(event.Type), + Timestamp: event.Timestamp, + Data: event.Data, + }) + return nil +} + +// Close is a no-op for WebSocket (hub manages its own lifecycle). +func (w *WebSocketSubscriber) Close() error { + return nil +} diff --git a/internal/server/events/broker.go b/internal/server/events/broker.go new file mode 100644 index 000000000..f79903951 --- /dev/null +++ b/internal/server/events/broker.go @@ -0,0 +1,130 @@ +package events + +import ( + "context" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// Broker manages event distribution to multiple subscribers. +// It provides a central hub for catalog events, fanning them out to +// all registered subscribers (WebSocket, SSE, etc.) concurrently. +type Broker struct { + subscribers []Subscriber + events chan Event + register chan Subscriber + unregister chan Subscriber + mu sync.RWMutex + logger *zerolog.Logger +} + +// NewBroker creates a new event broker. +func NewBroker(logger *zerolog.Logger) *Broker { + return &Broker{ + subscribers: make([]Subscriber, 0), + events: make(chan Event, 256), + register: make(chan Subscriber), + unregister: make(chan Subscriber), + logger: logger, + } +} + +// Run starts the broker's event loop. Should be called in a goroutine. +// The broker will run until the context is cancelled. +func (b *Broker) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + // Graceful shutdown: close all subscribers + b.mu.Lock() + for _, sub := range b.subscribers { + _ = sub.Close() + } + b.subscribers = nil + b.mu.Unlock() + b.logger.Info().Msg("Event broker shut down") + return + + case sub := <-b.register: + b.mu.Lock() + b.subscribers = append(b.subscribers, sub) + b.mu.Unlock() + b.logger.Info(). + Int("total_subscribers", len(b.subscribers)). + Msg("Subscriber registered") + + case sub := <-b.unregister: + b.mu.Lock() + for i, s := range b.subscribers { + if s == sub { + b.subscribers = append(b.subscribers[:i], b.subscribers[i+1:]...) + _ = s.Close() + break + } + } + b.mu.Unlock() + b.logger.Info(). + Int("total_subscribers", len(b.subscribers)). + Msg("Subscriber unregistered") + + case event := <-b.events: + b.mu.RLock() + subs := make([]Subscriber, len(b.subscribers)) + copy(subs, b.subscribers) + b.mu.RUnlock() + + // Fan-out to all subscribers concurrently + for _, sub := range subs { + go func(s Subscriber, e Event) { + if err := s.Send(e); err != nil { + b.logger.Warn(). + Err(err). + Str("event_type", string(e.Type)). + Msg("Failed to send event to subscriber") + } + }(sub, event) + } + + b.logger.Debug(). + Str("event_type", string(event.Type)). + Int("subscribers", len(subs)). + Msg("Event broadcasted") + } + } +} + +// Publish sends an event to all subscribers. +func (b *Broker) Publish(eventType EventType, data any) { + event := Event{ + Type: eventType, + Timestamp: time.Now(), + Data: data, + } + + select { + case b.events <- event: + default: + b.logger.Warn(). + Str("event_type", string(eventType)). + Msg("Event channel full, event dropped") + } +} + +// Subscribe registers a new subscriber to receive events. +func (b *Broker) Subscribe(sub Subscriber) { + b.register <- sub +} + +// Unsubscribe removes a subscriber from receiving events. +func (b *Broker) Unsubscribe(sub Subscriber) { + b.unregister <- sub +} + +// SubscriberCount returns the current number of subscribers. +func (b *Broker) SubscriberCount() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.subscribers) +} diff --git a/internal/server/events/broker_test.go b/internal/server/events/broker_test.go new file mode 100644 index 000000000..e994df92e --- /dev/null +++ b/internal/server/events/broker_test.go @@ -0,0 +1,130 @@ +package events + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" +) + +// mockSubscriber is a mock subscriber for testing. +type mockSubscriber struct { + events []Event + mu sync.Mutex + closed bool +} + +func newMockSubscriber() *mockSubscriber { + return &mockSubscriber{ + events: make([]Event, 0), + } +} + +func (m *mockSubscriber) Send(event Event) error { + m.mu.Lock() + defer m.mu.Unlock() + m.events = append(m.events, event) + return nil +} + +func (m *mockSubscriber) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *mockSubscriber) EventCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.events) +} + +// TestBroker_NewBroker tests broker creation. +func TestBroker_NewBroker(t *testing.T) { + logger := zerolog.Nop() + b := NewBroker(&logger) + + if b == nil { + t.Fatal("NewBroker returned nil") + } + + if b.subscribers == nil { + t.Error("subscribers slice not initialized") + } + + if b.events == nil { + t.Error("events channel not initialized") + } + + if b.register == nil { + t.Error("register channel not initialized") + } + + if b.unregister == nil { + t.Error("unregister channel not initialized") + } +} + +// TestBroker_BasicOperation tests basic broker operations. +func TestBroker_BasicOperation(t *testing.T) { + logger := zerolog.Nop() + b := NewBroker(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Subscribe + sub := newMockSubscriber() + b.Subscribe(sub) + time.Sleep(10 * time.Millisecond) + + if count := b.SubscriberCount(); count != 1 { + t.Fatalf("expected 1 subscriber, got %d", count) + } + + // Publish event + b.Publish(ModelAdded, map[string]any{"model": "test"}) + time.Sleep(50 * time.Millisecond) + + // Verify event received + if count := sub.EventCount(); count != 1 { + t.Errorf("expected 1 event, got %d", count) + } +} + +// TestBroker_Shutdown tests graceful shutdown. +func TestBroker_Shutdown(t *testing.T) { + logger := zerolog.Nop() + b := NewBroker(&logger) + + ctx, cancel := context.WithCancel(context.Background()) + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Subscribe + sub1 := newMockSubscriber() + sub2 := newMockSubscriber() + b.Subscribe(sub1) + b.Subscribe(sub2) + time.Sleep(10 * time.Millisecond) + + if count := b.SubscriberCount(); count != 2 { + t.Fatalf("expected 2 subscribers, got %d", count) + } + + // Trigger shutdown + cancel() + time.Sleep(50 * time.Millisecond) + + // Verify all subscribers disconnected + if count := b.SubscriberCount(); count != 0 { + t.Errorf("expected 0 subscribers after shutdown, got %d", count) + } +} diff --git a/internal/server/events/subscriber.go b/internal/server/events/subscriber.go new file mode 100644 index 000000000..a39e5f2c0 --- /dev/null +++ b/internal/server/events/subscriber.go @@ -0,0 +1,13 @@ +package events + +// Subscriber is an interface for event consumers. +// Implementations adapt the unified event stream to specific transport +// mechanisms (WebSocket, SSE, MQTT, webhooks, etc.). +type Subscriber interface { + // Send delivers an event to the subscriber. + // Implementations should be non-blocking and handle errors gracefully. + Send(Event) error + + // Close cleanly shuts down the subscriber. + Close() error +} diff --git a/internal/server/events/types.go b/internal/server/events/types.go new file mode 100644 index 000000000..0b8e63107 --- /dev/null +++ b/internal/server/events/types.go @@ -0,0 +1,34 @@ +// Package events provides a unified event system for real-time catalog updates. +// +// This package implements a broker pattern that connects Starmap's hooks system +// to multiple transport mechanisms (WebSocket, SSE, etc.) through a common event +// pipeline. This eliminates code duplication and provides a single point for +// event distribution. +package events + +import "time" + +// EventType represents the type of catalog event. +type EventType string + +// Event types for catalog changes. +const ( + // Model events (from Starmap hooks) + ModelAdded EventType = "model.added" + ModelUpdated EventType = "model.updated" + ModelDeleted EventType = "model.deleted" + + // Sync events (from sync operations) + SyncStarted EventType = "sync.started" + SyncCompleted EventType = "sync.completed" + + // Client events (from transport layers) + ClientConnected EventType = "client.connected" +) + +// Event represents a catalog event with type, timestamp, and data. +type Event struct { + Type EventType `json:"type"` + Timestamp time.Time `json:"timestamp"` + Data any `json:"data"` +} diff --git a/internal/server/filter/README.md b/internal/server/filter/README.md new file mode 100644 index 000000000..74743cd7a --- /dev/null +++ b/internal/server/filter/README.md @@ -0,0 +1,84 @@ + + + + +# filter + +```go +import "github.com/agentstation/starmap/internal/server/filter" +``` + +Package filter provides query parameter parsing and filtering for API endpoints. + +## Index + +- [type ModelFilter](<#ModelFilter>) + - [func ParseModelFilter\(r \*http.Request\) ModelFilter](<#ParseModelFilter>) + - [func \(f ModelFilter\) Apply\(models \[\]catalogs.Model\) \[\]catalogs.Model](<#ModelFilter.Apply>) + + + +## type [ModelFilter]() + +ModelFilter contains all possible filter criteria for models. + +```go +type ModelFilter struct { + // Basic filters + ID string + Name string + NameContains string + Provider string + + // Modality filters + ModalityInput []string + ModalityOutput []string + + // Feature filters + Features map[string]bool + + // Metadata filters + Tags []string + OpenWeights *bool + + // Numeric range filters + MinContext int64 + MaxContext int64 + MinOutput int64 + MaxOutput int64 + + // Date filters + ReleasedAfter *time.Time + ReleasedBefore *time.Time + + // Pagination + Sort string + Order string + Limit int + Offset int + MaxResults int +} +``` + + +### func [ParseModelFilter]() + +```go +func ParseModelFilter(r *http.Request) ModelFilter +``` + +ParseModelFilter extracts model filter parameters from HTTP request. + + +### func \(ModelFilter\) [Apply]() + +```go +func (f ModelFilter) Apply(models []catalogs.Model) []catalogs.Model +``` + +Apply applies the filter to a list of models and returns filtered results. + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/filter/filter.go b/internal/server/filter/filter.go new file mode 100644 index 000000000..5b55bbd62 --- /dev/null +++ b/internal/server/filter/filter.go @@ -0,0 +1,324 @@ +// Package filter provides query parameter parsing and filtering for API endpoints. +package filter + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/agentstation/starmap/pkg/catalogs" +) + +// ModelFilter contains all possible filter criteria for models. +type ModelFilter struct { + // Basic filters + ID string + Name string + NameContains string + Provider string + + // Modality filters + ModalityInput []string + ModalityOutput []string + + // Feature filters + Features map[string]bool + + // Metadata filters + Tags []string + OpenWeights *bool + + // Numeric range filters + MinContext int64 + MaxContext int64 + MinOutput int64 + MaxOutput int64 + + // Date filters + ReleasedAfter *time.Time + ReleasedBefore *time.Time + + // Pagination + Sort string + Order string + Limit int + Offset int + MaxResults int +} + +// ParseModelFilter extracts model filter parameters from HTTP request. +func ParseModelFilter(r *http.Request) ModelFilter { + q := r.URL.Query() + + filter := ModelFilter{ + ID: q.Get("id"), + Name: q.Get("name"), + NameContains: q.Get("name_contains"), + Provider: q.Get("provider"), + Sort: q.Get("sort"), + Order: q.Get("order"), + Limit: parseIntOrDefault(q.Get("limit"), 100), + Offset: parseIntOrDefault(q.Get("offset"), 0), + MaxResults: parseIntOrDefault(q.Get("max_results"), 1000), + } + + // Parse modalities + if modalInput := q.Get("modality_input"); modalInput != "" { + filter.ModalityInput = strings.Split(modalInput, ",") + } + if modalOutput := q.Get("modality_output"); modalOutput != "" { + filter.ModalityOutput = strings.Split(modalOutput, ",") + } + + // Parse features + filter.Features = make(map[string]bool) + for _, feature := range []string{"streaming", "tool_calls", "tools", "tool_choice", "reasoning", "temperature", "max_tokens"} { + if val := q.Get("feature_" + feature); val != "" { + if b, err := strconv.ParseBool(val); err == nil { + filter.Features[feature] = b + } + } + } + // Also support shorthand "feature=streaming" format + if feature := q.Get("feature"); feature != "" { + filter.Features[feature] = true + } + + // Parse tags + if tags := q.Get("tag"); tags != "" { + filter.Tags = strings.Split(tags, ",") + } + + // Parse open_weights + if ow := q.Get("open_weights"); ow != "" { + if b, err := strconv.ParseBool(ow); err == nil { + filter.OpenWeights = &b + } + } + + // Parse context window ranges + if minCtx := q.Get("min_context"); minCtx != "" { + if i, err := strconv.ParseInt(minCtx, 10, 64); err == nil { + filter.MinContext = i + } + } + if maxCtx := q.Get("max_context"); maxCtx != "" { + if i, err := strconv.ParseInt(maxCtx, 10, 64); err == nil { + filter.MaxContext = i + } + } + + // Parse output token ranges + if minOut := q.Get("min_output"); minOut != "" { + if i, err := strconv.ParseInt(minOut, 10, 64); err == nil { + filter.MinOutput = i + } + } + if maxOut := q.Get("max_output"); maxOut != "" { + if i, err := strconv.ParseInt(maxOut, 10, 64); err == nil { + filter.MaxOutput = i + } + } + + // Parse date ranges + if after := q.Get("released_after"); after != "" { + if t, err := time.Parse(time.RFC3339, after); err == nil { + filter.ReleasedAfter = &t + } + } + if before := q.Get("released_before"); before != "" { + if t, err := time.Parse(time.RFC3339, before); err == nil { + filter.ReleasedBefore = &t + } + } + + return filter +} + +// Apply applies the filter to a list of models and returns filtered results. +func (f ModelFilter) Apply(models []catalogs.Model) []catalogs.Model { + var results []catalogs.Model + + for _, model := range models { + if f.matches(model) { + results = append(results, model) + } + } + + // Apply sorting + if f.Sort != "" { + results = f.sort(results) + } + + return results +} + +// matches checks if a model matches the filter criteria. +func (f ModelFilter) matches(model catalogs.Model) bool { + return f.matchesBasicFilters(model) && + f.matchesModalityFilters(model) && + f.matchesFeaturesFilter(model) && + f.matchesMetadataFilters(model) && + f.matchesLimitFilters(model) && + f.matchesDateFilters(model) +} + +// matchesBasicFilters checks ID, name, and name contains filters. +func (f ModelFilter) matchesBasicFilters(model catalogs.Model) bool { + if f.ID != "" && model.ID != f.ID { + return false + } + if f.Name != "" && !strings.EqualFold(model.Name, f.Name) { + return false + } + if f.NameContains != "" && !strings.Contains(strings.ToLower(model.Name), strings.ToLower(f.NameContains)) { + return false + } + return true +} + +// matchesModalityFilters checks input and output modality filters. +func (f ModelFilter) matchesModalityFilters(model catalogs.Model) bool { + if len(f.ModalityInput) > 0 && model.Features != nil { + if !modalityContainsAll(model.Features.Modalities.Input, f.ModalityInput) { + return false + } + } + if len(f.ModalityOutput) > 0 && model.Features != nil { + if !modalityContainsAll(model.Features.Modalities.Output, f.ModalityOutput) { + return false + } + } + return true +} + +// matchesFeaturesFilter checks feature capability filters. +func (f ModelFilter) matchesFeaturesFilter(model catalogs.Model) bool { + if len(f.Features) > 0 && model.Features != nil { + for feature, required := range f.Features { + if !matchFeature(model.Features, feature, required) { + return false + } + } + } + return true +} + +// matchesMetadataFilters checks tags and open weights filters. +func (f ModelFilter) matchesMetadataFilters(model catalogs.Model) bool { + if len(f.Tags) > 0 && model.Metadata != nil { + if !tagContainsAny(model.Metadata.Tags, f.Tags) { + return false + } + } + if f.OpenWeights != nil && model.Metadata != nil { + if model.Metadata.OpenWeights != *f.OpenWeights { + return false + } + } + return true +} + +// matchesLimitFilters checks context window and output token range filters. +func (f ModelFilter) matchesLimitFilters(model catalogs.Model) bool { + if model.Limits == nil { + return true + } + if f.MinContext > 0 && model.Limits.ContextWindow < f.MinContext { + return false + } + if f.MaxContext > 0 && model.Limits.ContextWindow > f.MaxContext { + return false + } + if f.MinOutput > 0 && model.Limits.OutputTokens < f.MinOutput { + return false + } + if f.MaxOutput > 0 && model.Limits.OutputTokens > f.MaxOutput { + return false + } + return true +} + +// matchesDateFilters checks release date range filters. +func (f ModelFilter) matchesDateFilters(model catalogs.Model) bool { + if model.Metadata == nil || model.Metadata.ReleaseDate.IsZero() { + return true + } + if f.ReleasedAfter != nil && model.Metadata.ReleaseDate.Time.Before(*f.ReleasedAfter) { + return false + } + if f.ReleasedBefore != nil && model.Metadata.ReleaseDate.Time.After(*f.ReleasedBefore) { + return false + } + return true +} + +// sort sorts models based on the sort field and order. +func (f ModelFilter) sort(models []catalogs.Model) []catalogs.Model { + // Simple implementation - for production, use more sophisticated sorting + // This is a placeholder that maintains current order + return models +} + +// matchFeature checks if a model has a specific feature. +func matchFeature(features *catalogs.ModelFeatures, feature string, required bool) bool { + switch feature { + case "streaming": + return features.Streaming == required + case "tool_calls": + return features.ToolCalls == required + case "tools": + return features.Tools == required + case "tool_choice": + return features.ToolChoice == required + case "reasoning": + return features.Reasoning == required + case "temperature": + return features.Temperature == required + case "max_tokens": + return features.MaxTokens == required + default: + return true + } +} + +// modalityContainsAll checks if modality slice contains all required values. +func modalityContainsAll(slice []catalogs.ModelModality, required []string) bool { + for _, req := range required { + found := false + for _, item := range slice { + if strings.EqualFold(string(item), req) { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +// tagContainsAny checks if tag slice contains any of the values. +func tagContainsAny(slice []catalogs.ModelTag, values []string) bool { + for _, val := range values { + for _, item := range slice { + if strings.EqualFold(string(item), val) { + return true + } + } + } + return false +} + +// parseIntOrDefault parses an integer or returns default. +func parseIntOrDefault(s string, def int) int { + if s == "" { + return def + } + if i, err := strconv.Atoi(s); err == nil { + return i + } + return def +} diff --git a/internal/server/filter/filter_test.go b/internal/server/filter/filter_test.go new file mode 100644 index 000000000..84604b823 --- /dev/null +++ b/internal/server/filter/filter_test.go @@ -0,0 +1,676 @@ +package filter + +import ( + "net/http/httptest" + "testing" + "time" + + "github.com/agentstation/starmap/pkg/catalogs" +) + +// TestParseModelFilter tests query parameter parsing into ModelFilter struct. +func TestParseModelFilter(t *testing.T) { + tests := []struct { + name string + query string + expected ModelFilter + }{ + { + name: "empty query", + query: "", + expected: ModelFilter{ + Features: map[string]bool{}, + Limit: 100, + Offset: 0, + MaxResults: 1000, + }, + }, + { + name: "basic filters", + query: "id=gpt-4&name=GPT-4&provider=openai", + expected: ModelFilter{ + ID: "gpt-4", + Name: "GPT-4", + Provider: "openai", + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "name contains filter", + query: "name_contains=gpt", + expected: ModelFilter{ + NameContains: "gpt", + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "modality filters", + query: "modality_input=text,image&modality_output=text", + expected: ModelFilter{ + ModalityInput: []string{"text", "image"}, + ModalityOutput: []string{"text"}, + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "feature filters - explicit", + query: "feature_streaming=true&feature_tool_calls=false", + expected: ModelFilter{ + Features: map[string]bool{ + "streaming": true, + "tool_calls": false, + }, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "feature filters - shorthand", + query: "feature=streaming", + expected: ModelFilter{ + Features: map[string]bool{ + "streaming": true, + }, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "tags filter", + query: "tag=audio,vision", + expected: ModelFilter{ + Tags: []string{"audio", "vision"}, + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "open weights filter", + query: "open_weights=true", + expected: ModelFilter{ + OpenWeights: boolPtr(true), + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "context window range", + query: "min_context=4096&max_context=128000", + expected: ModelFilter{ + MinContext: 4096, + MaxContext: 128000, + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "output tokens range", + query: "min_output=1024&max_output=4096", + expected: ModelFilter{ + MinOutput: 1024, + MaxOutput: 4096, + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "pagination", + query: "limit=50&offset=100&max_results=500", + expected: ModelFilter{ + Features: map[string]bool{}, + Limit: 50, + Offset: 100, + MaxResults: 500, + }, + }, + { + name: "sort and order", + query: "sort=name&order=desc", + expected: ModelFilter{ + Sort: "name", + Order: "desc", + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "date range filters", + query: "released_after=2024-01-01T00:00:00Z&released_before=2024-12-31T23:59:59Z", + expected: ModelFilter{ + ReleasedAfter: timePtr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + ReleasedBefore: timePtr(time.Date(2024, 12, 31, 23, 59, 59, 0, time.UTC)), + Features: map[string]bool{}, + Limit: 100, + MaxResults: 1000, + }, + }, + { + name: "combined complex filters", + query: "name_contains=gpt&provider=openai&modality_input=text&feature_streaming=true&min_context=8000&limit=25", + expected: ModelFilter{ + NameContains: "gpt", + Provider: "openai", + ModalityInput: []string{"text"}, + Features: map[string]bool{ + "streaming": true, + }, + MinContext: 8000, + Limit: 25, + MaxResults: 1000, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock request with query params + req := httptest.NewRequest("GET", "/models?"+tt.query, nil) + + // Parse filter + result := ParseModelFilter(req) + + // Verify basic fields + if result.ID != tt.expected.ID { + t.Errorf("ID: got %q, want %q", result.ID, tt.expected.ID) + } + if result.Name != tt.expected.Name { + t.Errorf("Name: got %q, want %q", result.Name, tt.expected.Name) + } + if result.NameContains != tt.expected.NameContains { + t.Errorf("NameContains: got %q, want %q", result.NameContains, tt.expected.NameContains) + } + if result.Provider != tt.expected.Provider { + t.Errorf("Provider: got %q, want %q", result.Provider, tt.expected.Provider) + } + + // Verify slices + if !stringSliceEqual(result.ModalityInput, tt.expected.ModalityInput) { + t.Errorf("ModalityInput: got %v, want %v", result.ModalityInput, tt.expected.ModalityInput) + } + if !stringSliceEqual(result.ModalityOutput, tt.expected.ModalityOutput) { + t.Errorf("ModalityOutput: got %v, want %v", result.ModalityOutput, tt.expected.ModalityOutput) + } + if !stringSliceEqual(result.Tags, tt.expected.Tags) { + t.Errorf("Tags: got %v, want %v", result.Tags, tt.expected.Tags) + } + + // Verify maps + if !boolMapEqual(result.Features, tt.expected.Features) { + t.Errorf("Features: got %v, want %v", result.Features, tt.expected.Features) + } + + // Verify pointers + if !boolPtrEqual(result.OpenWeights, tt.expected.OpenWeights) { + t.Errorf("OpenWeights: got %v, want %v", result.OpenWeights, tt.expected.OpenWeights) + } + if !timePtrEqual(result.ReleasedAfter, tt.expected.ReleasedAfter) { + t.Errorf("ReleasedAfter: got %v, want %v", result.ReleasedAfter, tt.expected.ReleasedAfter) + } + if !timePtrEqual(result.ReleasedBefore, tt.expected.ReleasedBefore) { + t.Errorf("ReleasedBefore: got %v, want %v", result.ReleasedBefore, tt.expected.ReleasedBefore) + } + + // Verify numeric fields + if result.MinContext != tt.expected.MinContext { + t.Errorf("MinContext: got %d, want %d", result.MinContext, tt.expected.MinContext) + } + if result.MaxContext != tt.expected.MaxContext { + t.Errorf("MaxContext: got %d, want %d", result.MaxContext, tt.expected.MaxContext) + } + if result.MinOutput != tt.expected.MinOutput { + t.Errorf("MinOutput: got %d, want %d", result.MinOutput, tt.expected.MinOutput) + } + if result.MaxOutput != tt.expected.MaxOutput { + t.Errorf("MaxOutput: got %d, want %d", result.MaxOutput, tt.expected.MaxOutput) + } + if result.Limit != tt.expected.Limit { + t.Errorf("Limit: got %d, want %d", result.Limit, tt.expected.Limit) + } + if result.Offset != tt.expected.Offset { + t.Errorf("Offset: got %d, want %d", result.Offset, tt.expected.Offset) + } + if result.MaxResults != tt.expected.MaxResults { + t.Errorf("MaxResults: got %d, want %d", result.MaxResults, tt.expected.MaxResults) + } + + // Verify string fields + if result.Sort != tt.expected.Sort { + t.Errorf("Sort: got %q, want %q", result.Sort, tt.expected.Sort) + } + if result.Order != tt.expected.Order { + t.Errorf("Order: got %q, want %q", result.Order, tt.expected.Order) + } + }) + } +} + +// TestModelFilter_Apply tests the filtering logic. +func TestModelFilter_Apply(t *testing.T) { + // Create test models + models := []catalogs.Model{ + { + ID: "gpt-4", + Name: "GPT-4", + Features: &catalogs.ModelFeatures{ + Streaming: true, + ToolCalls: true, + Modalities: catalogs.ModelModalities{ + Input: []catalogs.ModelModality{"text"}, + Output: []catalogs.ModelModality{"text"}, + }, + }, + Limits: &catalogs.ModelLimits{ + ContextWindow: 128000, + OutputTokens: 4096, + }, + Metadata: &catalogs.ModelMetadata{ + Tags: []catalogs.ModelTag{"chat"}, + OpenWeights: false, + }, + }, + { + ID: "claude-3-opus", + Name: "Claude 3 Opus", + Features: &catalogs.ModelFeatures{ + Streaming: true, + ToolCalls: true, + Modalities: catalogs.ModelModalities{ + Input: []catalogs.ModelModality{"text", "image"}, + Output: []catalogs.ModelModality{"text"}, + }, + }, + Limits: &catalogs.ModelLimits{ + ContextWindow: 200000, + OutputTokens: 4096, + }, + Metadata: &catalogs.ModelMetadata{ + Tags: []catalogs.ModelTag{"chat", "vision"}, + OpenWeights: false, + }, + }, + { + ID: "llama-3-70b", + Name: "Llama 3 70B", + Features: &catalogs.ModelFeatures{ + Streaming: true, + ToolCalls: false, + Modalities: catalogs.ModelModalities{ + Input: []catalogs.ModelModality{"text"}, + Output: []catalogs.ModelModality{"text"}, + }, + }, + Limits: &catalogs.ModelLimits{ + ContextWindow: 8192, + OutputTokens: 2048, + }, + Metadata: &catalogs.ModelMetadata{ + Tags: []catalogs.ModelTag{"chat", "open"}, + OpenWeights: true, + }, + }, + } + + tests := []struct { + name string + filter ModelFilter + expected []string // Expected model IDs in result + }{ + { + name: "no filters - return all", + filter: ModelFilter{Features: map[string]bool{}}, + expected: []string{"gpt-4", "claude-3-opus", "llama-3-70b"}, + }, + { + name: "filter by ID", + filter: ModelFilter{ + ID: "gpt-4", + Features: map[string]bool{}, + }, + expected: []string{"gpt-4"}, + }, + { + name: "filter by name (case insensitive)", + filter: ModelFilter{ + Name: "gpt-4", + Features: map[string]bool{}, + }, + expected: []string{"gpt-4"}, + }, + { + name: "filter by name contains", + filter: ModelFilter{ + NameContains: "claude", + Features: map[string]bool{}, + }, + expected: []string{"claude-3-opus"}, + }, + { + name: "filter by input modality", + filter: ModelFilter{ + ModalityInput: []string{"image"}, + Features: map[string]bool{}, + }, + expected: []string{"claude-3-opus"}, + }, + { + name: "filter by streaming feature", + filter: ModelFilter{ + Features: map[string]bool{ + "streaming": true, + }, + }, + expected: []string{"gpt-4", "claude-3-opus", "llama-3-70b"}, + }, + { + name: "filter by tool_calls feature", + filter: ModelFilter{ + Features: map[string]bool{ + "tool_calls": true, + }, + }, + expected: []string{"gpt-4", "claude-3-opus"}, + }, + { + name: "filter by open weights", + filter: ModelFilter{ + OpenWeights: boolPtr(true), + Features: map[string]bool{}, + }, + expected: []string{"llama-3-70b"}, + }, + { + name: "filter by tags", + filter: ModelFilter{ + Tags: []string{"vision"}, + Features: map[string]bool{}, + }, + expected: []string{"claude-3-opus"}, + }, + { + name: "filter by min context window", + filter: ModelFilter{ + MinContext: 100000, + Features: map[string]bool{}, + }, + expected: []string{"gpt-4", "claude-3-opus"}, + }, + { + name: "filter by max context window", + filter: ModelFilter{ + MaxContext: 10000, + Features: map[string]bool{}, + }, + expected: []string{"llama-3-70b"}, + }, + { + name: "combined filters", + filter: ModelFilter{ + NameContains: "gpt", + Features: map[string]bool{ + "streaming": true, + }, + MinContext: 50000, + }, + expected: []string{"gpt-4"}, + }, + { + name: "no matches", + filter: ModelFilter{ + NameContains: "nonexistent", + Features: map[string]bool{}, + }, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.filter.Apply(models) + + // Extract IDs from result + var resultIDs []string + for _, m := range result { + resultIDs = append(resultIDs, m.ID) + } + + // Compare + if !stringSliceEqual(resultIDs, tt.expected) { + t.Errorf("got %v, want %v", resultIDs, tt.expected) + } + }) + } +} + +// TestMatchFeature tests individual feature matching. +func TestMatchFeature(t *testing.T) { + features := &catalogs.ModelFeatures{ + Streaming: true, + ToolCalls: true, + Tools: false, + ToolChoice: false, + Reasoning: true, + Temperature: true, + MaxTokens: true, + } + + tests := []struct { + name string + feature string + required bool + expected bool + }{ + {"streaming matches true", "streaming", true, true}, + {"streaming matches false", "streaming", false, false}, + {"tool_calls matches true", "tool_calls", true, true}, + {"tools matches false", "tools", false, true}, + {"tools doesn't match true", "tools", true, false}, + {"unknown feature defaults to true", "unknown", true, true}, + {"unknown feature defaults to true", "unknown", false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchFeature(features, tt.feature, tt.required) + if result != tt.expected { + t.Errorf("matchFeature(%q, %v) = %v, want %v", tt.feature, tt.required, result, tt.expected) + } + }) + } +} + +// TestModalityContainsAll tests modality matching. +func TestModalityContainsAll(t *testing.T) { + tests := []struct { + name string + slice []catalogs.ModelModality + required []string + expected bool + }{ + { + name: "empty required always matches", + slice: []catalogs.ModelModality{"text"}, + required: []string{}, + expected: true, + }, + { + name: "single match", + slice: []catalogs.ModelModality{"text"}, + required: []string{"text"}, + expected: true, + }, + { + name: "case insensitive match", + slice: []catalogs.ModelModality{"text"}, + required: []string{"TEXT"}, + expected: true, + }, + { + name: "multiple matches", + slice: []catalogs.ModelModality{"text", "image", "audio"}, + required: []string{"text", "image"}, + expected: true, + }, + { + name: "missing required modality", + slice: []catalogs.ModelModality{"text"}, + required: []string{"text", "image"}, + expected: false, + }, + { + name: "no match", + slice: []catalogs.ModelModality{"text"}, + required: []string{"image"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := modalityContainsAll(tt.slice, tt.required) + if result != tt.expected { + t.Errorf("modalityContainsAll(%v, %v) = %v, want %v", tt.slice, tt.required, result, tt.expected) + } + }) + } +} + +// TestTagContainsAny tests tag matching. +func TestTagContainsAny(t *testing.T) { + tests := []struct { + name string + slice []catalogs.ModelTag + values []string + expected bool + }{ + { + name: "empty values always false", + slice: []catalogs.ModelTag{"chat"}, + values: []string{}, + expected: false, + }, + { + name: "single match", + slice: []catalogs.ModelTag{"chat", "vision"}, + values: []string{"vision"}, + expected: true, + }, + { + name: "case insensitive match", + slice: []catalogs.ModelTag{"chat"}, + values: []string{"CHAT"}, + expected: true, + }, + { + name: "multiple options, one matches", + slice: []catalogs.ModelTag{"chat"}, + values: []string{"vision", "chat", "audio"}, + expected: true, + }, + { + name: "no match", + slice: []catalogs.ModelTag{"chat"}, + values: []string{"vision", "audio"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tagContainsAny(tt.slice, tt.values) + if result != tt.expected { + t.Errorf("tagContainsAny(%v, %v) = %v, want %v", tt.slice, tt.values, result, tt.expected) + } + }) + } +} + +// TestParseIntOrDefault tests integer parsing helper. +func TestParseIntOrDefault(t *testing.T) { + tests := []struct { + name string + input string + def int + expected int + }{ + {"empty string returns default", "", 100, 100}, + {"valid integer", "42", 100, 42}, + {"zero value", "0", 100, 0}, + {"negative value", "-5", 100, -5}, + {"invalid string returns default", "abc", 100, 100}, + {"float returns default", "3.14", 100, 100}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseIntOrDefault(tt.input, tt.def) + if result != tt.expected { + t.Errorf("parseIntOrDefault(%q, %d) = %d, want %d", tt.input, tt.def, result, tt.expected) + } + }) + } +} + +// Helper functions for test comparisons + +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func boolMapEqual(a, b map[string]bool) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } + } + return true +} + +func boolPtrEqual(a, b *bool) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func timePtrEqual(a, b *time.Time) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return a.Equal(*b) +} + +func boolPtr(b bool) *bool { + return &b +} + +func timePtr(t time.Time) *time.Time { + return &t +} diff --git a/internal/server/filter/generate.go b/internal/server/filter/generate.go new file mode 100644 index 000000000..ea59827d6 --- /dev/null +++ b/internal/server/filter/generate.go @@ -0,0 +1,3 @@ +package filter + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/filter diff --git a/internal/server/generate.go b/internal/server/generate.go new file mode 100644 index 000000000..c854dfb97 --- /dev/null +++ b/internal/server/generate.go @@ -0,0 +1,26 @@ +// Package server provides HTTP server implementation for the Starmap API. +// +// The server package implements a clean, layered architecture following Go best practices: +// +// - Server: Core server struct with lifecycle management +// - Config: Server configuration with sensible defaults +// - Router: Route registration and middleware chain +// - Handlers: HTTP request handlers organized by domain +// +// The architecture follows the pattern: CLI → App → Server → Router → Handlers +// +// Usage: +// +// cfg := server.DefaultConfig() +// cfg.Port = 8080 +// +// srv, err := server.New(app, cfg) +// if err != nil { +// log.Fatal(err) +// } +// +// srv.Start() // Start background services +// http.ListenAndServe(":8080", srv.Handler()) +package server + +//go:generate gomarkdoc --output README.md . diff --git a/internal/server/handlers/admin.go b/internal/server/handlers/admin.go new file mode 100644 index 000000000..e7ef24634 --- /dev/null +++ b/internal/server/handlers/admin.go @@ -0,0 +1,95 @@ +package handlers + +import ( + "net/http" + + "github.com/agentstation/starmap/internal/server/events" + "github.com/agentstation/starmap/internal/server/response" + "github.com/agentstation/starmap/pkg/catalogs" + "github.com/agentstation/starmap/pkg/sync" +) + +// HandleUpdate handles POST /api/v1/update. +// @Summary Trigger catalog update +// @Description Manually trigger catalog synchronization +// @Tags admin +// @Accept json +// @Produce json +// @Param provider query string false "Update specific provider only" +// @Success 200 {object} response.Response{data=object} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/update [post]. +func (h *Handlers) HandleUpdate(w http.ResponseWriter, r *http.Request) { + providerFilter := r.URL.Query().Get("provider") + + sm, err := h.app.Starmap() + if err != nil { + response.InternalError(w, err) + return + } + + // Build sync options + var opts []sync.Option + if providerFilter != "" { + opts = append(opts, sync.WithProvider(catalogs.ProviderID(providerFilter))) + } + + // Run sync + result, err := sm.Sync(r.Context(), opts...) + if err != nil { + response.InternalError(w, err) + return + } + + // Invalidate cache + h.cache.Clear() + + // Broadcast update event + h.broker.Publish(events.SyncCompleted, map[string]any{ + "total_changes": result.TotalChanges, + "providers_changed": result.ProvidersChanged, + }) + + response.OK(w, map[string]any{ + "status": "completed", + "total_changes": result.TotalChanges, + "providers_changed": result.ProvidersChanged, + "dry_run": result.DryRun, + }) +} + +// HandleStats handles GET /api/v1/stats. +// @Summary Catalog statistics +// @Description Get catalog statistics (model count, provider count, last sync) +// @Tags admin +// @Accept json +// @Produce json +// @Success 200 {object} response.Response{data=object} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/stats [get]. +func (h *Handlers) HandleStats(w http.ResponseWriter, _ *http.Request) { + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + models := cat.Models().List() + providers := cat.Providers().List() + + response.OK(w, map[string]any{ + "models": map[string]any{ + "total": len(models), + }, + "providers": map[string]any{ + "total": len(providers), + }, + "cache": h.cache.GetStats(), + "realtime": map[string]any{ + "websocket_clients": h.wsHub.ClientCount(), + "sse_clients": h.sseBroadcaster.ClientCount(), + }, + }) +} diff --git a/internal/server/handlers/generate.go b/internal/server/handlers/generate.go new file mode 100644 index 000000000..c23708849 --- /dev/null +++ b/internal/server/handlers/generate.go @@ -0,0 +1,25 @@ +// Package handlers provides HTTP request handlers for the Starmap API. +// +// Handlers are organized by domain for maintainability: +// +// - models.go: Model listing, retrieval, and search +// - providers.go: Provider listing, retrieval, and models +// - admin.go: Administrative operations (update, stats) +// - health.go: Health and readiness checks +// - realtime.go: WebSocket and SSE real-time updates +// - openapi.go: OpenAPI 3.1 specification endpoints +// +// All handlers follow a consistent pattern: +// +// 1. Validate input +// 2. Check cache (if applicable) +// 3. Query catalog/data source +// 4. Transform data +// 5. Cache result (if applicable) +// 6. Return response +// +// Handlers use dependency injection for testability and receive all +// dependencies through the Handlers struct. +package handlers + +//go:generate gomarkdoc --output README.md . diff --git a/internal/server/handlers/handlers.go b/internal/server/handlers/handlers.go new file mode 100644 index 000000000..0a1de7835 --- /dev/null +++ b/internal/server/handlers/handlers.go @@ -0,0 +1,45 @@ +// Package handlers provides HTTP request handlers for the Starmap API. +package handlers + +import ( + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + + "github.com/agentstation/starmap/cmd/application" + "github.com/agentstation/starmap/internal/server/cache" + "github.com/agentstation/starmap/internal/server/events" + "github.com/agentstation/starmap/internal/server/sse" + ws "github.com/agentstation/starmap/internal/server/websocket" +) + +// Handlers provides access to all HTTP handlers. +type Handlers struct { + app application.Application + cache *cache.Cache + broker *events.Broker + wsHub *ws.Hub + sseBroadcaster *sse.Broadcaster + upgrader websocket.Upgrader + logger *zerolog.Logger +} + +// New creates a new Handlers instance. +func New( + app application.Application, + cache *cache.Cache, + broker *events.Broker, + wsHub *ws.Hub, + sseBroadcaster *sse.Broadcaster, + upgrader websocket.Upgrader, + logger *zerolog.Logger, +) *Handlers { + return &Handlers{ + app: app, + cache: cache, + broker: broker, + wsHub: wsHub, + sseBroadcaster: sseBroadcaster, + upgrader: upgrader, + logger: logger, + } +} diff --git a/internal/server/handlers/health.go b/internal/server/handlers/health.go new file mode 100644 index 000000000..c23607437 --- /dev/null +++ b/internal/server/handlers/health.go @@ -0,0 +1,50 @@ +package handlers + +import ( + "net/http" + + "github.com/agentstation/starmap/internal/server/response" +) + +// HandleHealth handles GET /api/v1/health. +// @Summary Health check +// @Description Health check endpoint (liveness probe) +// @Tags health +// @Accept json +// @Produce json +// @Success 200 {object} response.Response{data=object} +// @Router /api/v1/health [get]. +func (h *Handlers) HandleHealth(w http.ResponseWriter, _ *http.Request) { + response.OK(w, map[string]any{ + "status": "healthy", + "service": "starmap-api", + "version": "v1", + }) +} + +// HandleReady handles GET /api/v1/ready. +// @Summary Readiness check +// @Description Readiness check including cache and data source status +// @Tags health +// @Accept json +// @Produce json +// @Success 200 {object} response.Response{data=object} +// @Failure 503 {object} response.Response{error=response.Error} +// @Router /api/v1/ready [get]. +func (h *Handlers) HandleReady(w http.ResponseWriter, _ *http.Request) { + // Check catalog availability + _, err := h.app.Catalog() + if err != nil { + response.ServiceUnavailable(w, "Catalog not available") + return + } + + response.OK(w, map[string]any{ + "status": "ready", + "cache": map[string]any{ + "items": h.cache.ItemCount(), + }, + "websocket_clients": h.wsHub.ClientCount(), + "sse_clients": h.sseBroadcaster.ClientCount(), + }) +} diff --git a/internal/server/handlers/models.go b/internal/server/handlers/models.go new file mode 100644 index 000000000..50e7044d0 --- /dev/null +++ b/internal/server/handlers/models.go @@ -0,0 +1,251 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/agentstation/starmap/internal/server/filter" + "github.com/agentstation/starmap/internal/server/response" + "github.com/agentstation/starmap/pkg/catalogs" +) + +// HandleListModels handles GET /api/v1/models. +// @Summary List models +// @Description List all models with optional filtering +// @Tags models +// @Accept json +// @Produce json +// @Param id query string false "Filter by exact model ID" +// @Param name query string false "Filter by exact model name (case-insensitive)" +// @Param name_contains query string false "Filter by partial model name match" +// @Param provider query string false "Filter by provider ID" +// @Param modality_input query string false "Filter by input modality (comma-separated)" +// @Param modality_output query string false "Filter by output modality (comma-separated)" +// @Param feature query string false "Filter by feature (streaming, tool_calls, etc.)" +// @Param tag query string false "Filter by tag (comma-separated)" +// @Param open_weights query boolean false "Filter by open weights status" +// @Param min_context query integer false "Minimum context window size" +// @Param max_context query integer false "Maximum context window size" +// @Param sort query string false "Sort field (id, name, release_date, context_window, created_at, updated_at)" +// @Param order query string false "Sort order (asc, desc)" +// @Param limit query integer false "Maximum number of results (default: 100, max: 1000)" +// @Param offset query integer false "Result offset for pagination" +// @Success 200 {object} response.Response{data=object} +// @Failure 400 {object} response.Response{error=response.Error} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/models [get]. +func (h *Handlers) HandleListModels(w http.ResponseWriter, r *http.Request) { + // Check cache + cacheKey := "models:" + r.URL.RawQuery + if cached, found := h.cache.Get(cacheKey); found { + response.OK(w, cached) + return + } + + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + // Parse filters + f := filter.ParseModelFilter(r) + + // Get and filter models + allModels := cat.Models().List() + filtered := f.Apply(allModels) + + // Apply pagination + total := len(filtered) + start := f.Offset + end := f.Offset + f.Limit + + if start >= total { + filtered = []catalogs.Model{} + } else { + if end > total { + end = total + } + filtered = filtered[start:end] + } + + // Build response + result := map[string]any{ + "models": filtered, + "pagination": map[string]any{ + "total": total, + "limit": f.Limit, + "offset": f.Offset, + "count": len(filtered), + }, + } + + // Cache result + h.cache.Set(cacheKey, result) + + response.OK(w, result) +} + +// HandleGetModel handles GET /api/v1/models/{id}. +// @Summary Get model by ID +// @Description Retrieve detailed information about a specific model +// @Tags models +// @Accept json +// @Produce json +// @Param id path string true "Model ID" +// @Success 200 {object} response.Response{data=catalogs.Model} +// @Failure 404 {object} response.Response{error=response.Error} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/models/{id} [get]. +func (h *Handlers) HandleGetModel(w http.ResponseWriter, _ *http.Request, modelID string) { + // Check cache + cacheKey := "model:" + modelID + if cached, found := h.cache.Get(cacheKey); found { + response.OK(w, cached) + return + } + + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + // Find model + model, err := cat.FindModel(modelID) + if err != nil { + response.ErrorFromType(w, err) + return + } + + // Cache result + h.cache.Set(cacheKey, model) + + response.OK(w, model) +} + +// SearchRequest represents the POST /api/v1/models/search request body. +type SearchRequest struct { + IDs []string `json:"ids,omitempty"` + NameContains string `json:"name_contains,omitempty"` + Provider string `json:"provider,omitempty"` + Modalities *SearchModalities `json:"modalities,omitempty"` + Features map[string]bool `json:"features,omitempty"` + Tags []string `json:"tags,omitempty"` + OpenWeights *bool `json:"open_weights,omitempty"` + ContextWindow *IntRange `json:"context_window,omitempty"` + OutputTokens *IntRange `json:"output_tokens,omitempty"` + ReleaseDate *DateRange `json:"release_date,omitempty"` + Sort string `json:"sort,omitempty"` + Order string `json:"order,omitempty"` + MaxResults int `json:"max_results,omitempty"` +} + +// SearchModalities specifies modality requirements. +type SearchModalities struct { + Input []string `json:"input,omitempty"` + Output []string `json:"output,omitempty"` +} + +// IntRange represents an integer range filter. +type IntRange struct { + Min int64 `json:"min,omitempty"` + Max int64 `json:"max,omitempty"` +} + +// DateRange represents a date range filter. +type DateRange struct { + After string `json:"after,omitempty"` + Before string `json:"before,omitempty"` +} + +// HandleSearchModels handles POST /api/v1/models/search. +// @Summary Search models +// @Description Advanced search with multiple criteria +// @Tags models +// @Accept json +// @Produce json +// @Param search body SearchRequest true "Search criteria" +// @Success 200 {object} response.Response{data=object} +// @Failure 400 {object} response.Response{error=response.Error} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/models/search [post]. +func (h *Handlers) HandleSearchModels(w http.ResponseWriter, r *http.Request) { + var req SearchRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + response.BadRequest(w, "Invalid JSON request body", err.Error()) + return + } + + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + // Convert search request to filter + f := filter.ModelFilter{ + NameContains: req.NameContains, + Provider: req.Provider, + Features: req.Features, + Tags: req.Tags, + OpenWeights: req.OpenWeights, + Sort: req.Sort, + Order: req.Order, + Limit: 100, + MaxResults: req.MaxResults, + } + + if req.Modalities != nil { + f.ModalityInput = req.Modalities.Input + f.ModalityOutput = req.Modalities.Output + } + + if req.ContextWindow != nil { + f.MinContext = req.ContextWindow.Min + f.MaxContext = req.ContextWindow.Max + } + + if req.OutputTokens != nil { + f.MinOutput = req.OutputTokens.Min + f.MaxOutput = req.OutputTokens.Max + } + + // Apply filters + allModels := cat.Models().List() + results := f.Apply(allModels) + + // Filter by IDs if specified + if len(req.IDs) > 0 { + filtered := make([]catalogs.Model, 0, len(req.IDs)) + idMap := make(map[string]bool) + for _, id := range req.IDs { + idMap[id] = true + } + for _, model := range results { + if idMap[model.ID] { + filtered = append(filtered, model) + } + } + results = filtered + } + + // Apply max results limit + if req.MaxResults > 0 && len(results) > req.MaxResults { + results = results[:req.MaxResults] + } + + // Build response + result := map[string]any{ + "models": results, + "count": len(results), + } + + response.OK(w, result) +} diff --git a/internal/server/handlers/openapi.go b/internal/server/handlers/openapi.go new file mode 100644 index 000000000..811f43439 --- /dev/null +++ b/internal/server/handlers/openapi.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "net/http" + + "github.com/agentstation/starmap/internal/embedded/openapi" +) + +// HandleOpenAPIJSON serves the embedded OpenAPI 3.1 specification in JSON format. +// @Summary Get OpenAPI specification (JSON) +// @Description Returns the OpenAPI 3.1 specification for this API in JSON format +// @Tags meta +// @Produce json +// @Success 200 {object} object "OpenAPI 3.1 specification" +// @Router /api/v1/openapi.json [get]. +func (h *Handlers) HandleOpenAPIJSON(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=3600") // Cache for 1 hour + _, _ = w.Write(openapi.SpecJSON) +} + +// HandleOpenAPIYAML serves the embedded OpenAPI 3.1 specification in YAML format. +// @Summary Get OpenAPI specification (YAML) +// @Description Returns the OpenAPI 3.1 specification for this API in YAML format +// @Tags meta +// @Produce application/x-yaml +// @Success 200 {string} string "OpenAPI 3.1 specification" +// @Router /api/v1/openapi.yaml [get]. +func (h *Handlers) HandleOpenAPIYAML(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/x-yaml") + w.Header().Set("Cache-Control", "public, max-age=3600") // Cache for 1 hour + _, _ = w.Write(openapi.SpecYAML) +} diff --git a/internal/server/handlers/providers.go b/internal/server/handlers/providers.go new file mode 100644 index 000000000..1a8ae14ff --- /dev/null +++ b/internal/server/handlers/providers.go @@ -0,0 +1,151 @@ +package handlers + +import ( + "net/http" + + "github.com/agentstation/starmap/internal/cmd/provider" + "github.com/agentstation/starmap/internal/server/response" + "github.com/agentstation/starmap/pkg/catalogs" +) + +// HandleListProviders handles GET /api/v1/providers. +// @Summary List providers +// @Description List all providers +// @Tags providers +// @Accept json +// @Produce json +// @Success 200 {object} response.Response{data=object} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/providers [get]. +func (h *Handlers) HandleListProviders(w http.ResponseWriter, _ *http.Request) { + // Check cache + if cached, found := h.cache.Get("providers"); found { + response.OK(w, cached) + return + } + + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + providers := cat.Providers().List() + + // Build simplified provider list + providerList := make([]map[string]any, 0, len(providers)) + for _, prov := range providers { + providerInfo := map[string]any{ + "id": prov.ID, + "name": prov.Name, + "model_count": len(prov.Models), + } + + if prov.Headquarters != nil { + providerInfo["headquarters"] = *prov.Headquarters + } + + if prov.Catalog != nil && prov.Catalog.Docs != nil { + providerInfo["docs_url"] = *prov.Catalog.Docs + } + + providerList = append(providerList, providerInfo) + } + + result := map[string]any{ + "providers": providerList, + "count": len(providerList), + } + + // Cache result + h.cache.Set("providers", result) + + response.OK(w, result) +} + +// HandleGetProvider handles GET /api/v1/providers/{id}. +// @Summary Get provider by ID +// @Description Retrieve detailed information about a specific provider +// @Tags providers +// @Accept json +// @Produce json +// @Param id path string true "Provider ID" +// @Success 200 {object} response.Response{data=catalogs.Provider} +// @Failure 404 {object} response.Response{error=response.Error} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/providers/{id} [get]. +func (h *Handlers) HandleGetProvider(w http.ResponseWriter, _ *http.Request, providerID string) { + // Check cache + cacheKey := "provider:" + providerID + if cached, found := h.cache.Get(cacheKey); found { + response.OK(w, cached) + return + } + + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + // Get provider + prov, err := provider.Get(cat, providerID) + if err != nil { + response.ErrorFromType(w, err) + return + } + + // Cache result + h.cache.Set(cacheKey, prov) + + response.OK(w, prov) +} + +// HandleGetProviderModels handles GET /api/v1/providers/{id}/models. +// @Summary Get provider models +// @Description List all models for a specific provider +// @Tags providers +// @Accept json +// @Produce json +// @Param id path string true "Provider ID" +// @Success 200 {object} response.Response{data=object} +// @Failure 404 {object} response.Response{error=response.Error} +// @Failure 500 {object} response.Response{error=response.Error} +// @Security ApiKeyAuth +// @Router /api/v1/providers/{id}/models [get]. +func (h *Handlers) HandleGetProviderModels(w http.ResponseWriter, _ *http.Request, providerID string) { + // Get catalog + cat, err := h.app.Catalog() + if err != nil { + response.InternalError(w, err) + return + } + + // Get provider + prov, err := provider.Get(cat, providerID) + if err != nil { + response.ErrorFromType(w, err) + return + } + + // Convert map to slice + models := make([]*catalogs.Model, 0, len(prov.Models)) + for _, model := range prov.Models { + models = append(models, model) + } + + result := map[string]any{ + "provider": map[string]any{ + "id": prov.ID, + "name": prov.Name, + }, + "models": models, + "count": len(models), + } + + response.OK(w, result) +} diff --git a/internal/server/handlers/realtime.go b/internal/server/handlers/realtime.go new file mode 100644 index 000000000..735b6b8fa --- /dev/null +++ b/internal/server/handlers/realtime.go @@ -0,0 +1,45 @@ +package handlers + +import ( + "fmt" + "net/http" + "time" + + ws "github.com/agentstation/starmap/internal/server/websocket" +) + +// HandleWebSocket handles WebSocket connections at /api/v1/updates/ws. +// @Summary WebSocket updates +// @Description WebSocket connection for real-time catalog updates +// @Tags updates +// @Success 101 "Switching Protocols" +// @Router /api/v1/updates/ws [get]. +func (h *Handlers) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Error().Err(err).Msg("WebSocket upgrade failed") + return + } + + // Create and register client + clientID := fmt.Sprintf("%s-%d", r.RemoteAddr, time.Now().Unix()) + client := ws.NewClient(clientID, h.wsHub, conn) + + // Register client with hub (this connects it to the event stream) + h.wsHub.Register(client) + + // Start client pumps (read and write must run concurrently) + go client.WritePump() + go client.ReadPump() +} + +// HandleSSE handles Server-Sent Events at /api/v1/updates/stream. +// @Summary SSE updates stream +// @Description Server-Sent Events stream for catalog change notifications +// @Tags updates +// @Produce text/event-stream +// @Success 200 "Event stream" +// @Router /api/v1/updates/stream [get]. +func (h *Handlers) HandleSSE(w http.ResponseWriter, r *http.Request) { + h.sseBroadcaster.ServeHTTP(w, r) +} diff --git a/internal/server/middleware/README.md b/internal/server/middleware/README.md new file mode 100644 index 000000000..94f240f52 --- /dev/null +++ b/internal/server/middleware/README.md @@ -0,0 +1,153 @@ + + + + +# middleware + +```go +import "github.com/agentstation/starmap/internal/server/middleware" +``` + +Package middleware provides HTTP middleware for the Starmap API server. It includes logging, recovery, CORS, authentication, and rate limiting. + +## Index + +- [func Auth\(config AuthConfig, logger \*zerolog.Logger\) func\(http.Handler\) http.Handler](<#Auth>) +- [func CORS\(config CORSConfig\) func\(http.Handler\) http.Handler](<#CORS>) +- [func Chain\(middlewares ...func\(http.Handler\) http.Handler\) func\(http.Handler\) http.Handler](<#Chain>) +- [func Logger\(logger \*zerolog.Logger\) func\(http.Handler\) http.Handler](<#Logger>) +- [func RateLimit\(rl \*RateLimiter\) func\(http.Handler\) http.Handler](<#RateLimit>) +- [func Recovery\(logger \*zerolog.Logger\) func\(http.Handler\) http.Handler](<#Recovery>) +- [type AuthConfig](<#AuthConfig>) + - [func DefaultAuthConfig\(\) AuthConfig](<#DefaultAuthConfig>) +- [type CORSConfig](<#CORSConfig>) + - [func DefaultCORSConfig\(\) CORSConfig](<#DefaultCORSConfig>) +- [type RateLimiter](<#RateLimiter>) + - [func NewRateLimiter\(limit int, logger \*zerolog.Logger\) \*RateLimiter](<#NewRateLimiter>) + + + +## func [Auth]() + +```go +func Auth(config AuthConfig, logger *zerolog.Logger) func(http.Handler) http.Handler +``` + +Auth middleware validates API keys for protected endpoints. + + +## func [CORS]() + +```go +func CORS(config CORSConfig) func(http.Handler) http.Handler +``` + +CORS middleware adds CORS headers to responses. + + +## func [Chain]() + +```go +func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler +``` + +Chain combines multiple middleware functions into a single middleware. + + +## func [Logger]() + +```go +func Logger(logger *zerolog.Logger) func(http.Handler) http.Handler +``` + +Logger logs HTTP requests with structured logging. + + +## func [RateLimit]() + +```go +func RateLimit(rl *RateLimiter) func(http.Handler) http.Handler +``` + +RateLimit middleware limits requests per IP address. + + +## func [Recovery]() + +```go +func Recovery(logger *zerolog.Logger) func(http.Handler) http.Handler +``` + +Recovery recovers from panics and returns 500 error. + + +## type [AuthConfig]() + +AuthConfig holds authentication configuration. + +```go +type AuthConfig struct { + Enabled bool + APIKey string + HeaderName string + PublicPaths []string + BearerPrefix bool +} +``` + + +### func [DefaultAuthConfig]() + +```go +func DefaultAuthConfig() AuthConfig +``` + +DefaultAuthConfig returns default authentication configuration. + + +## type [CORSConfig]() + +CORSConfig holds CORS configuration. + +```go +type CORSConfig struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string + AllowAll bool +} +``` + + +### func [DefaultCORSConfig]() + +```go +func DefaultCORSConfig() CORSConfig +``` + +DefaultCORSConfig returns the default CORS configuration. + + +## type [RateLimiter]() + +RateLimiter implements token bucket rate limiting per IP address. + +```go +type RateLimiter struct { + // contains filtered or unexported fields +} +``` + + +### func [NewRateLimiter]() + +```go +func NewRateLimiter(limit int, logger *zerolog.Logger) *RateLimiter +``` + +NewRateLimiter creates a new rate limiter. limit is requests per minute per IP. + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/middleware/auth.go b/internal/server/middleware/auth.go new file mode 100644 index 000000000..6f6ef00c9 --- /dev/null +++ b/internal/server/middleware/auth.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "net/http" + "os" + "strings" + + "github.com/rs/zerolog" +) + +// AuthConfig holds authentication configuration. +type AuthConfig struct { + Enabled bool + APIKey string + HeaderName string + PublicPaths []string + BearerPrefix bool +} + +// DefaultAuthConfig returns default authentication configuration. +func DefaultAuthConfig() AuthConfig { + return AuthConfig{ + Enabled: false, + APIKey: os.Getenv("API_KEY"), + HeaderName: "X-API-Key", + PublicPaths: []string{"/health", "/api/v1/health", "/api/v1/ready", "/api/v1/openapi.json"}, + BearerPrefix: false, + } +} + +// Auth middleware validates API keys for protected endpoints. +func Auth(config AuthConfig, logger *zerolog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip authentication if disabled + if !config.Enabled { + next.ServeHTTP(w, r) + return + } + + // Skip authentication for public paths + if isPublicPath(r.URL.Path, config.PublicPaths) { + next.ServeHTTP(w, r) + return + } + + // Extract API key from header + apiKey := extractAPIKey(r, config) + + // Validate API key + if apiKey == "" || apiKey != config.APIKey { + logger.Warn(). + Str("path", r.URL.Path). + Str("remote_addr", r.RemoteAddr). + Bool("key_provided", apiKey != ""). + Msg("Authentication failed") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"data":null,"error":{"code":"UNAUTHORIZED","message":"Invalid or missing API key","details":"Provide a valid API key in the ` + config.HeaderName + ` header"}}`)) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// isPublicPath checks if a path is in the public paths list. +func isPublicPath(path string, publicPaths []string) bool { + for _, p := range publicPaths { + if path == p { + return true + } + } + return false +} + +// extractAPIKey extracts the API key from the request. +func extractAPIKey(r *http.Request, config AuthConfig) string { + // Try custom header first (X-API-Key) + apiKey := r.Header.Get(config.HeaderName) + if apiKey != "" { + return apiKey + } + + // Try Authorization header + auth := r.Header.Get("Authorization") + if auth != "" { + // Support both "Bearer " and raw key + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + return auth + } + + return "" +} diff --git a/internal/server/middleware/auth_test.go b/internal/server/middleware/auth_test.go new file mode 100644 index 000000000..bceacfa92 --- /dev/null +++ b/internal/server/middleware/auth_test.go @@ -0,0 +1,432 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/rs/zerolog" +) + +// TestDefaultAuthConfig tests default configuration. +func TestDefaultAuthConfig(t *testing.T) { + config := DefaultAuthConfig() + + if config.Enabled { + t.Error("expected Enabled=false by default") + } + if config.HeaderName != "X-API-Key" { + t.Errorf("expected HeaderName=X-API-Key, got %s", config.HeaderName) + } + if len(config.PublicPaths) == 0 { + t.Error("expected default public paths to be set") + } + if config.BearerPrefix { + t.Error("expected BearerPrefix=false by default") + } +} + +// TestAuth tests the Auth middleware with various scenarios. +func TestAuth(t *testing.T) { + logger := zerolog.Nop() + + tests := []struct { + name string + config AuthConfig + path string + headers map[string]string + expectedStatus int + expectedPass bool + }{ + { + name: "auth disabled - always pass", + config: AuthConfig{ + Enabled: false, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{}, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "public path - always pass", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{"/health", "/api/v1/health"}, + }, + path: "/health", + headers: map[string]string{}, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "valid API key in custom header", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "X-API-Key": "secret-key", + }, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "valid API key in Authorization header", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "Authorization": "Bearer secret-key", + }, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "valid API key without Bearer prefix", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "Authorization": "secret-key", + }, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "missing API key", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{}, + expectedStatus: http.StatusUnauthorized, + expectedPass: false, + }, + { + name: "invalid API key", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "X-API-Key": "wrong-key", + }, + expectedStatus: http.StatusUnauthorized, + expectedPass: false, + }, + { + name: "invalid Bearer token", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "Authorization": "Bearer wrong-key", + }, + expectedStatus: http.StatusUnauthorized, + expectedPass: false, + }, + { + name: "empty API key", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "X-API-Key": "", + }, + expectedStatus: http.StatusUnauthorized, + expectedPass: false, + }, + { + name: "custom header name", + config: AuthConfig{ + Enabled: true, + APIKey: "custom-key", + HeaderName: "X-Custom-Auth", + PublicPaths: []string{}, + }, + path: "/api/v1/models", + headers: map[string]string{ + "X-Custom-Auth": "custom-key", + }, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + { + name: "multiple public paths", + config: AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{"/health", "/ready", "/api/v1/openapi.json"}, + }, + path: "/api/v1/openapi.json", + headers: map[string]string{}, + expectedStatus: http.StatusOK, + expectedPass: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test handler that tracks if it was called + handlerCalled := false + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + // Wrap with auth middleware + middleware := Auth(tt.config, &logger) + handler := middleware(testHandler) + + // Create request + req := httptest.NewRequest("GET", tt.path, nil) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + // Record response + w := httptest.NewRecorder() + + // Execute + handler.ServeHTTP(w, req) + + // Verify status code + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + // Verify handler was called (or not) + if handlerCalled != tt.expectedPass { + t.Errorf("expected handler called=%v, got %v", tt.expectedPass, handlerCalled) + } + + // Verify unauthorized response format + if !tt.expectedPass { + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("expected Content-Type=application/json for error, got %s", contentType) + } + + // Body should contain error JSON + body := w.Body.String() + if body == "" { + t.Error("expected error response body") + } + if !contains(body, "UNAUTHORIZED") { + t.Error("expected UNAUTHORIZED in error response") + } + } + }) + } +} + +// TestIsPublicPath tests public path matching. +func TestIsPublicPath(t *testing.T) { + publicPaths := []string{"/health", "/ready", "/api/v1/openapi.json"} + + tests := []struct { + path string + expected bool + }{ + {"/health", true}, + {"/ready", true}, + {"/api/v1/openapi.json", true}, + {"/api/v1/models", false}, + {"/health/sub", false}, // Exact match only + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := isPublicPath(tt.path, publicPaths) + if result != tt.expected { + t.Errorf("isPublicPath(%q) = %v, want %v", tt.path, result, tt.expected) + } + }) + } +} + +// TestExtractAPIKey tests API key extraction from various headers. +func TestExtractAPIKey(t *testing.T) { + tests := []struct { + name string + config AuthConfig + headers map[string]string + expected string + }{ + { + name: "from custom header", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{ + "X-API-Key": "test-key", + }, + expected: "test-key", + }, + { + name: "from Authorization with Bearer", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{ + "Authorization": "Bearer test-key", + }, + expected: "test-key", + }, + { + name: "from Authorization without Bearer", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{ + "Authorization": "test-key", + }, + expected: "test-key", + }, + { + name: "custom header takes precedence", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{ + "X-API-Key": "custom-key", + "Authorization": "Bearer auth-key", + }, + expected: "custom-key", + }, + { + name: "no API key", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{}, + expected: "", + }, + { + name: "empty header value", + config: AuthConfig{ + HeaderName: "X-API-Key", + }, + headers: map[string]string{ + "X-API-Key": "", + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + for key, value := range tt.headers { + req.Header.Set(key, value) + } + + result := extractAPIKey(req, tt.config) + if result != tt.expected { + t.Errorf("extractAPIKey() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestAuth_ConcurrentRequests tests auth middleware under concurrent load. +func TestAuth_ConcurrentRequests(t *testing.T) { + logger := zerolog.Nop() + config := AuthConfig{ + Enabled: true, + APIKey: "secret-key", + HeaderName: "X-API-Key", + PublicPaths: []string{"/health"}, + } + + middleware := Auth(config, &logger) + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + // Run concurrent requests + const numRequests = 100 + done := make(chan bool, numRequests) + + for i := 0; i < numRequests; i++ { + go func(id int) { + // Alternate between valid and invalid keys + key := "secret-key" + if id%2 == 0 { + key = "wrong-key" + } + + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.Header.Set("X-API-Key", key) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // Verify expected status + if id%2 == 0 { + if w.Code != http.StatusUnauthorized { + t.Errorf("request %d: expected 401, got %d", id, w.Code) + } + } else { + if w.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", id, w.Code) + } + } + + done <- true + }(i) + } + + // Wait for all requests + for i := 0; i < numRequests; i++ { + <-done + } +} + +// Helper function to check if a string contains a substring. +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsInner(s, substr))) +} + +func containsInner(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/server/middleware/cors.go b/internal/server/middleware/cors.go new file mode 100644 index 000000000..31e9e9ed2 --- /dev/null +++ b/internal/server/middleware/cors.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// CORSConfig holds CORS configuration. +type CORSConfig struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string + AllowAll bool +} + +// DefaultCORSConfig returns the default CORS configuration. +func DefaultCORSConfig() CORSConfig { + return CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization", "X-API-Key"}, + AllowAll: false, + } +} + +// CORS middleware adds CORS headers to responses. +func CORS(config CORSConfig) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // Set CORS headers + if config.AllowAll || len(config.AllowedOrigins) == 0 { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else if origin != "" && isOriginAllowed(origin, config.AllowedOrigins) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + } + + w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + + // Handle preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// isOriginAllowed checks if an origin is in the allowed list. +func isOriginAllowed(origin string, allowed []string) bool { + for _, o := range allowed { + if o == "*" || o == origin { + return true + } + } + return false +} diff --git a/internal/server/middleware/cors_test.go b/internal/server/middleware/cors_test.go new file mode 100644 index 000000000..96c0f2eef --- /dev/null +++ b/internal/server/middleware/cors_test.go @@ -0,0 +1,375 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// TestDefaultCORSConfig tests default CORS configuration. +func TestDefaultCORSConfig(t *testing.T) { + config := DefaultCORSConfig() + + if config.AllowAll { + t.Error("expected AllowAll=false by default") + } + if len(config.AllowedOrigins) == 0 { + t.Error("expected default allowed origins") + } + if config.AllowedOrigins[0] != "*" { + t.Errorf("expected first origin to be *, got %s", config.AllowedOrigins[0]) + } +} + +// TestCORS tests the CORS middleware with various scenarios. +func TestCORS(t *testing.T) { + tests := []struct { + name string + config CORSConfig + method string + origin string + expectHeaders map[string]string + expectNoHeader bool + }{ + { + name: "allow all - wildcard", + config: CORSConfig{ + AllowAll: true, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type"}, + }, + method: "GET", + origin: "https://example.com", + expectHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + }, + }, + { + name: "specific origin allowed", + config: CORSConfig{ + AllowAll: false, + AllowedOrigins: []string{"https://example.com", "https://app.example.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedHeaders: []string{"Content-Type"}, + }, + method: "GET", + origin: "https://example.com", + expectHeaders: map[string]string{ + "Access-Control-Allow-Origin": "https://example.com", + }, + }, + { + name: "origin not allowed", + config: CORSConfig{ + AllowAll: false, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"Content-Type"}, + }, + method: "GET", + origin: "https://evil.com", + expectNoHeader: true, + }, + { + name: "no origin header - allow all", + config: CORSConfig{ + AllowAll: true, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"Content-Type"}, + }, + method: "GET", + origin: "", + expectHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + }, + }, + { + name: "preflight request", + config: CORSConfig{ + AllowAll: true, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + }, + method: "OPTIONS", + origin: "https://example.com", + expectHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test handler + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Wrap with CORS middleware + middleware := CORS(tt.config) + handler := middleware(testHandler) + + // Create request + req := httptest.NewRequest(tt.method, "/api/v1/models", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + + // Record response + w := httptest.NewRecorder() + + // Execute + handler.ServeHTTP(w, req) + + // Verify headers + if tt.expectNoHeader { + if w.Header().Get("Access-Control-Allow-Origin") != "" { + t.Error("expected no CORS headers, but found Access-Control-Allow-Origin") + } + } else { + for header, expectedValue := range tt.expectHeaders { + actualValue := w.Header().Get(header) + if actualValue != expectedValue { + t.Errorf("header %s: expected %q, got %q", header, expectedValue, actualValue) + } + } + } + + // Preflight should return 200 + if tt.method == "OPTIONS" && !tt.expectNoHeader { + if w.Code != http.StatusOK { + t.Errorf("preflight: expected status 200, got %d", w.Code) + } + } + }) + } +} + +// TestIsOriginAllowed tests origin matching logic. +func TestIsOriginAllowed(t *testing.T) { + tests := []struct { + name string + allowedOrigins []string + origin string + expected bool + }{ + { + name: "exact match", + allowedOrigins: []string{"https://example.com"}, + origin: "https://example.com", + expected: true, + }, + { + name: "no match", + allowedOrigins: []string{"https://example.com"}, + origin: "https://evil.com", + expected: false, + }, + { + name: "multiple origins - matches first", + allowedOrigins: []string{"https://example.com", "https://app.example.com"}, + origin: "https://example.com", + expected: true, + }, + { + name: "multiple origins - matches second", + allowedOrigins: []string{"https://example.com", "https://app.example.com"}, + origin: "https://app.example.com", + expected: true, + }, + { + name: "empty allowed list", + allowedOrigins: []string{}, + origin: "https://example.com", + expected: false, + }, + { + name: "empty origin", + allowedOrigins: []string{"https://example.com"}, + origin: "", + expected: false, + }, + { + name: "case sensitive", + allowedOrigins: []string{"https://example.com"}, + origin: "https://Example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isOriginAllowed(tt.origin, tt.allowedOrigins) + if result != tt.expected { + t.Errorf("isOriginAllowed(%q, %v) = %v, want %v", tt.origin, tt.allowedOrigins, result, tt.expected) + } + }) + } +} + +// TestCORS_PreflightShortCircuit tests that preflight requests don't call the next handler. +func TestCORS_PreflightShortCircuit(t *testing.T) { + config := CORSConfig{ + AllowAll: true, + AllowedMethods: []string{"GET", "POST", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type"}, + } + + handlerCalled := false + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := CORS(config) + handler := middleware(testHandler) + + // OPTIONS request (preflight) + req := httptest.NewRequest("OPTIONS", "/api/v1/models", nil) + req.Header.Set("Origin", "https://example.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // Handler should NOT be called for preflight + if handlerCalled { + t.Error("expected handler to not be called for preflight request") + } + + // Should return 200 + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } +} + +// TestCORS_ActualRequestPassthrough tests that actual requests pass through to handler. +func TestCORS_ActualRequestPassthrough(t *testing.T) { + config := CORSConfig{ + AllowAll: true, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"Content-Type"}, + } + + handlerCalled := false + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := CORS(config) + handler := middleware(testHandler) + + // GET request (actual request) + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.Header.Set("Origin", "https://example.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // Handler SHOULD be called for actual request + if !handlerCalled { + t.Error("expected handler to be called for actual request") + } + + // Should have CORS headers + if w.Header().Get("Access-Control-Allow-Origin") == "" { + t.Error("expected Access-Control-Allow-Origin header") + } +} + +// TestCORS_MultipleOrigins tests handling multiple allowed origins. +func TestCORS_MultipleOrigins(t *testing.T) { + config := CORSConfig{ + AllowAll: false, + AllowedOrigins: []string{ + "https://example.com", + "https://app.example.com", + "https://admin.example.com", + }, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"Content-Type"}, + } + + middleware := CORS(config) + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + // Test each allowed origin + for _, origin := range config.AllowedOrigins { + t.Run("origin_"+origin, func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.Header.Set("Origin", origin) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + allowedOrigin := w.Header().Get("Access-Control-Allow-Origin") + if allowedOrigin != origin { + t.Errorf("expected Access-Control-Allow-Origin=%s, got %s", origin, allowedOrigin) + } + }) + } + + // Test disallowed origin + t.Run("disallowed_origin", func(t *testing.T) { + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.Header.Set("Origin", "https://evil.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Header().Get("Access-Control-Allow-Origin") != "" { + t.Error("expected no Access-Control-Allow-Origin for disallowed origin") + } + }) +} + +// TestCORS_ConcurrentRequests tests CORS middleware under concurrent load. +func TestCORS_ConcurrentRequests(t *testing.T) { + config := CORSConfig{ + AllowAll: false, + AllowedOrigins: []string{"https://example.com", "https://app.example.com"}, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{"Content-Type"}, + } + + middleware := CORS(config) + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + // Run concurrent requests with different origins + const numRequests = 100 + done := make(chan bool, numRequests) + + for i := 0; i < numRequests; i++ { + go func(id int) { + origin := "https://example.com" + if id%2 == 0 { + origin = "https://app.example.com" + } + + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.Header.Set("Origin", origin) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // Verify CORS header is set correctly + allowedOrigin := w.Header().Get("Access-Control-Allow-Origin") + if allowedOrigin != origin { + t.Errorf("request %d: expected origin %s, got %s", id, origin, allowedOrigin) + } + + done <- true + }(i) + } + + // Wait for all requests + for i := 0; i < numRequests; i++ { + <-done + } +} diff --git a/internal/server/middleware/generate.go b/internal/server/middleware/generate.go new file mode 100644 index 000000000..f784e5fbe --- /dev/null +++ b/internal/server/middleware/generate.go @@ -0,0 +1,3 @@ +package middleware + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/middleware diff --git a/internal/server/middleware/middleware.go b/internal/server/middleware/middleware.go new file mode 100644 index 000000000..4977bd558 --- /dev/null +++ b/internal/server/middleware/middleware.go @@ -0,0 +1,88 @@ +// Package middleware provides HTTP middleware for the Starmap API server. +// It includes logging, recovery, CORS, authentication, and rate limiting. +package middleware + +import ( + "net/http" + "time" + + "github.com/rs/zerolog" +) + +// Chain combines multiple middleware functions into a single middleware. +func Chain(middlewares ...func(http.Handler) http.Handler) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + next = middlewares[i](next) + } + return next + } +} + +// Logger logs HTTP requests with structured logging. +func Logger(logger *zerolog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Wrap response writer to capture status code + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // Add logger to request context + ctx := logger.With(). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("remote_addr", r.RemoteAddr). + Str("user_agent", r.UserAgent()). + Logger(). + WithContext(r.Context()) + + // Process request + next.ServeHTTP(wrapped, r.WithContext(ctx)) + + // Log request completion + duration := time.Since(start) + logger.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Int("status", wrapped.statusCode). + Dur("duration_ms", duration). + Str("remote_addr", r.RemoteAddr). + Msg("HTTP request") + }) + } +} + +// Recovery recovers from panics and returns 500 error. +func Recovery(logger *zerolog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + logger.Error(). + Interface("panic", err). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Panic recovered") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"data":null,"error":{"code":"INTERNAL_ERROR","message":"Internal server error","details":"An unexpected error occurred"}}`)) + } + }() + + next.ServeHTTP(w, r) + }) + } +} + +// responseWriter wraps http.ResponseWriter to capture status code. +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} diff --git a/internal/server/middleware/middleware_test.go b/internal/server/middleware/middleware_test.go new file mode 100644 index 000000000..c5acf7c75 --- /dev/null +++ b/internal/server/middleware/middleware_test.go @@ -0,0 +1,567 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" +) + +// TestChain tests middleware composition. +func TestChain(t *testing.T) { + tests := []struct { + name string + numMiddleware int + expectedCallOrder []string + }{ + { + name: "no middleware", + numMiddleware: 0, + expectedCallOrder: []string{"handler"}, + }, + { + name: "single middleware", + numMiddleware: 1, + expectedCallOrder: []string{"m1", "handler"}, + }, + { + name: "two middleware", + numMiddleware: 2, + expectedCallOrder: []string{"m1", "m2", "handler"}, + }, + { + name: "three middleware", + numMiddleware: 3, + expectedCallOrder: []string{"m1", "m2", "m3", "handler"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var callOrder []string + + // Create middleware that track call order + middlewares := make([]func(http.Handler) http.Handler, tt.numMiddleware) + for i := 0; i < tt.numMiddleware; i++ { + name := "m" + string(rune('1'+i)) + middlewares[i] = func(n string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, n) + next.ServeHTTP(w, r) + }) + } + }(name) + } + + // Create handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callOrder = append(callOrder, "handler") + w.WriteHeader(http.StatusOK) + }) + + // Chain middleware + chained := Chain(middlewares...)(handler) + + // Execute request + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + chained.ServeHTTP(w, req) + + // Verify call order + if len(callOrder) != len(tt.expectedCallOrder) { + t.Fatalf("expected %d calls, got %d", len(tt.expectedCallOrder), len(callOrder)) + } + + for i, expected := range tt.expectedCallOrder { + if callOrder[i] != expected { + t.Errorf("call %d: expected %s, got %s", i, expected, callOrder[i]) + } + } + + // Verify response + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + }) + } +} + +// TestChain_ExecutionOrder verifies first added is outermost middleware. +func TestChain_ExecutionOrder(t *testing.T) { + var executionLog []string + + // Middleware 1: Adds "start-1" before and "end-1" after + m1 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionLog = append(executionLog, "start-1") + next.ServeHTTP(w, r) + executionLog = append(executionLog, "end-1") + }) + } + + // Middleware 2: Adds "start-2" before and "end-2" after + m2 := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionLog = append(executionLog, "start-2") + next.ServeHTTP(w, r) + executionLog = append(executionLog, "end-2") + }) + } + + // Handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + executionLog = append(executionLog, "handler") + w.WriteHeader(http.StatusOK) + }) + + // Chain: m1 first, then m2 + chained := Chain(m1, m2)(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + chained.ServeHTTP(w, req) + + // Expected order: start-1 → start-2 → handler → end-2 → end-1 + expected := []string{"start-1", "start-2", "handler", "end-2", "end-1"} + if len(executionLog) != len(expected) { + t.Fatalf("expected %d log entries, got %d", len(expected), len(executionLog)) + } + + for i, exp := range expected { + if executionLog[i] != exp { + t.Errorf("log[%d]: expected %s, got %s", i, exp, executionLog[i]) + } + } +} + +// TestLogger tests request logging middleware. +func TestLogger(t *testing.T) { + tests := []struct { + name string + method string + path string + handlerStatus int + expectLogEntry bool + }{ + { + name: "GET request", + method: "GET", + path: "/api/v1/models", + handlerStatus: http.StatusOK, + expectLogEntry: true, + }, + { + name: "POST request", + method: "POST", + path: "/api/v1/sync", + handlerStatus: http.StatusCreated, + expectLogEntry: true, + }, + { + name: "DELETE request", + method: "DELETE", + path: "/api/v1/cache", + handlerStatus: http.StatusNoContent, + expectLogEntry: true, + }, + { + name: "error status", + method: "GET", + path: "/api/v1/unknown", + handlerStatus: http.StatusNotFound, + expectLogEntry: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create logger that writes to buffer + var buf bytes.Buffer + logger := zerolog.New(&buf).With().Timestamp().Logger() + + // Create test handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.handlerStatus) + }) + + // Wrap with logger middleware + middleware := Logger(&logger) + wrapped := middleware(handler) + + // Execute request + req := httptest.NewRequest(tt.method, tt.path, nil) + req.RemoteAddr = "192.168.1.1:12345" + req.Header.Set("User-Agent", "test-agent") + w := httptest.NewRecorder() + + wrapped.ServeHTTP(w, req) + + // Verify response status + if w.Code != tt.handlerStatus { + t.Errorf("expected status %d, got %d", tt.handlerStatus, w.Code) + } + + // Parse log output + logOutput := buf.String() + if !tt.expectLogEntry { + if logOutput != "" { + t.Error("expected no log output") + } + return + } + + // Verify log contains expected fields + if !strings.Contains(logOutput, tt.method) { + t.Errorf("log missing method %s: %s", tt.method, logOutput) + } + if !strings.Contains(logOutput, tt.path) { + t.Errorf("log missing path %s: %s", tt.path, logOutput) + } + if !strings.Contains(logOutput, "192.168.1.1:12345") { + t.Errorf("log missing remote_addr: %s", logOutput) + } + if !strings.Contains(logOutput, "HTTP request") { + t.Errorf("log missing message: %s", logOutput) + } + + // Verify log is valid JSON + var logEntry map[string]interface{} + if err := json.Unmarshal([]byte(logOutput), &logEntry); err != nil { + t.Errorf("log is not valid JSON: %v", err) + } + + // Verify required fields in JSON + if logEntry["method"] != tt.method { + t.Errorf("log method: expected %s, got %v", tt.method, logEntry["method"]) + } + if logEntry["path"] != tt.path { + t.Errorf("log path: expected %s, got %v", tt.path, logEntry["path"]) + } + if statusFloat, ok := logEntry["status"].(float64); !ok || int(statusFloat) != tt.handlerStatus { + t.Errorf("log status: expected %d, got %v", tt.handlerStatus, logEntry["status"]) + } + if _, ok := logEntry["duration_ms"]; !ok { + t.Error("log missing duration_ms field") + } + }) + } +} + +// TestLogger_StatusCodeCapture verifies responseWriter captures status codes. +func TestLogger_StatusCodeCapture(t *testing.T) { + var buf bytes.Buffer + logger := zerolog.New(&buf).With().Timestamp().Logger() + + statusCodes := []int{ + http.StatusOK, + http.StatusCreated, + http.StatusBadRequest, + http.StatusUnauthorized, + http.StatusNotFound, + http.StatusInternalServerError, + } + + for _, expectedStatus := range statusCodes { + t.Run("status_"+http.StatusText(expectedStatus), func(t *testing.T) { + buf.Reset() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + }) + + middleware := Logger(&logger) + wrapped := middleware(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + wrapped.ServeHTTP(w, req) + + // Parse log and verify status + var logEntry map[string]interface{} + if err := json.Unmarshal(buf.Bytes(), &logEntry); err != nil { + t.Fatalf("failed to parse log: %v", err) + } + + statusFloat, ok := logEntry["status"].(float64) + if !ok { + t.Fatalf("status field not found or wrong type") + } + + if int(statusFloat) != expectedStatus { + t.Errorf("expected status %d, got %d", expectedStatus, int(statusFloat)) + } + }) + } +} + +// TestLogger_Duration verifies duration logging. +func TestLogger_Duration(t *testing.T) { + var buf bytes.Buffer + logger := zerolog.New(&buf).With().Timestamp().Logger() + + // Handler that sleeps + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + middleware := Logger(&logger) + wrapped := middleware(handler) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + start := time.Now() + wrapped.ServeHTTP(w, req) + elapsed := time.Since(start) + + // Parse log + var logEntry map[string]interface{} + if err := json.Unmarshal(buf.Bytes(), &logEntry); err != nil { + t.Fatalf("failed to parse log: %v", err) + } + + // Verify duration is present and reasonable + durationFloat, ok := logEntry["duration_ms"].(float64) + if !ok { + t.Fatal("duration_ms field not found or wrong type") + } + + // Duration is logged in milliseconds as a float + durationMs := time.Duration(durationFloat * float64(time.Millisecond)) + + // Duration should be at least 50ms (sleep time) + if durationMs < 50*time.Millisecond { + t.Errorf("duration too short: %v (expected >= 50ms)", durationMs) + } + + // Duration should be close to actual elapsed time (within 100ms) + diff := elapsed - durationMs + if diff < 0 { + diff = -diff + } + if diff > 100*time.Millisecond { + t.Errorf("duration mismatch: logged %v, actual %v (diff %v)", durationMs, elapsed, diff) + } +} + +// TestRecovery tests panic recovery middleware. +func TestRecovery(t *testing.T) { + tests := []struct { + name string + shouldPanic bool + panicValue interface{} + expectStatus int + expectLogPanic bool + }{ + { + name: "no panic - normal execution", + shouldPanic: false, + expectStatus: http.StatusOK, + expectLogPanic: false, + }, + { + name: "panic with string", + shouldPanic: true, + panicValue: "something went wrong", + expectStatus: http.StatusInternalServerError, + expectLogPanic: true, + }, + { + name: "panic with error", + shouldPanic: true, + panicValue: http.ErrAbortHandler, + expectStatus: http.StatusInternalServerError, + expectLogPanic: true, + }, + { + name: "panic with nil", + shouldPanic: true, + panicValue: nil, + expectStatus: http.StatusInternalServerError, + expectLogPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create logger that writes to buffer + var buf bytes.Buffer + logger := zerolog.New(&buf).With().Timestamp().Logger() + + // Create handler that may panic + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tt.shouldPanic { + panic(tt.panicValue) + } + w.WriteHeader(http.StatusOK) + }) + + // Wrap with recovery middleware + middleware := Recovery(&logger) + wrapped := middleware(handler) + + // Execute request + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + // Should not panic at this level + func() { + defer func() { + if r := recover(); r != nil { + t.Errorf("panic not recovered: %v", r) + } + }() + wrapped.ServeHTTP(w, req) + }() + + // Verify response status + if w.Code != tt.expectStatus { + t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code) + } + + // Verify log output + logOutput := buf.String() + if tt.expectLogPanic { + if !strings.Contains(logOutput, "Panic recovered") { + t.Error("expected panic log entry") + } + if !strings.Contains(logOutput, "GET") { + t.Error("log missing method") + } + if !strings.Contains(logOutput, "/test") { + t.Error("log missing path") + } + + // Verify error response JSON + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("expected Content-Type=application/json, got %s", contentType) + } + + body := w.Body.String() + if !strings.Contains(body, "INTERNAL_ERROR") { + t.Error("response missing INTERNAL_ERROR code") + } + if !strings.Contains(body, "Internal server error") { + t.Error("response missing error message") + } + + // Verify valid JSON + var errorResp map[string]interface{} + if err := json.Unmarshal([]byte(body), &errorResp); err != nil { + t.Errorf("response is not valid JSON: %v", err) + } + } else { + if strings.Contains(logOutput, "Panic recovered") { + t.Error("unexpected panic log entry") + } + } + }) + } +} + +// TestRecovery_OtherRequestsStillWork verifies other requests work after panic. +func TestRecovery_OtherRequestsStillWork(t *testing.T) { + var buf bytes.Buffer + logger := zerolog.New(&buf).With().Timestamp().Logger() + + requestCount := 0 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + if requestCount == 2 { + panic("intentional panic") + } + w.WriteHeader(http.StatusOK) + }) + + middleware := Recovery(&logger) + wrapped := middleware(handler) + + // Request 1: Success + req1 := httptest.NewRequest("GET", "/test1", nil) + w1 := httptest.NewRecorder() + wrapped.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Errorf("request 1: expected 200, got %d", w1.Code) + } + + // Request 2: Panic + req2 := httptest.NewRequest("GET", "/test2", nil) + w2 := httptest.NewRecorder() + wrapped.ServeHTTP(w2, req2) + if w2.Code != http.StatusInternalServerError { + t.Errorf("request 2: expected 500, got %d", w2.Code) + } + + // Request 3: Should still work + req3 := httptest.NewRequest("GET", "/test3", nil) + w3 := httptest.NewRecorder() + wrapped.ServeHTTP(w3, req3) + if w3.Code != http.StatusOK { + t.Errorf("request 3: expected 200, got %d", w3.Code) + } + + if requestCount != 3 { + t.Errorf("expected 3 requests, got %d", requestCount) + } +} + +// TestResponseWriter tests the responseWriter wrapper. +func TestResponseWriter(t *testing.T) { + tests := []struct { + name string + writeHeader bool + statusCode int + expectedCode int + }{ + { + name: "explicit WriteHeader", + writeHeader: true, + statusCode: http.StatusCreated, + expectedCode: http.StatusCreated, + }, + { + name: "default status (no WriteHeader)", + writeHeader: false, + expectedCode: http.StatusOK, + }, + { + name: "error status", + writeHeader: true, + statusCode: http.StatusBadRequest, + expectedCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &responseWriter{ + ResponseWriter: recorder, + statusCode: http.StatusOK, + } + + if tt.writeHeader { + rw.WriteHeader(tt.statusCode) + } + + // Verify wrapped status code + if rw.statusCode != tt.expectedCode { + t.Errorf("expected statusCode=%d, got %d", tt.expectedCode, rw.statusCode) + } + + // Verify underlying recorder + if tt.writeHeader && recorder.Code != tt.statusCode { + t.Errorf("expected recorder.Code=%d, got %d", tt.statusCode, recorder.Code) + } + }) + } +} diff --git a/internal/server/middleware/ratelimit.go b/internal/server/middleware/ratelimit.go new file mode 100644 index 000000000..3c54dd7cc --- /dev/null +++ b/internal/server/middleware/ratelimit.go @@ -0,0 +1,132 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// RateLimiter implements token bucket rate limiting per IP address. +type RateLimiter struct { + mu sync.RWMutex + visitors map[string]*visitor + limit int // requests per minute + interval time.Duration // cleanup interval + logger *zerolog.Logger +} + +// visitor tracks rate limit state for a single IP. +type visitor struct { + tokens int + lastReset time.Time + mu sync.Mutex +} + +// NewRateLimiter creates a new rate limiter. +// limit is requests per minute per IP. +func NewRateLimiter(limit int, logger *zerolog.Logger) *RateLimiter { + rl := &RateLimiter{ + visitors: make(map[string]*visitor), + limit: limit, + interval: time.Minute, + logger: logger, + } + + // Start cleanup goroutine + go rl.cleanup() + + return rl +} + +// cleanup removes stale visitors every 5 minutes. +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + rl.mu.Lock() + for ip, v := range rl.visitors { + v.mu.Lock() + if time.Since(v.lastReset) > 10*time.Minute { + delete(rl.visitors, ip) + } + v.mu.Unlock() + } + rl.mu.Unlock() + } +} + +// getVisitor returns or creates a visitor for the IP. +func (rl *RateLimiter) getVisitor(ip string) *visitor { + rl.mu.RLock() + v, exists := rl.visitors[ip] + rl.mu.RUnlock() + + if !exists { + rl.mu.Lock() + // Double-check after acquiring write lock + v, exists = rl.visitors[ip] + if !exists { + v = &visitor{ + tokens: rl.limit, + lastReset: time.Now(), + } + rl.visitors[ip] = v + } + rl.mu.Unlock() + } + + return v +} + +// allow checks if a request from the IP is allowed. +func (rl *RateLimiter) allow(ip string) bool { + v := rl.getVisitor(ip) + + v.mu.Lock() + defer v.mu.Unlock() + + // Reset tokens if interval has passed + if time.Since(v.lastReset) > rl.interval { + v.tokens = rl.limit + v.lastReset = time.Now() + } + + // Check if tokens available + if v.tokens > 0 { + v.tokens-- + return true + } + + return false +} + +// RateLimit middleware limits requests per IP address. +func RateLimit(rl *RateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Extract IP address (handle X-Forwarded-For) + ip := r.RemoteAddr + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + ip = forwarded + } + + // Check rate limit + if !rl.allow(ip) { + rl.logger.Warn(). + Str("ip", ip). + Str("path", r.URL.Path). + Msg("Rate limit exceeded") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"data":null,"error":{"code":"RATE_LIMITED","message":"Rate limit exceeded","details":"Too many requests. Please try again later."}}`)) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/server/middleware/ratelimit_test.go b/internal/server/middleware/ratelimit_test.go new file mode 100644 index 000000000..c01884317 --- /dev/null +++ b/internal/server/middleware/ratelimit_test.go @@ -0,0 +1,476 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" +) + +// TestNewRateLimiter tests rate limiter creation. +func TestNewRateLimiter(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(100, &logger) + + if rl == nil { + t.Fatal("NewRateLimiter returned nil") + } + if rl.visitors == nil { + t.Error("visitors map not initialized") + } + if rl.limit != 100 { + t.Errorf("expected limit=100, got %d", rl.limit) + } + if rl.interval != time.Minute { + t.Errorf("expected interval=1m, got %v", rl.interval) + } +} + +// TestRateLimiter_Allow tests basic rate limiting logic. +func TestRateLimiter_Allow(t *testing.T) { + logger := zerolog.Nop() + + tests := []struct { + name string + limit int + requests int + expectedAllow int // How many should be allowed + }{ + { + name: "within limit", + limit: 10, + requests: 5, + expectedAllow: 5, + }, + { + name: "at limit", + limit: 10, + requests: 10, + expectedAllow: 10, + }, + { + name: "exceeds limit", + limit: 10, + requests: 15, + expectedAllow: 10, + }, + { + name: "zero limit", + limit: 0, + requests: 5, + expectedAllow: 0, + }, + { + name: "single request limit", + limit: 1, + requests: 3, + expectedAllow: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rl := NewRateLimiter(tt.limit, &logger) + ip := "192.168.1.1" + + allowed := 0 + for i := 0; i < tt.requests; i++ { + if rl.allow(ip) { + allowed++ + } + } + + if allowed != tt.expectedAllow { + t.Errorf("expected %d allowed, got %d", tt.expectedAllow, allowed) + } + }) + } +} + +// TestRateLimiter_MultipleIPs tests independent rate limiting per IP. +func TestRateLimiter_MultipleIPs(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(5, &logger) + + ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"} + + // Each IP should get their own limit + for _, ip := range ips { + allowed := 0 + for i := 0; i < 10; i++ { + if rl.allow(ip) { + allowed++ + } + } + if allowed != 5 { + t.Errorf("IP %s: expected 5 allowed, got %d", ip, allowed) + } + } + + // Verify each IP is tracked separately + if len(rl.visitors) != 3 { + t.Errorf("expected 3 visitors, got %d", len(rl.visitors)) + } +} + +// TestRateLimiter_TokenRefresh tests token bucket refresh after interval. +func TestRateLimiter_TokenRefresh(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(3, &logger) + + // Override interval for faster testing + rl.interval = 100 * time.Millisecond + + ip := "192.168.1.1" + + // Use all tokens + for i := 0; i < 3; i++ { + if !rl.allow(ip) { + t.Fatalf("expected request %d to be allowed", i) + } + } + + // Next request should be denied + if rl.allow(ip) { + t.Error("expected request to be denied (no tokens)") + } + + // Wait for token refresh + time.Sleep(150 * time.Millisecond) + + // Tokens should be refreshed + if !rl.allow(ip) { + t.Error("expected request to be allowed after token refresh") + } +} + +// TestRateLimiter_ConcurrentRequests tests thread-safety with concurrent requests. +func TestRateLimiter_ConcurrentRequests(t *testing.T) { + logger := zerolog.Nop() + limit := 100 + rl := NewRateLimiter(limit, &logger) + + ip := "192.168.1.1" + numGoroutines := 50 + requestsPerGoroutine := 10 + + var wg sync.WaitGroup + var mu sync.Mutex + allowed := 0 + + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + if rl.allow(ip) { + mu.Lock() + allowed++ + mu.Unlock() + } + } + }() + } + + wg.Wait() + + // Should allow exactly the limit + if allowed != limit { + t.Errorf("expected %d allowed, got %d", limit, allowed) + } +} + +// TestRateLimiter_ConcurrentMultipleIPs tests concurrent requests from multiple IPs. +func TestRateLimiter_ConcurrentMultipleIPs(t *testing.T) { + logger := zerolog.Nop() + limit := 10 + rl := NewRateLimiter(limit, &logger) + + numIPs := 20 + requestsPerIP := 15 + + var wg sync.WaitGroup + results := make(map[string]int) + var mu sync.Mutex + + wg.Add(numIPs) + for i := 0; i < numIPs; i++ { + go func(id int) { + defer wg.Done() + ip := "192.168.1." + string(rune(id+1)) + allowed := 0 + + for j := 0; j < requestsPerIP; j++ { + if rl.allow(ip) { + allowed++ + } + } + + mu.Lock() + results[ip] = allowed + mu.Unlock() + }(i) + } + + wg.Wait() + + // Each IP should be allowed exactly the limit + for ip, count := range results { + if count != limit { + t.Errorf("IP %s: expected %d allowed, got %d", ip, limit, count) + } + } +} + +// TestRateLimiter_Middleware tests the RateLimit middleware function. +func TestRateLimiter_Middleware(t *testing.T) { + logger := zerolog.Nop() + + tests := []struct { + name string + limit int + requests int + expectedSuccess int + expectedBlocked int + }{ + { + name: "within limit", + limit: 5, + requests: 3, + expectedSuccess: 3, + expectedBlocked: 0, + }, + { + name: "at limit", + limit: 5, + requests: 5, + expectedSuccess: 5, + expectedBlocked: 0, + }, + { + name: "exceeds limit", + limit: 5, + requests: 8, + expectedSuccess: 5, + expectedBlocked: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rl := NewRateLimiter(tt.limit, &logger) + middleware := RateLimit(rl) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + success := 0 + blocked := 0 + + for i := 0; i < tt.requests; i++ { + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code == http.StatusOK { + success++ + } else if w.Code == http.StatusTooManyRequests { + blocked++ + } + } + + if success != tt.expectedSuccess { + t.Errorf("expected %d successful requests, got %d", tt.expectedSuccess, success) + } + if blocked != tt.expectedBlocked { + t.Errorf("expected %d blocked requests, got %d", tt.expectedBlocked, blocked) + } + }) + } +} + +// TestRateLimiter_Middleware_XForwardedFor tests X-Forwarded-For header handling. +func TestRateLimiter_Middleware_XForwardedFor(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(3, &logger) + middleware := RateLimit(rl) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + // Use X-Forwarded-For + for i := 0; i < 3; i++ { + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.RemoteAddr = "proxy:8080" + req.Header.Set("X-Forwarded-For", "10.0.0.1") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("request %d: expected 200, got %d", i, w.Code) + } + } + + // Next request should be blocked + req := httptest.NewRequest("GET", "/api/v1/models", nil) + req.RemoteAddr = "proxy:8080" + req.Header.Set("X-Forwarded-For", "10.0.0.1") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", w.Code) + } +} + +// TestRateLimiter_Middleware_ErrorResponse tests rate limit error response format. +func TestRateLimiter_Middleware_ErrorResponse(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(1, &logger) + middleware := RateLimit(rl) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(testHandler) + + // First request succeeds + req := httptest.NewRequest("GET", "/api/v1/models", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + // Second request is rate limited + req = httptest.NewRequest("GET", "/api/v1/models", nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", w.Code) + } + + // Check response format + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("expected Content-Type=application/json, got %s", contentType) + } + + body := w.Body.String() + if body == "" { + t.Error("expected error response body") + } + if !contains(body, "RATE_LIMITED") { + t.Error("expected RATE_LIMITED in response body") + } +} + +// TestRateLimiter_Cleanup tests the cleanup goroutine behavior. +func TestRateLimiter_Cleanup(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(5, &logger) + + // Add some visitors + for i := 0; i < 10; i++ { + ip := "192.168.1." + string(rune(i+1)) + rl.allow(ip) + } + + initialCount := len(rl.visitors) + if initialCount != 10 { + t.Errorf("expected 10 visitors, got %d", initialCount) + } + + // Manually set lastReset to old time to trigger cleanup + rl.mu.Lock() + for _, v := range rl.visitors { + v.mu.Lock() + v.lastReset = time.Now().Add(-15 * time.Minute) + v.mu.Unlock() + } + rl.mu.Unlock() + + // Trigger cleanup by calling the internal cleanup logic + // Note: In production, cleanup runs every 5 minutes + // For testing, we'll simulate it + rl.mu.Lock() + for ip, v := range rl.visitors { + v.mu.Lock() + if time.Since(v.lastReset) > 10*time.Minute { + delete(rl.visitors, ip) + } + v.mu.Unlock() + } + rl.mu.Unlock() + + // Verify cleanup occurred + if len(rl.visitors) != 0 { + t.Errorf("expected 0 visitors after cleanup, got %d", len(rl.visitors)) + } +} + +// TestRateLimiter_VisitorCreation tests double-checked locking pattern. +func TestRateLimiter_VisitorCreation(t *testing.T) { + logger := zerolog.Nop() + rl := NewRateLimiter(100, &logger) + + ip := "192.168.1.1" + + // Concurrent creation should only create one visitor + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + rl.allow(ip) + }() + } + wg.Wait() + + // Should only have one visitor + if len(rl.visitors) != 1 { + t.Errorf("expected 1 visitor, got %d", len(rl.visitors)) + } +} + +// TestRateLimiter_BurstTraffic tests handling burst traffic patterns. +func TestRateLimiter_BurstTraffic(t *testing.T) { + logger := zerolog.Nop() + limit := 50 + rl := NewRateLimiter(limit, &logger) + + ip := "192.168.1.1" + + // Simulate burst of requests + burstSize := 100 + allowed := 0 + + start := time.Now() + for i := 0; i < burstSize; i++ { + if rl.allow(ip) { + allowed++ + } + } + duration := time.Since(start) + + // Should handle burst quickly + if duration > 100*time.Millisecond { + t.Errorf("burst took too long: %v", duration) + } + + // Should respect limit + if allowed != limit { + t.Errorf("expected %d allowed, got %d", limit, allowed) + } +} + +// contains helper is defined in auth_test.go diff --git a/internal/server/response/README.md b/internal/server/response/README.md new file mode 100644 index 000000000..2ed5ad5d4 --- /dev/null +++ b/internal/server/response/README.md @@ -0,0 +1,177 @@ + + + + +# response + +```go +import "github.com/agentstation/starmap/internal/server/response" +``` + +Package response provides standardized HTTP response structures and helpers for the Starmap API server. All API responses follow a consistent format with a data field for successful responses and an error field for failures. + +## Index + +- [func BadRequest\(w http.ResponseWriter, message, details string\)](<#BadRequest>) +- [func Created\(w http.ResponseWriter, data any\)](<#Created>) +- [func ErrorFromType\(w http.ResponseWriter, err error\)](<#ErrorFromType>) +- [func InternalError\(w http.ResponseWriter, \_ error\)](<#InternalError>) +- [func JSON\(w http.ResponseWriter, status int, resp Response\)](<#JSON>) +- [func MethodNotAllowed\(w http.ResponseWriter, method string\)](<#MethodNotAllowed>) +- [func NotFound\(w http.ResponseWriter, message, details string\)](<#NotFound>) +- [func OK\(w http.ResponseWriter, data any\)](<#OK>) +- [func RateLimited\(w http.ResponseWriter, message string\)](<#RateLimited>) +- [func ServiceUnavailable\(w http.ResponseWriter, message string\)](<#ServiceUnavailable>) +- [func Unauthorized\(w http.ResponseWriter, message, details string\)](<#Unauthorized>) +- [type Error](<#Error>) +- [type Response](<#Response>) + - [func Fail\(code, message, details string\) Response](<#Fail>) + - [func Success\(data any\) Response](<#Success>) + + + +## func [BadRequest]() + +```go +func BadRequest(w http.ResponseWriter, message, details string) +``` + +BadRequest writes a 400 error response. + + +## func [Created]() + +```go +func Created(w http.ResponseWriter, data any) +``` + +Created writes a successful response with 201 status. + + +## func [ErrorFromType]() + +```go +func ErrorFromType(w http.ResponseWriter, err error) +``` + +ErrorFromType maps typed errors to appropriate HTTP responses. + + +## func [InternalError]() + +```go +func InternalError(w http.ResponseWriter, _ error) +``` + +InternalError writes a 500 error response. + + +## func [JSON]() + +```go +func JSON(w http.ResponseWriter, status int, resp Response) +``` + +JSON writes a JSON response with the given status code. + + +## func [MethodNotAllowed]() + +```go +func MethodNotAllowed(w http.ResponseWriter, method string) +``` + +MethodNotAllowed writes a 405 error response. + + +## func [NotFound]() + +```go +func NotFound(w http.ResponseWriter, message, details string) +``` + +NotFound writes a 404 error response. + + +## func [OK]() + +```go +func OK(w http.ResponseWriter, data any) +``` + +OK writes a successful response with 200 status. + + +## func [RateLimited]() + +```go +func RateLimited(w http.ResponseWriter, message string) +``` + +RateLimited writes a 429 error response. + + +## func [ServiceUnavailable]() + +```go +func ServiceUnavailable(w http.ResponseWriter, message string) +``` + +ServiceUnavailable writes a 503 error response. + + +## func [Unauthorized]() + +```go +func Unauthorized(w http.ResponseWriter, message, details string) +``` + +Unauthorized writes a 401 error response. + + +## type [Error]() + +Error represents an API error with code, message, and optional details. + +```go +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} +``` + + +## type [Response]() + +Response represents the standardized API response structure. All endpoints return this format for consistency. + +```go +type Response struct { + Data any `json:"data"` + Error *Error `json:"error"` +} +``` + + +### func [Fail]() + +```go +func Fail(code, message, details string) Response +``` + +Fail creates an error response. + + +### func [Success]() + +```go +func Success(data any) Response +``` + +Success creates a successful response with data. + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/response/generate.go b/internal/server/response/generate.go new file mode 100644 index 000000000..7ac8361c6 --- /dev/null +++ b/internal/server/response/generate.go @@ -0,0 +1,3 @@ +package response + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/response diff --git a/internal/server/response/response.go b/internal/server/response/response.go new file mode 100644 index 000000000..6c79b583d --- /dev/null +++ b/internal/server/response/response.go @@ -0,0 +1,137 @@ +// Package response provides standardized HTTP response structures and helpers +// for the Starmap API server. All API responses follow a consistent format +// with a data field for successful responses and an error field for failures. +package response + +import ( + "encoding/json" + "net/http" + + "github.com/agentstation/starmap/pkg/errors" +) + +// Response represents the standardized API response structure. +// All endpoints return this format for consistency. +type Response struct { + Data any `json:"data"` + Error *Error `json:"error"` +} + +// Error represents an API error with code, message, and optional details. +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +// Success creates a successful response with data. +func Success(data any) Response { + return Response{ + Data: data, + Error: nil, + } +} + +// Fail creates an error response. +func Fail(code, message, details string) Response { + return Response{ + Data: nil, + Error: &Error{ + Code: code, + Message: message, + Details: details, + }, + } +} + +// JSON writes a JSON response with the given status code. +func JSON(w http.ResponseWriter, status int, resp Response) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + // Encoding errors are ignored as headers are already sent (best effort) + _ = json.NewEncoder(w).Encode(resp) +} + +// OK writes a successful response with 200 status. +func OK(w http.ResponseWriter, data any) { + JSON(w, http.StatusOK, Success(data)) +} + +// Created writes a successful response with 201 status. +func Created(w http.ResponseWriter, data any) { + JSON(w, http.StatusCreated, Success(data)) +} + +// BadRequest writes a 400 error response. +func BadRequest(w http.ResponseWriter, message, details string) { + JSON(w, http.StatusBadRequest, Fail("BAD_REQUEST", message, details)) +} + +// Unauthorized writes a 401 error response. +func Unauthorized(w http.ResponseWriter, message, details string) { + JSON(w, http.StatusUnauthorized, Fail("UNAUTHORIZED", message, details)) +} + +// NotFound writes a 404 error response. +func NotFound(w http.ResponseWriter, message, details string) { + JSON(w, http.StatusNotFound, Fail("NOT_FOUND", message, details)) +} + +// MethodNotAllowed writes a 405 error response. +func MethodNotAllowed(w http.ResponseWriter, method string) { + JSON(w, http.StatusMethodNotAllowed, Fail( + "METHOD_NOT_ALLOWED", + "Method not allowed", + "Method "+method+" is not supported for this endpoint", + )) +} + +// RateLimited writes a 429 error response. +func RateLimited(w http.ResponseWriter, message string) { + JSON(w, http.StatusTooManyRequests, Fail( + "RATE_LIMITED", + "Rate limit exceeded", + message, + )) +} + +// InternalError writes a 500 error response. +func InternalError(w http.ResponseWriter, _ error) { + // Log the actual error but don't expose details to client + // Note: Logging should be handled by middleware or passed via context + JSON(w, http.StatusInternalServerError, Fail( + "INTERNAL_ERROR", + "Internal server error", + "An unexpected error occurred", + )) +} + +// ServiceUnavailable writes a 503 error response. +func ServiceUnavailable(w http.ResponseWriter, message string) { + JSON(w, http.StatusServiceUnavailable, Fail( + "SERVICE_UNAVAILABLE", + "Service unavailable", + message, + )) +} + +// ErrorFromType maps typed errors to appropriate HTTP responses. +func ErrorFromType(w http.ResponseWriter, err error) { + switch e := err.(type) { + case *errors.NotFoundError: + NotFound(w, e.Error(), "") + case *errors.ValidationError: + BadRequest(w, e.Error(), "") + case *errors.SyncError: + InternalError(w, err) + case *errors.APIError: + if e.StatusCode >= 500 { + InternalError(w, err) + } else { + BadRequest(w, e.Error(), "") + } + default: + InternalError(w, err) + } +} + diff --git a/internal/server/response/response_test.go b/internal/server/response/response_test.go new file mode 100644 index 000000000..f66ea51e9 --- /dev/null +++ b/internal/server/response/response_test.go @@ -0,0 +1,410 @@ +package response + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + starmapErrors "github.com/agentstation/starmap/pkg/errors" +) + +// TestSuccess tests the Success helper function. +func TestSuccess(t *testing.T) { + data := map[string]string{"message": "success"} + resp := Success(data) + + if resp.Data == nil { + t.Error("expected Data to be set") + } + if resp.Error != nil { + t.Error("expected Error to be nil") + } +} + +// TestFail tests the Fail helper function. +func TestFail(t *testing.T) { + resp := Fail("TEST_ERROR", "Test error message", "Additional details") + + if resp.Data != nil { + t.Error("expected Data to be nil") + } + if resp.Error == nil { + t.Fatal("expected Error to be set") + } + if resp.Error.Code != "TEST_ERROR" { + t.Errorf("expected Code=TEST_ERROR, got %s", resp.Error.Code) + } + if resp.Error.Message != "Test error message" { + t.Errorf("expected Message=Test error message, got %s", resp.Error.Message) + } + if resp.Error.Details != "Additional details" { + t.Errorf("expected Details=Additional details, got %s", resp.Error.Details) + } +} + +// TestJSON tests the JSON helper function. +func TestJSON(t *testing.T) { + w := httptest.NewRecorder() + resp := Success(map[string]string{"test": "data"}) + + JSON(w, http.StatusOK, resp) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("expected Content-Type=application/json, got %s", contentType) + } + + // Verify JSON is valid + var decoded Response + if err := json.NewDecoder(w.Body).Decode(&decoded); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if decoded.Data == nil { + t.Error("expected decoded Data to be set") + } + if decoded.Error != nil { + t.Error("expected decoded Error to be nil") + } +} + +// TestOK tests the OK helper function. +func TestOK(t *testing.T) { + w := httptest.NewRecorder() + data := map[string]int{"count": 42} + + OK(w, data) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp Response + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if resp.Error != nil { + t.Error("expected no error in response") + } +} + +// TestCreated tests the Created helper function. +func TestCreated(t *testing.T) { + w := httptest.NewRecorder() + data := map[string]string{"id": "new-resource"} + + Created(w, data) + + if w.Code != http.StatusCreated { + t.Errorf("expected status 201, got %d", w.Code) + } +} + +// TestErrorHelpers tests all error response helpers. +func TestErrorHelpers(t *testing.T) { + tests := []struct { + name string + fn func(w http.ResponseWriter) + expectedStatus int + expectedCode string + }{ + { + name: "BadRequest", + fn: func(w http.ResponseWriter) { + BadRequest(w, "Invalid request", "Missing field") + }, + expectedStatus: http.StatusBadRequest, + expectedCode: "BAD_REQUEST", + }, + { + name: "Unauthorized", + fn: func(w http.ResponseWriter) { + Unauthorized(w, "Auth failed", "Invalid key") + }, + expectedStatus: http.StatusUnauthorized, + expectedCode: "UNAUTHORIZED", + }, + { + name: "NotFound", + fn: func(w http.ResponseWriter) { + NotFound(w, "Resource not found", "ID not found") + }, + expectedStatus: http.StatusNotFound, + expectedCode: "NOT_FOUND", + }, + { + name: "MethodNotAllowed", + fn: func(w http.ResponseWriter) { + MethodNotAllowed(w, "POST") + }, + expectedStatus: http.StatusMethodNotAllowed, + expectedCode: "METHOD_NOT_ALLOWED", + }, + { + name: "RateLimited", + fn: func(w http.ResponseWriter) { + RateLimited(w, "Too many requests") + }, + expectedStatus: http.StatusTooManyRequests, + expectedCode: "RATE_LIMITED", + }, + { + name: "InternalError", + fn: func(w http.ResponseWriter) { + InternalError(w, errors.New("internal error")) + }, + expectedStatus: http.StatusInternalServerError, + expectedCode: "INTERNAL_ERROR", + }, + { + name: "ServiceUnavailable", + fn: func(w http.ResponseWriter) { + ServiceUnavailable(w, "Service down") + }, + expectedStatus: http.StatusServiceUnavailable, + expectedCode: "SERVICE_UNAVAILABLE", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + tt.fn(w) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + var resp Response + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if resp.Data != nil { + t.Error("expected Data to be nil for error response") + } + if resp.Error == nil { + t.Fatal("expected Error to be set") + } + if resp.Error.Code != tt.expectedCode { + t.Errorf("expected Code=%s, got %s", tt.expectedCode, resp.Error.Code) + } + }) + } +} + +// TestErrorFromType tests typed error mapping. +func TestErrorFromType(t *testing.T) { + tests := []struct { + name string + err error + expectedStatus int + expectedCode string + }{ + { + name: "NotFoundError", + err: &starmapErrors.NotFoundError{Resource: "model", ID: "gpt-4"}, + expectedStatus: http.StatusNotFound, + expectedCode: "NOT_FOUND", + }, + { + name: "ValidationError", + err: &starmapErrors.ValidationError{Field: "name", Value: "", Message: "required"}, + expectedStatus: http.StatusBadRequest, + expectedCode: "BAD_REQUEST", + }, + { + name: "SyncError", + err: &starmapErrors.SyncError{Provider: "openai", Err: errors.New("sync failed")}, + expectedStatus: http.StatusInternalServerError, + expectedCode: "INTERNAL_ERROR", + }, + { + name: "APIError - 4xx", + err: &starmapErrors.APIError{Provider: "openai", Endpoint: "/models", StatusCode: 400}, + expectedStatus: http.StatusBadRequest, + expectedCode: "BAD_REQUEST", + }, + { + name: "APIError - 5xx", + err: &starmapErrors.APIError{Provider: "openai", Endpoint: "/models", StatusCode: 503}, + expectedStatus: http.StatusInternalServerError, + expectedCode: "INTERNAL_ERROR", + }, + { + name: "Generic error", + err: errors.New("generic error"), + expectedStatus: http.StatusInternalServerError, + expectedCode: "INTERNAL_ERROR", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + ErrorFromType(w, tt.err) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + var resp Response + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if resp.Data != nil { + t.Error("expected Data to be nil for error response") + } + if resp.Error == nil { + t.Fatal("expected Error to be set") + } + if resp.Error.Code != tt.expectedCode { + t.Errorf("expected Code=%s, got %s", tt.expectedCode, resp.Error.Code) + } + }) + } +} + +// TestResponseStructure tests the Response struct marshaling. +func TestResponseStructure(t *testing.T) { + t.Run("success response structure", func(t *testing.T) { + resp := Success(map[string]string{"key": "value"}) + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var unmarshaled map[string]any + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + // Check structure + if _, ok := unmarshaled["data"]; !ok { + t.Error("expected 'data' field in JSON") + } + if _, ok := unmarshaled["error"]; !ok { + t.Error("expected 'error' field in JSON") + } + }) + + t.Run("error response structure", func(t *testing.T) { + resp := Fail("TEST", "message", "details") + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var unmarshaled map[string]any + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + // Check error structure + if unmarshaled["data"] != nil { + t.Error("expected 'data' to be null") + } + + errorField, ok := unmarshaled["error"].(map[string]any) + if !ok { + t.Fatal("expected 'error' to be an object") + } + + if errorField["code"] != "TEST" { + t.Errorf("expected code=TEST, got %v", errorField["code"]) + } + if errorField["message"] != "message" { + t.Errorf("expected message=message, got %v", errorField["message"]) + } + if errorField["details"] != "details" { + t.Errorf("expected details=details, got %v", errorField["details"]) + } + }) +} + +// TestErrorDetails tests error details omitempty behavior. +func TestErrorDetails(t *testing.T) { + t.Run("with details", func(t *testing.T) { + resp := Fail("TEST", "message", "details") + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var unmarshaled map[string]any + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + errorField := unmarshaled["error"].(map[string]any) + if _, ok := errorField["details"]; !ok { + t.Error("expected 'details' field when provided") + } + }) + + t.Run("without details", func(t *testing.T) { + resp := Fail("TEST", "message", "") + data, err := json.Marshal(resp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var unmarshaled map[string]any + if err := json.Unmarshal(data, &unmarshaled); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + errorField := unmarshaled["error"].(map[string]any) + // omitempty should exclude empty details + if details, ok := errorField["details"]; ok && details != "" { + t.Errorf("expected 'details' to be omitted when empty, got %v", details) + } + }) +} + +// TestComplexDataTypes tests response with various data types. +func TestComplexDataTypes(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + Count int `json:"count"` + Active bool `json:"active"` + Tags []string `json:"tags"` + } + + tests := []struct { + name string + data any + }{ + {"string", "hello"}, + {"int", 42}, + {"bool", true}, + {"slice", []string{"a", "b", "c"}}, + {"map", map[string]int{"one": 1, "two": 2}}, + {"struct", TestStruct{Name: "test", Count: 123, Active: true, Tags: []string{"tag1"}}}, + {"nil", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + OK(w, tt.data) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var resp Response + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + }) + } +} diff --git a/internal/server/router.go b/internal/server/router.go new file mode 100644 index 000000000..bea195bf4 --- /dev/null +++ b/internal/server/router.go @@ -0,0 +1,201 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + + "github.com/agentstation/starmap/internal/server/handlers" + "github.com/agentstation/starmap/internal/server/middleware" +) + +// setupRouter creates the HTTP handler with routes and middleware. +func (s *Server) setupRouter() http.Handler { + mux := http.NewServeMux() + + // Create handlers instance + h := handlers.New( + s.app, + s.cache, + s.broker, + s.wsHub, + s.sseBroadcaster, + s.upgrader, + s.logger, + ) + + // Register routes + s.registerRoutes(mux, h) + + // Apply middleware chain + handler := s.applyMiddleware(mux) + + return handler +} + +// registerRoutes registers all HTTP routes. +func (s *Server) registerRoutes(mux *http.ServeMux, h *handlers.Handlers) { + prefix := s.config.PathPrefix + + // Public health endpoints (no auth required) + mux.HandleFunc("/health", h.HandleHealth) + mux.HandleFunc(prefix+"/health", h.HandleHealth) + mux.HandleFunc(prefix+"/ready", h.HandleReady) + + // Models endpoints + mux.HandleFunc(prefix+"/models", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // POST /api/v1/models is treated as search + if r.URL.Path == prefix+"/models" || r.URL.Path == prefix+"/models/" { + h.HandleSearchModels(w, r) + return + } + } + + if r.Method == http.MethodGet { + h.HandleListModels(w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + mux.HandleFunc(prefix+"/models/", func(w http.ResponseWriter, r *http.Request) { + modelID := extractPathParam(r.URL.Path, prefix+"/models/") + if modelID != "" && r.Method == http.MethodGet { + h.HandleGetModel(w, r, modelID) + return + } + http.Error(w, "Not found", http.StatusNotFound) + }) + + // Providers endpoints + mux.HandleFunc(prefix+"/providers", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + h.HandleListProviders(w, r) + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + mux.HandleFunc(prefix+"/providers/", func(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path[len(prefix+"/providers/"):] + parts := splitPath(path) + + if len(parts) == 0 { + http.Error(w, "Provider ID required", http.StatusBadRequest) + return + } + + providerID := parts[0] + + if len(parts) == 1 { + // GET /providers/{id} + if r.Method == http.MethodGet { + h.HandleGetProvider(w, r, providerID) + return + } + } else if len(parts) == 2 && parts[1] == "models" { + // GET /providers/{id}/models + if r.Method == http.MethodGet { + h.HandleGetProviderModels(w, r, providerID) + return + } + } + + http.Error(w, "Not found", http.StatusNotFound) + }) + + // Admin endpoints + mux.HandleFunc(prefix+"/update", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + h.HandleUpdate(w, r) + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + mux.HandleFunc(prefix+"/stats", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + h.HandleStats(w, r) + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + }) + + // Real-time endpoints + mux.HandleFunc(prefix+"/updates/ws", h.HandleWebSocket) + mux.HandleFunc(prefix+"/updates/stream", h.HandleSSE) + + // OpenAPI specification endpoints + mux.HandleFunc(prefix+"/openapi.json", h.HandleOpenAPIJSON) + mux.HandleFunc(prefix+"/openapi.yaml", h.HandleOpenAPIYAML) + + // Metrics endpoint (optional) + if s.config.MetricsEnabled { + mux.HandleFunc("/metrics", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain") + _, _ = fmt.Fprintf(w, "# Starmap API Metrics\n") + _, _ = fmt.Fprintf(w, "# TYPE starmap_api_info gauge\n") + _, _ = fmt.Fprintf(w, "starmap_api_info{version=\"v1\"} 1\n") + }) + } +} + +// applyMiddleware wraps handler with middleware chain. +func (s *Server) applyMiddleware(handler http.Handler) http.Handler { + cfg := s.config + + // Rate limiting (if enabled) + if cfg.RateLimit > 0 { + rateLimiter := middleware.NewRateLimiter(cfg.RateLimit, s.logger) + handler = middleware.RateLimit(rateLimiter)(handler) + } + + // Authentication (if enabled) + if cfg.AuthEnabled { + authConfig := middleware.DefaultAuthConfig() + authConfig.Enabled = true + authConfig.HeaderName = cfg.AuthHeader + handler = middleware.Auth(authConfig, s.logger)(handler) + } + + // CORS (if enabled) + if cfg.CORSEnabled { + corsConfig := middleware.DefaultCORSConfig() + if len(cfg.CORSOrigins) > 0 { + corsConfig.AllowedOrigins = cfg.CORSOrigins + corsConfig.AllowAll = false + } else { + corsConfig.AllowAll = true + } + handler = middleware.CORS(corsConfig)(handler) + } + + // Logging and recovery (always enabled) + handler = middleware.Logger(s.logger)(handler) + handler = middleware.Recovery(s.logger)(handler) + + return handler +} + +// extractPathParam extracts path parameter from URL. +func extractPathParam(path, prefix string) string { + trimmed := strings.TrimPrefix(path, prefix) + parts := strings.Split(trimmed, "/") + if len(parts) > 0 { + return parts[0] + } + return "" +} + +// splitPath splits a URL path into parts, removing empty strings. +func splitPath(path string) []string { + parts := []string{} + for _, part := range strings.Split(path, "/") { + if part != "" { + parts = append(parts, part) + } + } + return parts +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 000000000..1921a7f00 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,178 @@ +// Package server provides HTTP server implementation for the Starmap API. +package server + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" + + "github.com/agentstation/starmap/cmd/application" + "github.com/agentstation/starmap/internal/server/cache" + "github.com/agentstation/starmap/internal/server/events" + "github.com/agentstation/starmap/internal/server/events/adapters" + "github.com/agentstation/starmap/internal/server/sse" + ws "github.com/agentstation/starmap/internal/server/websocket" + "github.com/agentstation/starmap/pkg/catalogs" +) + +// Server holds the HTTP server state and dependencies. +type Server struct { + app application.Application + cache *cache.Cache + broker *events.Broker + wsHub *ws.Hub + sseBroadcaster *sse.Broadcaster + upgrader websocket.Upgrader + logger *zerolog.Logger + config Config + ctx context.Context + cancel context.CancelFunc +} + +// New creates a new server instance with the given configuration. +func New(app application.Application, cfg Config) (*Server, error) { + logger := app.Logger() + + // Set defaults + if cfg.CacheTTL == 0 { + cfg.CacheTTL = 5 * time.Minute + } + + // Create unified event broker + broker := events.NewBroker(logger) + + // Create transport layers + wsHub := ws.NewHub(logger) + sseBroadcaster := sse.NewBroadcaster(logger) + + // Subscribe transports to broker + broker.Subscribe(adapters.NewWebSocketSubscriber(wsHub)) + broker.Subscribe(adapters.NewSSESubscriber(sseBroadcaster)) + + // Create context for managing background services + ctx, cancel := context.WithCancel(context.Background()) + + server := &Server{ + app: app, + cache: cache.New(cfg.CacheTTL, cfg.CacheTTL*2), + broker: broker, + wsHub: wsHub, + sseBroadcaster: sseBroadcaster, + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(_ *http.Request) bool { + return true // Allow all origins for WebSocket + }, + }, + logger: logger, + config: cfg, + ctx: ctx, + cancel: cancel, + } + + // Connect Starmap hooks to event broker + if err := server.connectHooks(); err != nil { + return nil, err + } + + return server, nil +} + +// connectHooks registers Starmap event hooks to publish to the broker. +func (s *Server) connectHooks() error { + sm, err := s.app.Starmap() + if err != nil { + return err + } + + // Model added + sm.OnModelAdded(func(model catalogs.Model) { + s.broker.Publish(events.ModelAdded, map[string]any{ + "model": model, + }) + s.logger.Debug(). + Str("model_id", model.ID). + Msg("Model added event published") + }) + + // Model updated + sm.OnModelUpdated(func(old, updated catalogs.Model) { + s.broker.Publish(events.ModelUpdated, map[string]any{ + "old_model": old, + "new_model": updated, + }) + s.logger.Debug(). + Str("model_id", updated.ID). + Msg("Model updated event published") + }) + + // Model removed + sm.OnModelRemoved(func(model catalogs.Model) { + s.broker.Publish(events.ModelDeleted, map[string]any{ + "model": model, + }) + s.logger.Debug(). + Str("model_id", model.ID). + Msg("Model deleted event published") + }) + + s.logger.Info().Msg("Starmap hooks connected to event broker") + return nil +} + +// Start starts background services (broker, WebSocket hub, SSE broadcaster). +func (s *Server) Start() { + go s.broker.Run(s.ctx) + go s.wsHub.Run(s.ctx) + go s.sseBroadcaster.Run(s.ctx) +} + +// Handler returns the configured http.Handler with middleware chain applied. +func (s *Server) Handler() http.Handler { + return s.setupRouter() +} + +// Shutdown gracefully shuts down background services. +func (s *Server) Shutdown(ctx context.Context) error { + s.logger.Info().Msg("Shutting down server background services") + + // Cancel the context to stop all background services + s.cancel() + + // Give services time to shutdown gracefully + shutdownTimeout := time.NewTimer(5 * time.Second) + defer shutdownTimeout.Stop() + + select { + case <-shutdownTimeout.C: + s.logger.Warn().Msg("Background services shutdown timed out") + case <-time.After(100 * time.Millisecond): + s.logger.Info().Msg("Background services shut down successfully") + } + + return nil +} + +// Cache returns the server's cache instance. +func (s *Server) Cache() *cache.Cache { + return s.cache +} + +// WSHub returns the WebSocket hub. +func (s *Server) WSHub() *ws.Hub { + return s.wsHub +} + +// SSEBroadcaster returns the SSE broadcaster. +func (s *Server) SSEBroadcaster() *sse.Broadcaster { + return s.sseBroadcaster +} + +// Broker returns the event broker for publishing events. +func (s *Server) Broker() *events.Broker { + return s.broker +} diff --git a/internal/server/sse/README.md b/internal/server/sse/README.md new file mode 100644 index 000000000..a98f822f6 --- /dev/null +++ b/internal/server/sse/README.md @@ -0,0 +1,96 @@ + + + + +# sse + +```go +import "github.com/agentstation/starmap/internal/server/sse" +``` + +Package sse provides Server\-Sent Events support for real\-time updates. + +## Index + +- [type Broadcaster](<#Broadcaster>) + - [func NewBroadcaster\(logger \*zerolog.Logger\) \*Broadcaster](<#NewBroadcaster>) + - [func \(b \*Broadcaster\) Broadcast\(event Event\)](<#Broadcaster.Broadcast>) + - [func \(b \*Broadcaster\) ClientCount\(\) int](<#Broadcaster.ClientCount>) + - [func \(b \*Broadcaster\) Run\(\)](<#Broadcaster.Run>) + - [func \(b \*Broadcaster\) ServeHTTP\(w http.ResponseWriter, r \*http.Request\)](<#Broadcaster.ServeHTTP>) +- [type Event](<#Event>) + + + +## type [Broadcaster]() + +Broadcaster manages Server\-Sent Events connections. + +```go +type Broadcaster struct { + // contains filtered or unexported fields +} +``` + + +### func [NewBroadcaster]() + +```go +func NewBroadcaster(logger *zerolog.Logger) *Broadcaster +``` + +NewBroadcaster creates a new SSE broadcaster. + + +### func \(\*Broadcaster\) [Broadcast]() + +```go +func (b *Broadcaster) Broadcast(event Event) +``` + +Broadcast sends an event to all connected SSE clients. + + +### func \(\*Broadcaster\) [ClientCount]() + +```go +func (b *Broadcaster) ClientCount() int +``` + +ClientCount returns the number of connected SSE clients. + + +### func \(\*Broadcaster\) [Run]() + +```go +func (b *Broadcaster) Run() +``` + +Run starts the broadcaster's main loop. Should be called in a goroutine. + + +### func \(\*Broadcaster\) [ServeHTTP]() + +```go +func (b *Broadcaster) ServeHTTP(w http.ResponseWriter, r *http.Request) +``` + +ServeHTTP handles SSE connections. + + +## type [Event]() + +Event represents an SSE event. + +```go +type Event struct { + Event string `json:"event,omitempty"` // Event type (optional) + ID string `json:"id,omitempty"` // Event ID (optional) + Data any `json:"data"` // Event data +} +``` + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/sse/broadcaster.go b/internal/server/sse/broadcaster.go new file mode 100644 index 000000000..f9cfaf87f --- /dev/null +++ b/internal/server/sse/broadcaster.go @@ -0,0 +1,177 @@ +// Package sse provides Server-Sent Events support for real-time updates. +package sse + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// Broadcaster manages Server-Sent Events connections. +type Broadcaster struct { + clients map[chan Event]bool + newClients chan chan Event + closed chan chan Event + events chan Event + mu sync.RWMutex + logger *zerolog.Logger +} + +// NewBroadcaster creates a new SSE broadcaster. +func NewBroadcaster(logger *zerolog.Logger) *Broadcaster { + return &Broadcaster{ + clients: make(map[chan Event]bool), + newClients: make(chan chan Event), + closed: make(chan chan Event), + events: make(chan Event, 256), + logger: logger, + } +} + +// Run starts the broadcaster's main loop. Should be called in a goroutine. +// The broadcaster will run until the context is cancelled. +func (b *Broadcaster) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + // Graceful shutdown: close all client connections + b.mu.Lock() + for client := range b.clients { + close(client) + } + b.clients = make(map[chan Event]bool) + b.mu.Unlock() + b.logger.Info().Msg("SSE broadcaster shut down") + return + + case client := <-b.newClients: + b.mu.Lock() + b.clients[client] = true + b.mu.Unlock() + b.logger.Info(). + Int("total_clients", len(b.clients)). + Msg("SSE client connected") + + case client := <-b.closed: + b.mu.Lock() + delete(b.clients, client) + close(client) + b.mu.Unlock() + b.logger.Info(). + Int("total_clients", len(b.clients)). + Msg("SSE client disconnected") + + case event := <-b.events: + b.mu.RLock() + for client := range b.clients { + select { + case client <- event: + default: + // Client buffer full, skip this event for this client + b.logger.Warn().Msg("SSE client buffer full, event skipped") + } + } + b.mu.RUnlock() + } + } +} + +// Broadcast sends an event to all connected SSE clients. +func (b *Broadcaster) Broadcast(event Event) { + select { + case b.events <- event: + default: + b.logger.Warn().Msg("SSE broadcast channel full, event dropped") + } +} + +// ClientCount returns the number of connected SSE clients. +func (b *Broadcaster) ClientCount() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.clients) +} + +// ServeHTTP handles SSE connections. +func (b *Broadcaster) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + // Create client channel + client := make(chan Event, 256) + + // Register client + b.newClients <- client + + // Ensure cleanup + defer func() { + b.closed <- client + }() + + // Get flusher + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + // Send initial connection event + initialEvent := Event{ + Event: "connected", + Data: map[string]any{ + "message": "Connected to Starmap updates stream", + "timestamp": time.Now(), + }, + } + b.writeEvent(w, flusher, initialEvent) + + // Stream events + for { + select { + case event := <-client: + b.writeEvent(w, flusher, event) + + case <-r.Context().Done(): + return + } + } +} + +// writeEvent writes an SSE event to the response writer. +func (b *Broadcaster) writeEvent(w http.ResponseWriter, flusher http.Flusher, event Event) { + // Write event type if specified + if event.Event != "" { + _, _ = fmt.Fprintf(w, "event: %s\n", event.Event) + } + + // Write event ID if specified + if event.ID != "" { + _, _ = fmt.Fprintf(w, "id: %s\n", event.ID) + } + + // Write data as JSON + data, err := json.Marshal(event.Data) + if err != nil { + b.logger.Error().Err(err).Msg("Failed to marshal SSE event data") + return + } + _, _ = fmt.Fprintf(w, "data: %s\n\n", data) + + // Flush the response + flusher.Flush() +} + +// Event represents an SSE event. +type Event struct { + Event string `json:"event,omitempty"` // Event type (optional) + ID string `json:"id,omitempty"` // Event ID (optional) + Data any `json:"data"` // Event data +} diff --git a/internal/server/sse/broadcaster_test.go b/internal/server/sse/broadcaster_test.go new file mode 100644 index 000000000..0b5128c4f --- /dev/null +++ b/internal/server/sse/broadcaster_test.go @@ -0,0 +1,595 @@ +package sse + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" +) + +// TestBroadcaster_NewBroadcaster tests broadcaster creation. +func TestBroadcaster_NewBroadcaster(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + if b == nil { + t.Fatal("NewBroadcaster returned nil") + } + + if b.clients == nil { + t.Error("clients map not initialized") + } + + if b.newClients == nil { + t.Error("newClients channel not initialized") + } + + if b.closed == nil { + t.Error("closed channel not initialized") + } + + if b.events == nil { + t.Error("events channel not initialized") + } +} + +// TestBroadcaster_BasicOperation tests basic broadcaster operations. +func TestBroadcaster_BasicOperation(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register client + client := make(chan Event, 256) + b.newClients <- client + time.Sleep(10 * time.Millisecond) + + if count := b.ClientCount(); count != 1 { + t.Fatalf("expected 1 client, got %d", count) + } + + // Broadcast event + event := Event{ + Event: "test", + Data: map[string]any{"test": true}, + } + b.Broadcast(event) + + // Verify client received event + select { + case received := <-client: + if received.Event != event.Event { + t.Errorf("expected event %s, got %s", event.Event, received.Event) + } + case <-time.After(100 * time.Millisecond): + t.Error("client did not receive event") + } +} + +// TestBroadcaster_Shutdown tests graceful shutdown. +func TestBroadcaster_Shutdown(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithCancel(context.Background()) + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register clients + client1 := make(chan Event, 256) + client2 := make(chan Event, 256) + b.newClients <- client1 + b.newClients <- client2 + time.Sleep(10 * time.Millisecond) + + if count := b.ClientCount(); count != 2 { + t.Fatalf("expected 2 clients, got %d", count) + } + + // Trigger shutdown + cancel() + time.Sleep(50 * time.Millisecond) + + // Verify all clients disconnected + if count := b.ClientCount(); count != 0 { + t.Errorf("expected 0 clients after shutdown, got %d", count) + } +} + +// TestBroadcaster_MultipleClients tests multiple concurrent SSE clients. +func TestBroadcaster_MultipleClients(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register multiple clients + const numClients = 10 + clients := make([]chan Event, numClients) + for i := 0; i < numClients; i++ { + clients[i] = make(chan Event, 256) + b.newClients <- clients[i] + } + time.Sleep(50 * time.Millisecond) + + // Verify all registered + if count := b.ClientCount(); count != numClients { + t.Fatalf("expected %d clients, got %d", numClients, count) + } + + // Broadcast event + testEvent := Event{ + Event: "test", + ID: "123", + Data: map[string]any{"message": "hello"}, + } + b.Broadcast(testEvent) + + // Verify all clients received event + for i, client := range clients { + select { + case event := <-client: + if event.Event != testEvent.Event { + t.Errorf("client %d: expected event %s, got %s", i, testEvent.Event, event.Event) + } + if event.ID != testEvent.ID { + t.Errorf("client %d: expected ID %s, got %s", i, testEvent.ID, event.ID) + } + case <-time.After(200 * time.Millisecond): + t.Errorf("client %d: did not receive event", i) + } + } +} + +// TestBroadcaster_BroadcastChannelFull tests behavior when broadcast channel is full. +func TestBroadcaster_BroadcastChannelFull(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + // Don't start Run() so events won't be consumed + // This will cause the channel to fill up + + // Fill the channel (capacity is 256) + for i := 0; i < 256; i++ { + b.Broadcast(Event{ + Event: "fill", + Data: map[string]any{"i": i}, + }) + } + + // Next broadcast should not block (should drop the event) + done := make(chan bool, 1) + go func() { + b.Broadcast(Event{ + Event: "overflow", + Data: map[string]any{"test": true}, + }) + done <- true + }() + + select { + case <-done: + // Success - broadcast didn't block + case <-time.After(100 * time.Millisecond): + t.Error("Broadcast blocked when channel was full") + } +} + +// TestBroadcaster_ClientDisconnect tests client disconnect handling. +func TestBroadcaster_ClientDisconnect(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register clients + client1 := make(chan Event, 256) + client2 := make(chan Event, 256) + b.newClients <- client1 + b.newClients <- client2 + time.Sleep(10 * time.Millisecond) + + if count := b.ClientCount(); count != 2 { + t.Fatalf("expected 2 clients, got %d", count) + } + + // Disconnect client1 + b.closed <- client1 + time.Sleep(10 * time.Millisecond) + + if count := b.ClientCount(); count != 1 { + t.Errorf("expected 1 client after disconnect, got %d", count) + } + + // Broadcast event - only client2 should receive + testEvent := Event{Event: "test", Data: map[string]any{"value": 42}} + b.Broadcast(testEvent) + + // client2 should receive + select { + case <-client2: + // Success + case <-time.After(100 * time.Millisecond): + t.Error("client2 did not receive event") + } + + // client1 should be closed + select { + case _, ok := <-client1: + if ok { + t.Error("client1 channel should be closed") + } + default: + t.Error("client1 channel not closed") + } +} + +// TestBroadcaster_ClientBufferFull tests behavior when client buffer is full. +func TestBroadcaster_ClientBufferFull(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register client with small buffer + client := make(chan Event, 5) + b.newClients <- client + time.Sleep(10 * time.Millisecond) + + // Fill client buffer + for i := 0; i < 5; i++ { + b.Broadcast(Event{Event: "fill", Data: map[string]any{"i": i}}) + time.Sleep(5 * time.Millisecond) + } + + // Broadcast more events - should skip when buffer full + for i := 0; i < 5; i++ { + b.Broadcast(Event{Event: "overflow", Data: map[string]any{"i": i}}) + time.Sleep(5 * time.Millisecond) + } + + // Verify client still connected + if count := b.ClientCount(); count != 1 { + t.Errorf("expected 1 client, got %d", count) + } + + // Drain client buffer + received := 0 + timeout := time.After(100 * time.Millisecond) + for { + select { + case <-client: + received++ + case <-timeout: + goto verify + } + } +verify: + // Should have received at least the initial 5 events + if received < 5 { + t.Errorf("expected at least 5 events, got %d", received) + } +} + +// TestBroadcaster_ServeHTTP tests the SSE HTTP handler. +func TestBroadcaster_ServeHTTP(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create request with cancellable context + req := httptest.NewRequest("GET", "/events", nil) + reqCtx, reqCancel := context.WithCancel(req.Context()) + req = req.WithContext(reqCtx) + + // Create response recorder + w := httptest.NewRecorder() + + // Start ServeHTTP in goroutine + done := make(chan bool) + go func() { + b.ServeHTTP(w, req) + done <- true + }() + + // Wait for client to register + for i := 0; i < 100; i++ { + if b.ClientCount() == 1 { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Verify client registered + if count := b.ClientCount(); count != 1 { + t.Fatalf("expected 1 client, got %d", count) + } + + // Broadcast test event + testEvent := Event{ + Event: "test.event", + ID: "evt-123", + Data: map[string]any{"message": "hello"}, + } + b.Broadcast(testEvent) + time.Sleep(100 * time.Millisecond) + + // Cancel request to stop ServeHTTP + reqCancel() + + // Wait for handler to finish + select { + case <-done: + // Success + case <-time.After(500 * time.Millisecond): + t.Error("ServeHTTP did not finish after context cancel") + } + + // Now it's safe to check headers and body since ServeHTTP has finished + // Verify headers + if ct := w.Header().Get("Content-Type"); ct != "text/event-stream" { + t.Errorf("expected Content-Type=text/event-stream, got %s", ct) + } + if cc := w.Header().Get("Cache-Control"); cc != "no-cache" { + t.Errorf("expected Cache-Control=no-cache, got %s", cc) + } + if conn := w.Header().Get("Connection"); conn != "keep-alive" { + t.Errorf("expected Connection=keep-alive, got %s", conn) + } + + // Verify response body contains SSE formatted data + body := w.Body.String() + + // Should contain initial connection event + if !strings.Contains(body, "event: connected") { + t.Error("missing initial connection event") + } + if !strings.Contains(body, "Connected to Starmap updates stream") { + t.Error("missing connection message") + } + + // Should contain test event + if !strings.Contains(body, "event: test.event") { + t.Error("missing test event type") + } + if !strings.Contains(body, "id: evt-123") { + t.Error("missing test event ID") + } + if !strings.Contains(body, `"message":"hello"`) { + t.Error("missing test event data") + } +} + +// TestBroadcaster_ServeHTTP_NoFlusher tests ServeHTTP with non-flushing writer. +func TestBroadcaster_ServeHTTP_NoFlusher(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create request + req := httptest.NewRequest("GET", "/events", nil) + + // Create custom ResponseWriter that doesn't implement Flusher + w := &nonFlushingWriter{ + header: make(http.Header), + buffer: &strings.Builder{}, + } + + // ServeHTTP should detect lack of flusher and return error + b.ServeHTTP(w, req) + + // Verify error response + if w.statusCode != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", w.statusCode) + } + if !strings.Contains(w.buffer.String(), "Streaming not supported") { + t.Error("missing streaming not supported error") + } +} + +// TestBroadcaster_WriteEvent tests SSE event formatting. +func TestBroadcaster_WriteEvent(t *testing.T) { + tests := []struct { + name string + event Event + expectedOutput []string + }{ + { + name: "full event with type, ID, and data", + event: Event{ + Event: "update", + ID: "123", + Data: map[string]any{"status": "ok"}, + }, + expectedOutput: []string{ + "event: update", + "id: 123", + `data: {"status":"ok"}`, + }, + }, + { + name: "event without type", + event: Event{ + ID: "456", + Data: map[string]any{"value": 42}, + }, + expectedOutput: []string{ + "id: 456", + `data: {"value":42}`, + }, + }, + { + name: "event without ID", + event: Event{ + Event: "ping", + Data: map[string]any{"timestamp": 12345}, + }, + expectedOutput: []string{ + "event: ping", + `data: {"timestamp":12345}`, + }, + }, + { + name: "event with only data", + event: Event{ + Data: map[string]any{"test": true}, + }, + expectedOutput: []string{ + `data: {"test":true}`, + }, + }, + { + name: "event with string data", + event: Event{ + Event: "message", + Data: "hello world", + }, + expectedOutput: []string{ + "event: message", + `data: "hello world"`, + }, + }, + { + name: "event with null data", + event: Event{ + Event: "empty", + Data: nil, + }, + expectedOutput: []string{ + "event: empty", + "data: null", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + w := httptest.NewRecorder() + flusher := w + + b.writeEvent(w, flusher, tt.event) + + output := w.Body.String() + + // Verify all expected strings are present + for _, expected := range tt.expectedOutput { + if !strings.Contains(output, expected) { + t.Errorf("output missing expected string %q\nGot: %s", expected, output) + } + } + + // Verify SSE format (ends with double newline) + if !strings.HasSuffix(output, "\n\n") { + t.Error("SSE event should end with double newline") + } + }) + } +} + +// TestBroadcaster_ConcurrentBroadcast tests concurrent broadcasting. +func TestBroadcaster_ConcurrentBroadcast(t *testing.T) { + logger := zerolog.Nop() + b := NewBroadcaster(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go b.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register client + client := make(chan Event, 256) + b.newClients <- client + time.Sleep(10 * time.Millisecond) + + // Broadcast multiple events concurrently + const numEvents = 50 + done := make(chan bool) + go func() { + for i := 0; i < numEvents; i++ { + b.Broadcast(Event{ + Event: "concurrent", + Data: map[string]any{"i": i}, + }) + } + done <- true + }() + + // Wait for broadcasts + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("concurrent broadcast timeout") + } + + // Drain and count messages + time.Sleep(100 * time.Millisecond) + count := 0 + timeout := time.After(200 * time.Millisecond) + for { + select { + case <-client: + count++ + case <-timeout: + goto verify + } + } +verify: + if count != numEvents { + t.Errorf("expected %d events, got %d", numEvents, count) + } +} + +// nonFlushingWriter is a ResponseWriter that doesn't implement Flusher. +type nonFlushingWriter struct { + header http.Header + buffer *strings.Builder + statusCode int +} + +func (w *nonFlushingWriter) Header() http.Header { + return w.header +} + +func (w *nonFlushingWriter) Write(data []byte) (int, error) { + return w.buffer.Write(data) +} + +func (w *nonFlushingWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} diff --git a/internal/server/sse/generate.go b/internal/server/sse/generate.go new file mode 100644 index 000000000..5589ea542 --- /dev/null +++ b/internal/server/sse/generate.go @@ -0,0 +1,3 @@ +package sse + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/sse diff --git a/internal/server/websocket/README.md b/internal/server/websocket/README.md new file mode 100644 index 000000000..a80faa60a --- /dev/null +++ b/internal/server/websocket/README.md @@ -0,0 +1,128 @@ + + + + +# websocket + +```go +import "github.com/agentstation/starmap/internal/server/websocket" +``` + +Package websocket provides WebSocket support for real\-time catalog updates. + +## Index + +- [type Client](<#Client>) + - [func NewClient\(id string, hub \*Hub, conn \*websocket.Conn\) \*Client](<#NewClient>) + - [func \(c \*Client\) ReadPump\(\)](<#Client.ReadPump>) + - [func \(c \*Client\) WritePump\(\)](<#Client.WritePump>) +- [type Hub](<#Hub>) + - [func NewHub\(logger \*zerolog.Logger\) \*Hub](<#NewHub>) + - [func \(h \*Hub\) Broadcast\(message Message\)](<#Hub.Broadcast>) + - [func \(h \*Hub\) ClientCount\(\) int](<#Hub.ClientCount>) + - [func \(h \*Hub\) Run\(\)](<#Hub.Run>) +- [type Message](<#Message>) + + + +## type [Client]() + +Client represents a WebSocket client connection. + +```go +type Client struct { + // contains filtered or unexported fields +} +``` + + +### func [NewClient]() + +```go +func NewClient(id string, hub *Hub, conn *websocket.Conn) *Client +``` + +NewClient creates a new WebSocket client. + + +### func \(\*Client\) [ReadPump]() + +```go +func (c *Client) ReadPump() +``` + +ReadPump pumps messages from the WebSocket connection to the hub. + + +### func \(\*Client\) [WritePump]() + +```go +func (c *Client) WritePump() +``` + +WritePump pumps messages from the hub to the WebSocket connection. + + +## type [Hub]() + +Hub maintains active WebSocket connections and broadcasts messages. + +```go +type Hub struct { + // contains filtered or unexported fields +} +``` + + +### func [NewHub]() + +```go +func NewHub(logger *zerolog.Logger) *Hub +``` + +NewHub creates a new WebSocket hub. + + +### func \(\*Hub\) [Broadcast]() + +```go +func (h *Hub) Broadcast(message Message) +``` + +Broadcast sends a message to all connected clients. + + +### func \(\*Hub\) [ClientCount]() + +```go +func (h *Hub) ClientCount() int +``` + +ClientCount returns the number of connected clients. + + +### func \(\*Hub\) [Run]() + +```go +func (h *Hub) Run() +``` + +Run starts the hub's main loop. Should be called in a goroutine. + + +## type [Message]() + +Message represents a WebSocket message. + +```go +type Message struct { + Type string `json:"type"` + Timestamp time.Time `json:"timestamp"` + Data any `json:"data"` +} +``` + +Generated by [gomarkdoc]() + + + \ No newline at end of file diff --git a/internal/server/websocket/generate.go b/internal/server/websocket/generate.go new file mode 100644 index 000000000..01e95c63d --- /dev/null +++ b/internal/server/websocket/generate.go @@ -0,0 +1,3 @@ +package websocket + +//go:generate gomarkdoc -e -o README.md . --repository.path /internal/server/websocket diff --git a/internal/server/websocket/hub.go b/internal/server/websocket/hub.go new file mode 100644 index 000000000..f3124a341 --- /dev/null +++ b/internal/server/websocket/hub.go @@ -0,0 +1,215 @@ +// Package websocket provides WebSocket support for real-time catalog updates. +package websocket + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +// Hub maintains active WebSocket connections and broadcasts messages. +type Hub struct { + clients map[*Client]bool + broadcast chan Message + register chan *Client + unregister chan *Client + mu sync.RWMutex + logger *zerolog.Logger +} + +// NewHub creates a new WebSocket hub. +func NewHub(logger *zerolog.Logger) *Hub { + return &Hub{ + clients: make(map[*Client]bool), + broadcast: make(chan Message, 256), + register: make(chan *Client), + unregister: make(chan *Client), + logger: logger, + } +} + +// Run starts the hub's main loop. Should be called in a goroutine. +// The hub will run until the context is cancelled. +func (h *Hub) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + // Graceful shutdown: close all client connections + h.mu.Lock() + for client := range h.clients { + close(client.send) + } + h.clients = make(map[*Client]bool) + h.mu.Unlock() + h.logger.Info().Msg("WebSocket hub shut down") + return + + case client := <-h.register: + h.mu.Lock() + h.clients[client] = true + h.mu.Unlock() + h.logger.Info(). + Str("client_id", client.id). + Int("total_clients", len(h.clients)). + Msg("WebSocket client connected") + + case client := <-h.unregister: + h.mu.Lock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + h.mu.Unlock() + h.logger.Info(). + Str("client_id", client.id). + Int("total_clients", len(h.clients)). + Msg("WebSocket client disconnected") + + case message := <-h.broadcast: + h.mu.RLock() + // Take snapshot of clients for safe iteration + clients := make([]*Client, 0, len(h.clients)) + for client := range h.clients { + clients = append(clients, client) + } + h.mu.RUnlock() + + // Send to clients (some may need disconnection) + for _, client := range clients { + select { + case client.send <- message: + default: + // Client buffer full, disconnect via unregister channel + h.unregister <- client + } + } + } + } +} + +// Register registers a client with the hub. +func (h *Hub) Register(client *Client) { + h.register <- client +} + +// Broadcast sends a message to all connected clients. +func (h *Hub) Broadcast(message Message) { + select { + case h.broadcast <- message: + default: + h.logger.Warn().Msg("Broadcast channel full, message dropped") + } +} + +// ClientCount returns the number of connected clients. +func (h *Hub) ClientCount() int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.clients) +} + +// Message represents a WebSocket message. +type Message struct { + Type string `json:"type"` + Timestamp time.Time `json:"timestamp"` + Data any `json:"data"` +} + +// Client represents a WebSocket client connection. +type Client struct { + id string + hub *Hub + conn *websocket.Conn + send chan Message +} + +// NewClient creates a new WebSocket client. +func NewClient(id string, hub *Hub, conn *websocket.Conn) *Client { + return &Client{ + id: id, + hub: hub, + conn: conn, + send: make(chan Message, 256), + } +} + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +// ReadPump pumps messages from the WebSocket connection to the hub. +func (c *Client) ReadPump() { + defer func() { + c.hub.unregister <- c + _ = c.conn.Close() + }() + + c.conn.SetReadLimit(maxMessageSize) + _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { + _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + c.hub.logger.Error().Err(err).Str("client_id", c.id).Msg("WebSocket read error") + } + break + } + } +} + +// WritePump pumps messages from the hub to the WebSocket connection. +func (c *Client) WritePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + _ = c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // Hub closed the channel + _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + // Write message as JSON + data, err := json.Marshal(message) + if err != nil { + c.hub.logger.Error().Err(err).Msg("Failed to marshal WebSocket message") + continue + } + + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + return + } + + case <-ticker.C: + _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/internal/server/websocket/hub_test.go b/internal/server/websocket/hub_test.go new file mode 100644 index 000000000..a32a39d14 --- /dev/null +++ b/internal/server/websocket/hub_test.go @@ -0,0 +1,862 @@ +package websocket + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +// TestHub_NewHub tests hub creation. +func TestHub_NewHub(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + if hub == nil { + t.Fatal("NewHub returned nil") + } + + if hub.clients == nil { + t.Error("clients map not initialized") + } + + if hub.broadcast == nil { + t.Error("broadcast channel not initialized") + } + + if hub.register == nil { + t.Error("register channel not initialized") + } + + if hub.unregister == nil { + t.Error("unregister channel not initialized") + } +} + +// TestHub_BasicOperation tests basic hub operations with proper cleanup. +func TestHub_BasicOperation(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + // Create context with timeout for this test + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Start hub + go hub.Run(ctx) + + // Give hub time to start + time.Sleep(10 * time.Millisecond) + + // Create and register client + client := NewClient("test-1", hub, nil) + hub.Register(client) + + // Wait for registration + time.Sleep(10 * time.Millisecond) + + // Verify client count + if count := hub.ClientCount(); count != 1 { + t.Errorf("expected 1 client, got %d", count) + } + + // Broadcast message + msg := Message{ + Type: "test.event", + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + hub.Broadcast(msg) + + // Verify client received message + select { + case received := <-client.send: + if received.Type != msg.Type { + t.Errorf("expected type %s, got %s", msg.Type, received.Type) + } + case <-time.After(100 * time.Millisecond): + t.Error("client did not receive message") + } + + // Test passes - context cleanup happens automatically +} + +// TestHub_Shutdown tests graceful shutdown. +func TestHub_Shutdown(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithCancel(context.Background()) + + // Start hub + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register clients + client1 := NewClient("test-1", hub, nil) + client2 := NewClient("test-2", hub, nil) + hub.Register(client1) + hub.Register(client2) + + time.Sleep(10 * time.Millisecond) + + if count := hub.ClientCount(); count != 2 { + t.Fatalf("expected 2 clients, got %d", count) + } + + // Trigger shutdown + cancel() + + // Wait for shutdown + time.Sleep(50 * time.Millisecond) + + // Verify all clients disconnected + if count := hub.ClientCount(); count != 0 { + t.Errorf("expected 0 clients after shutdown, got %d", count) + } +} + +// TestHub_ConcurrentBroadcast tests concurrent broadcasting. +func TestHub_ConcurrentBroadcast(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register client + client := NewClient("test", hub, nil) + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Broadcast multiple messages concurrently + done := make(chan bool) + go func() { + for i := 0; i < 10; i++ { + hub.Broadcast(Message{ + Type: "test", + Data: map[string]any{"i": i}, + }) + } + done <- true + }() + + // Wait for broadcasts + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("broadcast timeout") + } + + // Drain messages + count := 0 + timeout := time.After(200 * time.Millisecond) + for { + select { + case <-client.send: + count++ + case <-timeout: + goto verify + } + } +verify: + if count != 10 { + t.Errorf("expected 10 messages, got %d", count) + } +} + +// TestHub_MultipleClients tests multiple concurrent clients. +func TestHub_MultipleClients(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register multiple clients + const numClients = 20 + clients := make([]*Client, numClients) + for i := 0; i < numClients; i++ { + clients[i] = NewClient("client-"+string(rune(i)), hub, nil) + hub.Register(clients[i]) + } + time.Sleep(50 * time.Millisecond) + + // Verify all registered + if count := hub.ClientCount(); count != numClients { + t.Fatalf("expected %d clients, got %d", numClients, count) + } + + // Broadcast message + testMsg := Message{ + Type: "test.event", + Data: map[string]any{"message": "hello"}, + } + hub.Broadcast(testMsg) + + // Verify all clients received message + for i, client := range clients { + select { + case msg := <-client.send: + if msg.Type != testMsg.Type { + t.Errorf("client %d: expected type %s, got %s", i, testMsg.Type, msg.Type) + } + case <-time.After(200 * time.Millisecond): + t.Errorf("client %d: did not receive message", i) + } + } +} + +// TestHub_ClientBufferFull tests client behavior when buffer approaches full. +func TestHub_ClientBufferFull(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create client with small buffer + client := &Client{ + id: "test-client", + hub: hub, + conn: nil, + send: make(chan Message, 10), // Small buffer + } + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Send messages rapidly + for i := 0; i < 20; i++ { + hub.Broadcast(Message{ + Type: "rapid", + Data: map[string]any{"i": i}, + }) + } + + // Wait for processing + time.Sleep(100 * time.Millisecond) + + // Client should either handle messages or be unregistered + // (implementation dependent, so we just verify no panic occurred) + count := hub.ClientCount() + if count < 0 || count > 1 { + t.Errorf("unexpected client count: %d", count) + } +} + +// TestHub_ConcurrentRegisterUnregister tests concurrent register/unregister operations. +func TestHub_ConcurrentRegisterUnregister(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Concurrently register and unregister clients + const numOperations = 50 + done := make(chan bool, numOperations*2) + + // Registrations + for i := 0; i < numOperations; i++ { + go func(id int) { + client := NewClient("client-"+string(rune(id)), hub, nil) + hub.Register(client) + done <- true + }(i) + } + + // Unregistrations (for some clients) via unregister channel + for i := 0; i < numOperations/2; i++ { + go func(id int) { + time.Sleep(5 * time.Millisecond) + client := NewClient("client-"+string(rune(id)), hub, nil) + hub.unregister <- client + done <- true + }(i) + } + + // Wait for all operations + for i := 0; i < numOperations+numOperations/2; i++ { + <-done + } + + time.Sleep(50 * time.Millisecond) + + // Final count should be reasonable + count := hub.ClientCount() + if count < 0 || count > numOperations { + t.Errorf("unexpected client count: %d", count) + } +} + +// TestHub_MessageOrdering tests that messages maintain order for each client. +func TestHub_MessageOrdering(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register client + client := NewClient("test", hub, nil) + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Send ordered messages + const numMessages = 20 + for i := 0; i < numMessages; i++ { + hub.Broadcast(Message{ + Type: "ordered", + Data: map[string]any{"seq": i}, + }) + } + + // Verify order is maintained + for i := 0; i < numMessages; i++ { + select { + case msg := <-client.send: + data, ok := msg.Data.(map[string]any) + if !ok { + t.Fatal("invalid message data type") + } + seq, ok := data["seq"].(int) + if !ok { + t.Fatal("invalid seq type") + } + if seq != i { + t.Errorf("expected seq=%d, got %d (out of order)", i, seq) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timeout waiting for message %d", i) + } + } +} + +// TestHub_StressTest tests hub under heavy concurrent load. +func TestHub_StressTest(t *testing.T) { + if testing.Short() { + t.Skip("skipping stress test in short mode") + } + + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Register many clients + const numClients = 100 + clients := make([]*Client, numClients) + for i := 0; i < numClients; i++ { + clients[i] = NewClient("stress-"+string(rune(i)), hub, nil) + hub.Register(clients[i]) + } + time.Sleep(100 * time.Millisecond) + + // Broadcast many messages + const numMessages = 100 + done := make(chan bool) + go func() { + for i := 0; i < numMessages; i++ { + hub.Broadcast(Message{ + Type: "stress", + Data: map[string]any{"id": i}, + }) + } + done <- true + }() + + // Wait for broadcasts to complete + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("stress test timeout") + } + + // Let messages propagate + time.Sleep(200 * time.Millisecond) + + // Verify all clients still connected + if count := hub.ClientCount(); count != numClients { + t.Errorf("expected %d clients, got %d", numClients, count) + } +} + +// TestClient_WritePump tests the WritePump method with mock WebSocket connection. +func TestClient_WritePump(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create test WebSocket server + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("failed to upgrade: %v", err) + return + } + defer conn.Close() + + // Read messages from client + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + break + } + + // Verify message type is text + if messageType != websocket.TextMessage { + // Could be ping/close + continue + } + + // Verify message is valid JSON + var msg Message + if err := json.Unmarshal(message, &msg); err == nil { + // Message received successfully + t.Logf("Server received: %s", message) + } + } + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] // Convert http:// to ws:// + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and start WritePump + client := NewClient("test-client", hub, conn) + go client.WritePump() + + // Send messages through hub + for i := 0; i < 5; i++ { + msg := Message{ + Type: "test", + Timestamp: time.Now(), + Data: map[string]any{"i": i}, + } + client.send <- msg + time.Sleep(10 * time.Millisecond) + } + + // Close client send channel to trigger shutdown + close(client.send) + + // Wait for WritePump to finish + time.Sleep(100 * time.Millisecond) +} + +// TestClient_ReadPump tests the ReadPump method with mock WebSocket connection. +func TestClient_ReadPump(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create test WebSocket server + upgrader := websocket.Upgrader{} + serverDone := make(chan bool) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("failed to upgrade: %v", err) + return + } + defer conn.Close() + defer func() { serverDone <- true }() + + // Send test messages to client + for i := 0; i < 3; i++ { + msg := Message{ + Type: "server.test", + Timestamp: time.Now(), + Data: map[string]any{"i": i}, + } + data, _ := json.Marshal(msg) + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + // Close connection + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and register + client := NewClient("test-client", hub, conn) + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Start ReadPump in goroutine + go client.ReadPump() + + // Wait for server to finish sending + select { + case <-serverDone: + // Success + case <-time.After(1 * time.Second): + t.Error("server did not finish") + } + + // Wait for ReadPump to process and unregister + time.Sleep(100 * time.Millisecond) + + // Client should be unregistered after connection close + if count := hub.ClientCount(); count != 0 { + t.Errorf("expected 0 clients after close, got %d", count) + } +} + +// TestClient_PingPong tests ping/pong mechanism in WritePump. +func TestClient_PingPong(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Track pings received + pingsReceived := 0 + var mu sync.Mutex + + // Create test WebSocket server + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Set ping handler to count pings + conn.SetPingHandler(func(appData string) error { + mu.Lock() + pingsReceived++ + mu.Unlock() + // Send pong response + return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second)) + }) + + // Read messages (including pings) + for { + _, _, err := conn.ReadMessage() + if err != nil { + break + } + } + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and start WritePump + client := NewClient("test-client", hub, conn) + done := make(chan bool) + go func() { + client.WritePump() + done <- true + }() + + // Wait for at least one ping (pingPeriod is 54 seconds in production) + // For testing, we'll wait a bit and then close + time.Sleep(200 * time.Millisecond) + + // Close client to stop WritePump + close(client.send) + + // Wait for WritePump to finish + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("WritePump did not finish") + } + + // Note: In production, ping period is 54 seconds, so we might not see pings in this test + // The important thing is that WritePump runs without error + t.Logf("Pings received: %d (may be 0 due to short test duration)", pingsReceived) +} + +// TestClient_Integration tests full client lifecycle with real WebSocket connection. +func TestClient_Integration(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Track messages received by server + serverMessages := make([]Message, 0) + var serverMu sync.Mutex + + // Create test WebSocket server + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Read messages from client + for { + messageType, data, err := conn.ReadMessage() + if err != nil { + break + } + + if messageType == websocket.TextMessage { + var msg Message + if err := json.Unmarshal(data, &msg); err == nil { + serverMu.Lock() + serverMessages = append(serverMessages, msg) + serverMu.Unlock() + } + } + } + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and register with hub + client := NewClient("integration-test", hub, conn) + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Verify client registered + if count := hub.ClientCount(); count != 1 { + t.Fatalf("expected 1 client, got %d", count) + } + + // Start client pumps + go client.WritePump() + go client.ReadPump() + + // Broadcast messages via hub + testMessages := []Message{ + {Type: "event.1", Timestamp: time.Now(), Data: map[string]any{"value": 1}}, + {Type: "event.2", Timestamp: time.Now(), Data: map[string]any{"value": 2}}, + {Type: "event.3", Timestamp: time.Now(), Data: map[string]any{"value": 3}}, + } + + for _, msg := range testMessages { + hub.Broadcast(msg) + time.Sleep(20 * time.Millisecond) + } + + // Wait for messages to be received + time.Sleep(100 * time.Millisecond) + + // Verify server received messages + serverMu.Lock() + receivedCount := len(serverMessages) + serverMu.Unlock() + + if receivedCount != len(testMessages) { + t.Errorf("expected %d messages, server received %d", len(testMessages), receivedCount) + } + + // Verify message types + serverMu.Lock() + for i, msg := range serverMessages { + if i < len(testMessages) && msg.Type != testMessages[i].Type { + t.Errorf("message %d: expected type %s, got %s", i, testMessages[i].Type, msg.Type) + } + } + serverMu.Unlock() + + // Close connection + conn.Close() + + // Wait for client to unregister + time.Sleep(100 * time.Millisecond) + + // Verify client unregistered + if count := hub.ClientCount(); count != 0 { + t.Errorf("expected 0 clients after close, got %d", count) + } +} + +// TestClient_WriteDeadline tests write deadline handling in WritePump. +func TestClient_WriteDeadline(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create test WebSocket server + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + // Just keep connection open, read nothing + // This tests that WritePump handles writes correctly + time.Sleep(2 * time.Second) + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and start WritePump + client := NewClient("test-client", hub, conn) + done := make(chan bool) + go func() { + client.WritePump() + done <- true + }() + + // Send a message + msg := Message{ + Type: "test", + Timestamp: time.Now(), + Data: map[string]any{"test": true}, + } + client.send <- msg + + // Wait a bit for write to complete + time.Sleep(100 * time.Millisecond) + + // Close send channel + close(client.send) + + // WritePump should finish gracefully + select { + case <-done: + // Success + case <-time.After(2 * time.Second): + t.Error("WritePump did not finish after close") + } +} + +// TestClient_ConnectionClose tests handling of unexpected connection close. +func TestClient_ConnectionClose(t *testing.T) { + logger := zerolog.Nop() + hub := NewHub(&logger) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go hub.Run(ctx) + time.Sleep(10 * time.Millisecond) + + // Create test WebSocket server that closes abruptly + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + // Close connection immediately + conn.Close() + })) + defer server.Close() + + // Connect client to server + wsURL := "ws" + server.URL[4:] + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + + // Create client and register + client := NewClient("test-client", hub, conn) + hub.Register(client) + time.Sleep(10 * time.Millisecond) + + // Start ReadPump + done := make(chan bool) + go func() { + client.ReadPump() + done <- true + }() + + // ReadPump should detect close and finish + select { + case <-done: + // Success + case <-time.After(1 * time.Second): + t.Error("ReadPump did not finish after connection close") + } + + // Client should be unregistered + time.Sleep(50 * time.Millisecond) + if count := hub.ClientCount(); count != 0 { + t.Errorf("expected 0 clients after close, got %d", count) + } +} diff --git a/internal/utils/ptr/ptr.go b/internal/utils/ptr/ptr.go index 81b549e9f..5feffdb7e 100644 --- a/internal/utils/ptr/ptr.go +++ b/internal/utils/ptr/ptr.go @@ -1,3 +1,4 @@ +// Package ptr provides utility functions for creating pointers to values. package ptr // To creates a pointer to the given value. diff --git a/pkg/sources/README.md b/pkg/sources/README.md index fb9e07999..69d576c6c 100644 --- a/pkg/sources/README.md +++ b/pkg/sources/README.md @@ -122,7 +122,7 @@ func (id ID) String() string String returns the string representation of a source name. -## type [Option]() +## type [Option]() Option is a function that configures options. @@ -131,7 +131,7 @@ type Option func(*Options) ``` -### func [WithCleanupRepo]() +### func [WithCleanupRepo]() ```go func WithCleanupRepo(cleanup bool) Option @@ -140,7 +140,7 @@ func WithCleanupRepo(cleanup bool) Option WithCleanupRepo configures whether to clean up temporary repositories after fetch. -### func [WithFresh]() +### func [WithFresh]() ```go func WithFresh(fresh bool) Option @@ -149,7 +149,7 @@ func WithFresh(fresh bool) Option WithFresh configures fresh sync mode for sources. -### func [WithProviderFilter]() +### func [WithProviderFilter]() ```go func WithProviderFilter(providerID catalogs.ProviderID) Option @@ -158,7 +158,7 @@ func WithProviderFilter(providerID catalogs.ProviderID) Option WithProviderFilter configures filtering for a specific provider. -### func [WithReformat]() +### func [WithReformat]() ```go func WithReformat(reformat bool) Option @@ -167,7 +167,7 @@ func WithReformat(reformat bool) Option WithReformat configures whether to reformat output files. -### func [WithSafeMode]() +### func [WithSafeMode]() ```go func WithSafeMode(safeMode bool) Option @@ -176,7 +176,7 @@ func WithSafeMode(safeMode bool) Option WithSafeMode configures safe mode for sources. -## type [Options]() +## type [Options]() Options is the configuration for sources. @@ -196,7 +196,7 @@ type Options struct { ``` -### func [Defaults]() +### func [Defaults]() ```go func Defaults() *Options @@ -205,7 +205,7 @@ func Defaults() *Options Defaults returns source options with default values. -### func \(\*Options\) [Apply]() +### func \(\*Options\) [Apply]() ```go func (o *Options) Apply(opts ...Option) *Options diff --git a/update.go b/update.go index 1d45bd222..7a21294e7 100644 --- a/update.go +++ b/update.go @@ -33,23 +33,23 @@ var _ Updater = (*client)(nil) // Update manually triggers a catalog update. func (c *client) Update(ctx context.Context) error { - if c.options.remoteServerURL != nil { - return c.updateFromServer(ctx) + if c.options.remoteServerURL != nil { + return c.updateFromServer(ctx) } - if c.options.autoUpdateFunc != nil { + if c.options.autoUpdateFunc != nil { c.mu.RLock() - currentCatalog := c.catalog + currentCatalog := c.catalog c.mu.RUnlock() - newCatalog, err := c.options.autoUpdateFunc(currentCatalog) + newCatalog, err := c.options.autoUpdateFunc(currentCatalog) if err != nil { return err } c.setCatalog(newCatalog) } else { // Use pipeline-based update as default - return c.updateWithPipeline(ctx) + return c.updateWithPipeline(ctx) } return nil @@ -64,14 +64,14 @@ func (c *client) updateWithPipeline(ctx context.Context) error { } // Perform a sync operation with default options - _, err := c.Sync(ctx, opts...) + _, err := c.Sync(ctx, opts...) return err } // updateFromServer fetches catalog updates from the remote server. func (c *client) updateFromServer(ctx context.Context) error { - if c.options.remoteServerURL == nil { + if c.options.remoteServerURL == nil { return &errors.ConfigError{ Component: "starmap", Message: "remote server URL is not set", @@ -88,7 +88,7 @@ func (c *client) updateFromServer(ctx context.Context) error { return errors.WrapResource("create", "request", "", err) } - if c.options.remoteServerAPIKey != nil { + if c.options.remoteServerAPIKey != nil { req.Header.Set("Authorization", "Bearer "+*c.options.remoteServerAPIKey) } @@ -196,7 +196,7 @@ func (c *client) updateFromServer(ctx context.Context) error { // setCatalog updates the catalog and triggers appropriate event hooks. func (c *client) setCatalog(newCatalog catalogs.Catalog) { c.mu.Lock() - oldCatalog := c.catalog + oldCatalog := c.catalog c.catalog = newCatalog c.mu.Unlock()