From c25c5a171bc617c837af5a3e006bf203fd04241a Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 06:59:53 -0400 Subject: [PATCH 01/11] Update oauth2 implementation --- .gitignore | 6 +- CLAUDE.md | 116 +++ README.md | 240 +++++- config.json => config.example.json | 17 +- config.go | 12 + docs/index.html | 13 + http.go | 93 ++- oauth.go | 1196 ++++++++++++++++++++++++++++ 8 files changed, 1684 insertions(+), 9 deletions(-) create mode 100644 CLAUDE.md rename config.json => config.example.json (73%) create mode 100644 oauth.go diff --git a/.gitignore b/.gitignore index 536bd99..f8705ef 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,8 @@ fabric.properties /mcp-proxy -/build \ No newline at end of file +/build + +# Configuration files with personal details +config.json +*.log \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..a7de26d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,116 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +MCP Proxy is a Go-based proxy server that aggregates multiple Model Context Protocol (MCP) resource servers through a single HTTP interface. It acts as a unified gateway, collecting tools, prompts, and resources from various MCP clients and exposing them via HTTP endpoints with SSE or streamable-HTTP transport. + +## Development Commands + +### Build and Run +```bash +# Build the project +make build + +# Run the built binary +./build/mcp-proxy --config config.json + +# Build for Linux x86_64 +make buildLinuxX86 + +# Install via Go +go install github.com/TBXark/mcp-proxy@latest +``` + +### Code Quality +```bash +# Run linter +make lint + +# Format code and fix issues +make format +``` + +### Docker +```bash +# Build multi-arch Docker image +make buildImage + +# Run with Docker +docker run -d -p 9090:9090 -v /path/to/config.json:/config/config.json ghcr.io/tbxark/mcp-proxy:latest +``` + +## Architecture Overview + +### Single Package Design +All code is in the root package with clear separation of concerns across files: +- `main.go`: CLI entry point and argument parsing +- `config.go`: Configuration structures and V2 format with migration from V1 +- `client.go`: MCP client management and server integration logic +- `http.go`: HTTP server with middleware chain (auth, logging, recovery) +- `oauth.go`: OAuth 2.1 authorization server implementation + +### Key Patterns +- **Proxy Aggregation**: Collects capabilities from multiple upstream MCP servers +- **Transport Abstraction**: Supports three transport types seamlessly +- **Middleware Chain**: Modular HTTP middleware for cross-cutting concerns +- **Configuration Migration**: Automatic V1 to V2 config format upgrade +- **OAuth 2.1 Server**: Complete authorization server with user access control + +### Transport Types +1. **stdio**: Command-line subprocess communication (e.g., npx, uvx commands) +2. **sse**: Server-Sent Events for real-time updates +3. **streamable-http**: HTTP streaming transport + +Each transport type is automatically detected based on configuration fields present. + +## Configuration System + +Uses V2 configuration format with backward compatibility: +- **mcpProxy**: Server settings (baseURL, addr, type, auth tokens, OAuth 2 config) +- **mcpServers**: Individual MCP client configurations +- **Tool Filtering**: Allow/block specific tools per server with `toolFilter.mode` and `toolFilter.list` +- **Per-client Auth**: Individual auth tokens override global defaults +- **OAuth 2 Client Credentials**: Supports OAuth 2 authentication for streamable HTTP transport +- **Access Control**: IP allowlist/blocklist, client approval workflows, user restrictions + +Configuration can be loaded from local files or HTTP URLs. The system automatically migrates V1 configs to V2 format. + +## Authentication + +### OAuth 2 Client Credentials Flow +For `streamable-http` transport only: +- Enable via `options.oauth2.enabled: true` in configuration +- Full OAuth 2.1 authorization server with Dynamic Client Registration (RFC 7591) +- PKCE support for enhanced security +- Per-server OAuth discovery endpoints +- Client persistence across server restarts +- Comprehensive access control system + +### Bearer Token Authentication +For all transport types: +- Configure via `options.authTokens` array in configuration +- Uses `Authorization: Bearer ` header format +- Falls back to this method when OAuth 2 is not configured + +### Access Control Features +- **IP Restrictions**: `allowedIPs` and `blockedIPs` arrays +- **Client Management**: `allowedClients` and `blockedClients` arrays +- **Approval Workflow**: `requireApproval` flag for manual client approval +- **Client Tracking**: Automatic logging of client IP addresses and metadata + +## Important Notes + +- **No Tests**: This codebase currently lacks automated tests - consider this when making changes +- **Error Handling**: Uses panic recovery middleware and optional `panicIfInvalid` per client +- **Logging**: Configurable per-client logging with request tracing +- **Health Monitoring**: Automatic ping/health checking for SSE and HTTP transport clients +- **Graceful Shutdown**: Proper signal handling for clean resource cleanup +- **Client Persistence**: OAuth clients saved to `oauth_clients.json` for persistence + +## Dependencies + +- `github.com/mark3labs/mcp-go`: Core MCP protocol implementation +- `github.com/TBXark/confstore`: Configuration management with HTTP loading +- `github.com/TBXark/optional-go`: Optional field handling for config migration \ No newline at end of file diff --git a/README.md b/README.md index 6938623..6500ee4 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,22 @@ The server is configured using a JSON file. Below is an example configuration: "logEnabled": true, "authTokens": [ "DefaultTokens" - ] + ], + "oauth2": { + "enabled": true, + "users": { + "admin": "password123", + "user": "mypassword" + }, + "persistenceDir": "/custom/path/oauth", + "allowedIPs": [ + "34.162.46.92", + "34.162.102.82", + "34.162.136.91", + "34.162.142.92", + "34.162.183.95" + ] + } } }, "mcpServers": { @@ -107,6 +122,12 @@ Common options for `mcpProxy` and `mcpServers`. - `panicIfInvalid`: If true, the server will panic if the client is invalid. - `logEnabled`: If true, the server will log the client's requests. - `authTokens`: A list of authentication tokens for the client. The `Authorization` header will be checked against this list. +- `oauth2`: OAuth 2.1 Authorization Server configuration. **Only applies when proxy type is `streamable-http`.** + - `enabled`: Enable/disable the OAuth 2.1 server. Set to `true` for Claude Desktop integration. + - `users`: Username/password pairs for authentication. Users must provide valid credentials to authorize access. + - `persistenceDir`: Directory for storing OAuth client registrations. Defaults to `$HOME/.mcpproxy` if not specified. + - `allowedIPs`: IP addresses permitted to register OAuth clients. Use Claude's official IPs for security. Empty array allows all IPs. + - `tokenExpirationMinutes`: Access token expiration time in minutes. Defaults to 60 minutes (1 hour) if not specified. - `toolFilter`: Optional tool filtering configuration. **This configuration is only effective in `mcpServers`.** - `mode`: Specifies the filtering mode. Must be explicitly set to `allow` or `block` if `list` is provided. If `list` is present but `mode` is missing or invalid, the filter will be ignored for this server. - `list`: A list of tool names to filter (either allow or block based on the `mode`). @@ -154,6 +175,223 @@ For http streaming mcp servers, the `url` field is required. and `transportType` - `headers`: The headers to send with the request to the MCP client. - `timeout`: The timeout for the request to the MCP client. +### OAuth 2.1 Authorization Server + +When using `streamable-http` transport, the proxy acts as a complete OAuth 2.1 Authorization Server designed for Claude Desktop integration. This provides secure, standards-compliant authentication with advanced security features. + +```jsonc +{ + "mcpProxy": { + "baseURL": "https://mcp.example.com", + "addr": ":9090", + "name": "MCP Proxy", + "version": "1.0.0", + "type": "streamable-http", + "options": { + "oauth2": { + "enabled": true, + "users": { + "admin": "password123", + "user": "mypassword" + }, + "persistenceDir": "/custom/path/for/oauth", + "allowedIPs": [ + "34.162.46.92", + "34.162.102.82", + "34.162.136.91", + "34.162.142.92", + "34.162.183.95" + ] + } + } + }, + "mcpServers": { + "neo4j-memory": { + "command": "docker", + "args": ["run", "-i", "--rm", "mcp/neo4j-memory"] + } + } +} +``` + +#### OAuth Flow Features + +- **🔐 RFC 7591 Dynamic Client Registration**: Claude Desktop automatically registers without manual setup +- **🛡️ PKCE Support**: Proof Key for Code Exchange prevents authorization code interception attacks +- **👤 Username/Password Authentication**: Secure login form validates against configured user credentials +- **🎫 Bearer Token Authorization**: All MCP endpoints require valid OAuth access tokens +- **💾 Token Persistence**: Clients, access tokens, and refresh tokens survive server restarts +- **🔄 Refresh Token Support**: Automatic token renewal for seamless long-term access +- **🌐 IP Allowlisting**: Restrict client registration to Claude's official IP addresses +- **🔒 Callback URL Validation**: Only official Claude callback URLs are accepted + +#### OAuth Endpoints + +The proxy automatically exposes these OAuth endpoints: + +- `GET /.well-known/oauth-authorization-server` - Server metadata discovery +- `POST /oauth/register` - Dynamic client registration +- `GET /oauth/authorize` - Authorization endpoint with login form +- `POST /oauth/token` - Token exchange endpoint + +#### Persistence Directory + +OAuth data (clients, access tokens, refresh tokens) is persisted across server restarts. Default location is `$HOME/.mcpproxy/oauth_clients.json`. + +You can customize the persistence directory: + +```jsonc +{ + "mcpProxy": { + "options": { + "oauth2": { + "enabled": true, + "users": { + "admin": "password123" + }, + "persistenceDir": "/var/lib/mcpproxy" + } + } + } +} +``` + +#### IP Allowlisting + +You can restrict OAuth client registration to specific IP addresses for enhanced security: + +```jsonc +{ + "mcpProxy": { + "options": { + "oauth2": { + "enabled": true, + "users": { + "admin": "password123" + }, + "allowedIPs": [ + "34.162.46.92", + "34.162.102.82", + "34.162.136.91", + "34.162.142.92", + "34.162.183.95" + ] + } + } + } +} +``` + +**Note**: The IP addresses above are Claude's official IP addresses as documented at https://docs.anthropic.com/en/api/ip-addresses#ipv4-2. Using this allowlist ensures only Claude Desktop can register OAuth clients with your proxy. + +**Proxy Support**: The IP detection works correctly with various proxy configurations: +- **Cloudflare**: `CF-Connecting-IP`, `True-Client-IP` +- **nginx**: `X-Real-IP`, `X-Forwarded-For` +- **AWS ALB/ELB**: `X-Forwarded-For` +- **Kubernetes Ingress**: `X-Cluster-Client-IP` +- **RFC 7239 Standard**: `Forwarded` header +- **ngrok/tunnels**: `X-Forwarded-For` +- **Direct connections**: `RemoteAddr` + +#### Configuration Examples + +**Minimal OAuth Setup (Development)**: +```jsonc +{ + "mcpProxy": { + "type": "streamable-http", + "options": { + "oauth2": { + "enabled": true, + "users": { + "developer": "dev-password" + } + } + } + } +} +``` + +**Production OAuth Setup (Recommended)**: +```jsonc +{ + "mcpProxy": { + "type": "streamable-http", + "options": { + "oauth2": { + "enabled": true, + "users": { + "admin": "secure-admin-password", + "user": "secure-user-password" + }, + "persistenceDir": "/var/lib/mcpproxy/oauth", + "allowedIPs": [ + "34.162.46.92", + "34.162.102.82", + "34.162.136.91", + "34.162.142.92", + "34.162.183.95" + ] + } + } + } +} +``` + +**Development/Testing Setup (No IP Restrictions)**: +```jsonc +{ + "mcpProxy": { + "type": "streamable-http", + "options": { + "oauth2": { + "enabled": true, + "users": { + "test": "test123" + }, + "allowedIPs": [], + "tokenExpirationMinutes": 60 + } + } + } +} +``` + +#### Security Features + +- **🔐 Username/Password Authentication**: All OAuth flows require valid user credentials +- **🎫 Bearer Token Access**: MCP endpoints require `Authorization: Bearer ` header +- **🔑 Client Secret Validation**: Generated client secrets are cryptographically validated +- **📁 Secure Persistence**: OAuth data (clients + tokens) stored with 0700 permissions (owner-only) +- **🌐 IP Allowlisting**: Optional restriction to Claude's official IP addresses +- **🔒 Callback URL Validation**: Only official Claude URLs accepted as redirect targets +- **🔄 Configurable Token Expiration**: Access tokens expire after configurable time (default: 1 hour) +- **♻️ Refresh Token Rotation**: New refresh token issued on each refresh (OAuth 2.1 best practice) + +#### Claude Desktop Setup + +Once your proxy is running with OAuth enabled, configure Claude Desktop: + +1. **Add MCP Server**: In Claude Desktop settings, add a new MCP server +2. **Server URL**: Use your proxy's base URL (e.g., `https://your-domain.com` or `https://your-tunnel.ngrok.io`) +3. **Authentication**: Claude Desktop will automatically: + - Discover the OAuth endpoints via `.well-known/oauth-authorization-server` + - Register as an OAuth client via Dynamic Client Registration + - Present a login form for username/password authentication + - Handle token refresh automatically + +**Example Claude Desktop MCP Configuration**: +```json +{ + "mcpServers": { + "your-proxy": { + "command": "mcp", + "args": ["--server", "https://your-domain.com/your-mcp-server"] + } + } +} +``` + ## Usage ```bash diff --git a/config.json b/config.example.json similarity index 73% rename from config.json rename to config.example.json index 4fcf87e..ebf9fd5 100644 --- a/config.json +++ b/config.example.json @@ -10,7 +10,22 @@ "logEnabled": true, "authTokens": [ "DefaultTokens" - ] + ], + "oauth2": { + "enabled": true, + "users": { + "admin": "password123", + "user": "mypassword" + }, + "persistenceDir": "/custom/path/oauth", + "allowedIPs": [ + "34.162.46.92", + "34.162.102.82", + "34.162.136.91", + "34.162.142.92", + "34.162.183.95" + ] + } } }, "mcpServers": { diff --git a/config.go b/config.go index 1af8d25..2cb968a 100644 --- a/config.go +++ b/config.go @@ -56,10 +56,19 @@ type ToolFilterConfig struct { List []string `json:"list,omitempty"` } +type OAuth2Config struct { + Enabled bool `json:"enabled,omitempty"` + Users map[string]string `json:"users,omitempty"` + PersistenceDir string `json:"persistenceDir,omitempty"` + AllowedIPs []string `json:"allowedIPs,omitempty"` + TokenExpirationMinutes int `json:"tokenExpirationMinutes,omitempty"` +} + type OptionsV2 struct { PanicIfInvalid optional.Field[bool] `json:"panicIfInvalid,omitempty"` LogEnabled optional.Field[bool] `json:"logEnabled,omitempty"` AuthTokens []string `json:"authTokens,omitempty"` + OAuth2 *OAuth2Config `json:"oauth2,omitempty"` ToolFilter *ToolFilterConfig `json:"toolFilter,omitempty"` } @@ -161,6 +170,9 @@ func load(path string, insecure bool) (*Config, error) { if clientConfig.Options.AuthTokens == nil { clientConfig.Options.AuthTokens = conf.McpProxy.Options.AuthTokens } + if clientConfig.Options.OAuth2 == nil && conf.McpProxy.Options.OAuth2 != nil { + clientConfig.Options.OAuth2 = conf.McpProxy.Options.OAuth2 + } if !clientConfig.Options.PanicIfInvalid.Present() { clientConfig.Options.PanicIfInvalid = conf.McpProxy.Options.PanicIfInvalid } diff --git a/docs/index.html b/docs/index.html index 6782bb5..6017c16 100644 --- a/docs/index.html +++ b/docs/index.html @@ -143,6 +143,19 @@

+
+

+ +

+
+
+ Complete OAuth 2.1 Authorization Server with PKCE support, username/password authentication, Dynamic Client Registration (RFC 7591), refresh token support, and IP allowlisting for Claude Desktop integration. +
+
+
diff --git a/http.go b/http.go index eb8ae7b..c2daeda 100644 --- a/http.go +++ b/http.go @@ -50,6 +50,50 @@ func newAuthMiddleware(tokens []string) MiddlewareFunc { } } +func newOAuth2Middleware(oauth2Config *OAuth2Config, oauthServer *OAuthServer) MiddlewareFunc { + if oauth2Config == nil || !oauth2Config.Enabled { + return func(next http.Handler) http.Handler { + return next + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "Missing Authorization header", http.StatusUnauthorized) + return + } + + // Check for Bearer token (OAuth 2.1) + if strings.HasPrefix(authHeader, "Bearer ") { + token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + if token == "" { + http.Error(w, "Missing access token", http.StatusUnauthorized) + return + } + + // Validate token with OAuth server + accessToken, valid := oauthServer.ValidateToken(token) + if !valid { + http.Error(w, "Invalid or expired access token", http.StatusUnauthorized) + return + } + + // Add token info to request context for potential use + r.Header.Set("X-OAuth-Client-ID", accessToken.ClientID) + r.Header.Set("X-OAuth-Scope", accessToken.Scope) + r.Header.Set("X-OAuth-Resource", accessToken.Resource) + + next.ServeHTTP(w, r) + return + } + + http.Error(w, "Unsupported authorization method. Use Bearer token", http.StatusUnauthorized) + }) + } +} + func loggerMiddleware(prefix string) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -92,6 +136,16 @@ func startHTTPServer(config *Config) error { Name: config.McpProxy.Name, } + // Create OAuth 2.1 server with access control config + var oauthAccessConfig *OAuth2Config + if config.McpProxy.Options != nil && config.McpProxy.Options.OAuth2 != nil { + oauthAccessConfig = config.McpProxy.Options.OAuth2 + } + oauthServer := NewOAuthServer(config.McpProxy.BaseURL, oauthAccessConfig) + + // Register OAuth routes + oauthServer.RegisterRoutes(httpMux) + for name, clientConfig := range config.McpServers { mcpClient, err := newMCPClient(name, clientConfig) if err != nil { @@ -118,18 +172,45 @@ func startHTTPServer(config *Config) error { if clientConfig.Options.LogEnabled.OrElse(false) { middlewares = append(middlewares, loggerMiddleware(name)) } - if len(clientConfig.Options.AuthTokens) > 0 { + + // Apply authentication middleware based on proxy configuration + // OAuth2 authentication applies when the proxy itself uses streamable-http transport + if config.McpProxy.Type == MCPServerTypeStreamable && config.McpProxy.Options.OAuth2 != nil && config.McpProxy.Options.OAuth2.Enabled { + middlewares = append(middlewares, newOAuth2Middleware(config.McpProxy.Options.OAuth2, oauthServer)) + } else if len(clientConfig.Options.AuthTokens) > 0 { + // Fall back to token authentication if OAuth2 is not configured middlewares = append(middlewares, newAuthMiddleware(clientConfig.Options.AuthTokens)) } mcpRoute := path.Join(baseURL.Path, name) + log.Printf("<%s> baseURL.Path='%s', name='%s', initial mcpRoute='%s'", name, baseURL.Path, name, mcpRoute) + if !strings.HasPrefix(mcpRoute, "/") { mcpRoute = "/" + mcpRoute } - if !strings.HasSuffix(mcpRoute, "/") { - mcpRoute += "/" - } - log.Printf("<%s> Handling requests at %s", name, mcpRoute) - httpMux.Handle(mcpRoute, chainMiddleware(server.handler, middlewares...)) + + baseHandler := chainMiddleware(server.handler, middlewares...) + + // Register exact paths to avoid Go's automatic redirect behavior + mcpRouteWithoutSlash := strings.TrimSuffix(mcpRoute, "/") + mcpRouteWithSlash := mcpRouteWithoutSlash + "/" + + log.Printf("<%s> Registering exact routes: '%s' and '%s'", name, mcpRouteWithoutSlash, mcpRouteWithSlash) + + // Register both exact patterns + httpMux.HandleFunc(mcpRouteWithoutSlash, func(w http.ResponseWriter, r *http.Request) { + // Only handle exact matches to prevent Go's redirect behavior + if r.URL.Path == mcpRouteWithoutSlash { + baseHandler.ServeHTTP(w, r) + } else { + http.NotFound(w, r) + } + }) + + httpMux.HandleFunc(mcpRouteWithSlash, func(w http.ResponseWriter, r *http.Request) { + baseHandler.ServeHTTP(w, r) + }) + + log.Printf("<%s> Routes registered successfully", name) httpServer.RegisterOnShutdown(func() { log.Printf("<%s> Shutting down", name) _ = mcpClient.Close() diff --git a/oauth.go b/oauth.go new file mode 100644 index 0000000..3ecb5b3 --- /dev/null +++ b/oauth.go @@ -0,0 +1,1196 @@ +package main + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +// OAuth 2.1 Server Implementation for MCP + +type OAuthServer struct { + baseURL string + clients map[string]*OAuthClient + authCodes map[string]*AuthorizationCode + accessTokens map[string]*AccessToken + mutex sync.RWMutex + tokenExpiration time.Duration + persistenceFile string + accessConfig *OAuth2Config +} + +type OAuthClient struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + CreatedAt time.Time `json:"client_id_issued_at"` + ClientName string `json:"client_name,omitempty"` +} + +type AuthorizationCode struct { + Code string + ClientID string + RedirectURI string + Scope string + CodeChallenge string // PKCE challenge + ExpiresAt time.Time + Resource string +} + +type AccessToken struct { + Token string + RefreshToken string + ClientID string + Scope string + Resource string + ExpiresAt time.Time +} + +// OAuth Server Metadata Response +type ServerMetadata struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` +} + +// Dynamic Client Registration Request +type ClientRegistrationRequest struct { + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types,omitempty"` + Scope string `json:"scope,omitempty"` + ClientName string `json:"client_name,omitempty"` +} + +// Dynamic Client Registration Response +type ClientRegistrationResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + CreatedAt int64 `json:"client_id_issued_at"` +} + +// Token Request +type TokenRequest struct { + GrantType string `json:"grant_type"` + Code string `json:"code,omitempty"` + RedirectURI string `json:"redirect_uri,omitempty"` + ClientID string `json:"client_id,omitempty"` + CodeVerifier string `json:"code_verifier,omitempty"` + Resource string `json:"resource,omitempty"` +} + +// Token Response +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// Error Response +type OAuthError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { + var mcpProxyDir string + var persistenceFile string + + // Check if persistence directory is specified in config + if accessConfig != nil && accessConfig.PersistenceDir != "" { + // Use the configured directory + mcpProxyDir = accessConfig.PersistenceDir + } else { + // Use default $HOME/.mcpproxy + homeDir, err := os.UserHomeDir() + if err != nil { + log.Printf("OAuth: Could not determine home directory, using current directory: %v", err) + homeDir = "." + } + mcpProxyDir = filepath.Join(homeDir, ".mcpproxy") + } + + persistenceFile = filepath.Join(mcpProxyDir, "oauth_clients.json") + + // Create directory if it doesn't exist + if err := os.MkdirAll(mcpProxyDir, 0700); err != nil { + log.Printf("OAuth: Could not create directory %s: %v", mcpProxyDir, err) + // Fall back to current directory + persistenceFile = "oauth_clients.json" + } + + + // Set token expiration from config or default to 1 hour + tokenExpiration := time.Hour // Default 1 hour + if accessConfig != nil && accessConfig.TokenExpirationMinutes > 0 { + tokenExpiration = time.Duration(accessConfig.TokenExpirationMinutes) * time.Minute + log.Printf("OAuth: Using custom token expiration: %v", tokenExpiration) + } + + server := &OAuthServer{ + baseURL: baseURL, + clients: make(map[string]*OAuthClient), + authCodes: make(map[string]*AuthorizationCode), + accessTokens: make(map[string]*AccessToken), + tokenExpiration: tokenExpiration, + persistenceFile: persistenceFile, + accessConfig: accessConfig, + } + + // Load persisted clients + server.loadClients() + + return server +} + +// OAuth persistence data structure +type OAuthPersistenceData struct { + Clients map[string]*OAuthClient `json:"clients"` + AccessTokens map[string]*AccessToken `json:"accessTokens"` + SavedAt time.Time `json:"savedAt"` +} + +func (s *OAuthServer) loadClients() { + if _, err := os.Stat(s.persistenceFile); os.IsNotExist(err) { + return + } + + data, err := ioutil.ReadFile(s.persistenceFile) + if err != nil { + log.Printf("OAuth: Failed to read persistence file: %v", err) + return + } + + // Try to load new format first (with tokens) + var persistenceData OAuthPersistenceData + if err := json.Unmarshal(data, &persistenceData); err == nil && persistenceData.Clients != nil { + s.mutex.Lock() + s.clients = persistenceData.Clients + + // Load tokens, filtering out expired ones + validAccessTokens := make(map[string]*AccessToken) + + now := time.Now() + for token, accessToken := range persistenceData.AccessTokens { + if accessToken.ExpiresAt.After(now) { + validAccessTokens[token] = accessToken + } + } + + s.accessTokens = validAccessTokens + s.mutex.Unlock() + + log.Printf("OAuth: Loaded %d clients, %d active access tokens", + len(persistenceData.Clients), len(validAccessTokens)) + return + } + + // Fallback to old format (clients only) for backward compatibility + var clients map[string]*OAuthClient + if err := json.Unmarshal(data, &clients); err != nil { + log.Printf("OAuth: Failed to unmarshal persistence data: %v", err) + return + } + + s.mutex.Lock() + s.clients = clients + s.mutex.Unlock() + + log.Printf("OAuth: Loaded %d persisted clients (legacy format)", len(clients)) +} + +func (s *OAuthServer) saveClients() { + s.mutex.RLock() + + // Copy all data for persistence + clients := make(map[string]*OAuthClient) + for k, v := range s.clients { + clients[k] = v + } + + accessTokens := make(map[string]*AccessToken) + for k, v := range s.accessTokens { + accessTokens[k] = v + } + + s.mutex.RUnlock() + + // Create persistence data structure + persistenceData := OAuthPersistenceData{ + Clients: clients, + AccessTokens: accessTokens, + SavedAt: time.Now(), + } + + data, err := json.MarshalIndent(persistenceData, "", " ") + if err != nil { + log.Printf("OAuth: Failed to marshal persistence data: %v", err) + return + } + + if err := ioutil.WriteFile(s.persistenceFile, data, 0600); err != nil { + log.Printf("OAuth: Failed to save persistence data: %v", err) + return + } + + log.Printf("OAuth: Saved %d clients, %d access tokens to persistence file", + len(clients), len(accessTokens)) +} + +func (s *OAuthServer) generateRandomString(length int) string { + bytes := make([]byte, length) + rand.Read(bytes) + return base64.URLEncoding.EncodeToString(bytes)[:length] +} + + + + + +// Server Metadata Discovery Handler - Per MCP Server +func (s *OAuthServer) handleServerMetadata(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract server name from path if present + // Path format: /.well-known/oauth-authorization-server/server-name + path := r.URL.Path + serverName := "" + if path != "/.well-known/oauth-authorization-server" { + parts := strings.Split(strings.TrimPrefix(path, "/.well-known/oauth-authorization-server/"), "/") + if len(parts) > 0 && parts[0] != "" { + serverName = parts[0] + } + } + + metadata := ServerMetadata{ + Issuer: s.baseURL, + AuthorizationEndpoint: s.baseURL + "/oauth/authorize", + TokenEndpoint: s.baseURL + "/oauth/token", + RegistrationEndpoint: s.baseURL + "/oauth/register", + ScopesSupported: []string{"mcp"}, + ResponseTypesSupported: []string{"code"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_post", "none"}, + CodeChallengeMethodsSupported: []string{"S256"}, + } + + // If this is for a specific server, add server-specific metadata + if serverName != "" { + // Add server-specific resource URI + metadata.Issuer = s.baseURL + "/" + serverName + // Update endpoints to include server context + metadata.AuthorizationEndpoint = s.baseURL + "/oauth/authorize?resource=" + url.QueryEscape(s.baseURL+"/"+serverName) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(metadata) +} + +// Protected Resource Metadata Handler +func (s *OAuthServer) handleProtectedResourceMetadata(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract server name from path + // Path format: /.well-known/oauth-protected-resource/server-name + path := r.URL.Path + serverName := "" + parts := strings.Split(strings.TrimPrefix(path, "/.well-known/oauth-protected-resource/"), "/") + if len(parts) > 0 && parts[0] != "" { + serverName = parts[0] + } + + if serverName == "" { + http.Error(w, "Server name required", http.StatusBadRequest) + return + } + + resourceMetadata := map[string]interface{}{ + "resource": s.baseURL + "/" + serverName, + "authorization_servers": []string{s.baseURL}, + "scopes_supported": []string{"mcp"}, + "bearer_methods_supported": []string{"header"}, + "resource_documentation": s.baseURL + "/" + serverName + "/mcp", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resourceMetadata) +} + +// Dynamic Client Registration Handler +func (s *OAuthServer) handleClientRegistration(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Validate client IP against allowlist + if s.accessConfig != nil && len(s.accessConfig.AllowedIPs) > 0 { + clientIP := s.getClientIP(r) + if !s.isIPAllowed(clientIP, s.accessConfig.AllowedIPs) { + log.Printf("OAuth: Client registration blocked - IP %s not in allowlist %v", clientIP, s.accessConfig.AllowedIPs) + s.writeOAuthError(w, "access_denied", "Client registration not allowed from this IP", http.StatusForbidden) + return + } + log.Printf("OAuth: Client registration allowed from IP %s", clientIP) + } + + var req ClientRegistrationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("OAuth: Client registration failed - invalid JSON: %v", err) + s.writeOAuthError(w, "invalid_request", "Invalid JSON request", http.StatusBadRequest) + return + } + + log.Printf("OAuth: Client registration request: %+v", req) + + // Validate redirect URIs + if len(req.RedirectURIs) == 0 { + log.Printf("OAuth: Client registration failed - no redirect URIs") + s.writeOAuthError(w, "invalid_redirect_uri", "At least one redirect URI is required", http.StatusBadRequest) + return + } + + // Validate that redirect URIs are from Claude (allowlist) + allowedCallbackURLs := []string{ + "https://claude.ai/api/mcp/auth_callback", + "https://claude.com/api/mcp/auth_callback", // Future URL + } + + for _, uri := range req.RedirectURIs { + validURI := false + for _, allowed := range allowedCallbackURLs { + if uri == allowed { + validURI = true + break + } + } + if !validURI { + log.Printf("OAuth: Client registration failed - invalid redirect URI: %s", uri) + s.writeOAuthError(w, "invalid_redirect_uri", "Redirect URI not allowed", http.StatusBadRequest) + return + } + } + + + // Generate client credentials + clientID := s.generateRandomString(32) + clientSecret := s.generateRandomString(48) + + client := &OAuthClient{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURIs: req.RedirectURIs, + GrantTypes: []string{"authorization_code"}, + CreatedAt: time.Now(), + ClientName: req.ClientName, + } + + if len(req.GrantTypes) > 0 { + client.GrantTypes = req.GrantTypes + } + + s.mutex.Lock() + s.clients[clientID] = client + s.mutex.Unlock() + + // Save clients to persistence file + s.saveClients() + + log.Printf("OAuth: Registered client ID: %s, redirect URIs: %v", clientID, client.RedirectURIs) + + response := ClientRegistrationResponse{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURIs: client.RedirectURIs, + GrantTypes: client.GrantTypes, + CreatedAt: client.CreatedAt.Unix(), + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(response) +} + +// Authorization Endpoint Handler +func (s *OAuthServer) handleAuthorization(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + s.handleAuthorizationGET(w, r) + } else if r.Method == http.MethodPost { + s.handleAuthorizationPOST(w, r) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *OAuthServer) handleAuthorizationGET(w http.ResponseWriter, r *http.Request) { + // Parse authorization request + clientID := r.URL.Query().Get("client_id") + redirectURI := r.URL.Query().Get("redirect_uri") + responseType := r.URL.Query().Get("response_type") + scope := r.URL.Query().Get("scope") + state := r.URL.Query().Get("state") + codeChallenge := r.URL.Query().Get("code_challenge") + resource := r.URL.Query().Get("resource") + + log.Printf("OAuth: Authorization request - client_id=%s, redirect_uri=%s, response_type=%s, resource=%s", + clientID, redirectURI, responseType, resource) + + // Validate request + if clientID == "" || redirectURI == "" || responseType != "code" { + log.Printf("OAuth: Authorization failed - missing parameters") + s.writeOAuthError(w, "invalid_request", "Missing or invalid required parameters", http.StatusBadRequest) + return + } + + // Show authorization/consent page instead of auto-approving + s.showAuthorizationPage(w, r, clientID, redirectURI, responseType, scope, state, codeChallenge, resource, "") +} + +func (s *OAuthServer) showAuthorizationPage(w http.ResponseWriter, r *http.Request, clientID, redirectURI, responseType, scope, state, codeChallenge, resource, errorMsg string) { + // Skip client validation at authorization endpoint per Claude DCR spec + // Client validation will happen at token endpoint where invalid_client triggers re-registration + log.Printf("OAuth: Authorization request for client_id '%s' - proceeding to login", clientID) + + // Show login page for authentication + clientName := "Claude" // Default to Claude since that's the expected client + + resourceName := "MCP Proxy" + if resource != "" { + // Extract resource name from URL + if u, err := url.Parse(resource); err == nil { + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + if len(parts) > 0 && parts[len(parts)-1] != "" { + resourceName = parts[len(parts)-1] + } + } + } + + // Add error message display + errorHTML := "" + if errorMsg != "" { + errorHTML = `
` + errorMsg + `
` + } + + // HTML login page with error handling + html := ` + + + Sign In + + + + + + +` + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(html)) +} + +func (s *OAuthServer) handleAuthorizationPOST(w http.ResponseWriter, r *http.Request) { + // Parse form data from login page + err := r.ParseForm() + if err != nil { + s.writeOAuthError(w, "invalid_request", "Failed to parse form", http.StatusBadRequest) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + clientID := r.FormValue("client_id") + redirectURI := r.FormValue("redirect_uri") + scope := r.FormValue("scope") + state := r.FormValue("state") + codeChallenge := r.FormValue("code_challenge") + resource := r.FormValue("resource") + + log.Printf("OAuth: Login attempt - username=%s, client_id=%s", username, clientID) + + // Validate credentials against configuration + if s.accessConfig == nil || s.accessConfig.Users == nil { + log.Printf("OAuth: No users configured in OAuth2 config") + s.writeOAuthError(w, "server_error", "Authentication not configured", http.StatusInternalServerError) + return + } + + expectedPassword, exists := s.accessConfig.Users[username] + if !exists || expectedPassword != password { + log.Printf("OAuth: Authentication failed for username: %s", username) + + // Show login page again with error message + s.showAuthorizationPage(w, r, clientID, redirectURI, "code", scope, state, codeChallenge, resource, "Invalid username or password. Please try again.") + return + } + + log.Printf("OAuth: Authentication successful for username: %s", username) + + // Generate authorization code after successful authentication + code := s.generateRandomString(32) + authCode := &AuthorizationCode{ + Code: code, + ClientID: clientID, + RedirectURI: redirectURI, + Scope: scope, + CodeChallenge: codeChallenge, + ExpiresAt: time.Now().Add(10 * time.Minute), + Resource: resource, + } + + s.mutex.Lock() + s.authCodes[code] = authCode + s.mutex.Unlock() + + // Show success page before redirecting + s.showSuccessPage(w, r, redirectURI, code, state, username) + + log.Printf("OAuth: User authenticated successfully, showing success page for code: %s", code) +} + +func (s *OAuthServer) showSuccessPage(w http.ResponseWriter, r *http.Request, redirectURI, code, state, username string) { + // Build redirect URL + redirectURL, _ := url.Parse(redirectURI) + params := redirectURL.Query() + params.Set("code", code) + if state != "" { + params.Set("state", state) + } + redirectURL.RawQuery = params.Encode() + + // HTML success page with auto-redirect + html := ` + + + Sign In Successful + + + + +
+
+
+

Sign In Successful!

+
+ +
+ Welcome back, ` + username + `!
+ You have been successfully authenticated. +
+ +
+ Redirecting to Claude Desktop in 3 seconds... +
+ +
+
+
+ +
+

Taking longer than expected?

+ Click here to continue to Claude Desktop +
+
+ +` + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(html)) +} + +// Token Endpoint Handler +func (s *OAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + log.Printf("OAuth: Token request received - Method: %s, Content-Type: %s", r.Method, r.Header.Get("Content-Type")) + + var grantType, code, redirectURI, clientID, codeVerifier, resource string + + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + // Handle JSON request body + log.Printf("OAuth: Parsing JSON request body") + var req TokenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + log.Printf("OAuth: Failed to parse JSON body: %v", err) + s.writeOAuthError(w, "invalid_request", "Invalid JSON request", http.StatusBadRequest) + return + } + + grantType = req.GrantType + code = req.Code + redirectURI = req.RedirectURI + clientID = req.ClientID + codeVerifier = req.CodeVerifier + resource = req.Resource + + log.Printf("OAuth: JSON request params - grant_type=%s, code=%s, redirect_uri=%s, client_id=%s, resource=%s", + grantType, code, redirectURI, clientID, resource) + } else { + // Handle form data + log.Printf("OAuth: Parsing form data") + if err := r.ParseForm(); err != nil { + log.Printf("OAuth: Failed to parse form data: %v", err) + s.writeOAuthError(w, "invalid_request", "Invalid form data", http.StatusBadRequest) + return + } + + // Log all form values for debugging + log.Printf("OAuth: Token request form data:") + for key, values := range r.PostForm { + log.Printf(" %s: %v", key, values) + } + + grantType = r.FormValue("grant_type") + code = r.FormValue("code") + redirectURI = r.FormValue("redirect_uri") + clientID = r.FormValue("client_id") + codeVerifier = r.FormValue("code_verifier") + resource = r.FormValue("resource") + + log.Printf("OAuth: Form request params - grant_type=%s, code=%s, redirect_uri=%s, client_id=%s, resource=%s", + grantType, code, redirectURI, clientID, resource) + } + + if grantType == "refresh_token" { + s.handleRefreshToken(w, r, clientID) + return + } + + if grantType != "authorization_code" { + s.writeOAuthError(w, "unsupported_grant_type", "Only authorization_code and refresh_token grant types are supported", http.StatusBadRequest) + return + } + + if code == "" || redirectURI == "" || clientID == "" { + s.writeOAuthError(w, "invalid_request", "Missing required parameters", http.StatusBadRequest) + return + } + + // First, validate that the client exists + s.mutex.RLock() + _, clientExists := s.clients[clientID] + s.mutex.RUnlock() + + if !clientExists { + log.Printf("OAuth: Client ID '%s' not found in token endpoint, returning invalid_client", clientID) + s.writeOAuthError(w, "invalid_client", "Client not found", http.StatusUnauthorized) + return + } + + // Validate authorization code + s.mutex.Lock() + authCode, exists := s.authCodes[code] + if exists { + delete(s.authCodes, code) // Use authorization code only once + } + s.mutex.Unlock() + + if !exists { + s.writeOAuthError(w, "invalid_grant", "Invalid or expired authorization code", http.StatusBadRequest) + return + } + + if time.Now().After(authCode.ExpiresAt) { + s.writeOAuthError(w, "invalid_grant", "Authorization code expired", http.StatusBadRequest) + return + } + + if authCode.ClientID != clientID || authCode.RedirectURI != redirectURI { + s.writeOAuthError(w, "invalid_grant", "Authorization code does not match client", http.StatusBadRequest) + return + } + + // PKCE verification (if code_verifier provided) + if codeVerifier != "" && authCode.CodeChallenge != "" { + hash := sha256.Sum256([]byte(codeVerifier)) + challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) + log.Printf("OAuth: PKCE verification - code_verifier='%s', stored_challenge='%s', computed_challenge='%s'", codeVerifier, authCode.CodeChallenge, challenge) + if challenge != authCode.CodeChallenge { + s.writeOAuthError(w, "invalid_grant", "PKCE verification failed", http.StatusBadRequest) + return + } + log.Printf("OAuth: PKCE verification passed") + } + + // Generate access token and refresh token + accessToken := s.generateRandomString(48) + refreshToken := s.generateRandomString(48) + token := &AccessToken{ + Token: accessToken, + RefreshToken: refreshToken, + ClientID: clientID, + Scope: authCode.Scope, + Resource: resource, + ExpiresAt: time.Now().Add(s.tokenExpiration), + } + + s.mutex.Lock() + s.accessTokens[accessToken] = token + s.mutex.Unlock() + + // Persist tokens to disk + s.saveClients() + + response := TokenResponse{ + AccessToken: accessToken, + TokenType: "Bearer", + ExpiresIn: int(s.tokenExpiration.Seconds()), + RefreshToken: refreshToken, + Scope: authCode.Scope, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// Helper function to get the real client IP address +func (s *OAuthServer) getClientIP(r *http.Request) string { + // Priority order for different proxy scenarios: + + // 1. CF-Connecting-IP (Cloudflare) + if cfIP := r.Header.Get("CF-Connecting-IP"); cfIP != "" { + if net.ParseIP(cfIP) != nil { + return cfIP + } + } + + // 2. True-Client-IP (Cloudflare Enterprise, some CDNs) + if tcIP := r.Header.Get("True-Client-IP"); tcIP != "" { + if net.ParseIP(tcIP) != nil { + return tcIP + } + } + + // 3. X-Real-IP (nginx, some proxies) + if xrip := r.Header.Get("X-Real-IP"); xrip != "" { + if net.ParseIP(xrip) != nil { + return xrip + } + } + + // 4. X-Forwarded-For (most proxies/load balancers) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the chain (the original client) + ips := strings.Split(xff, ",") + if len(ips) > 0 { + ip := strings.TrimSpace(ips[0]) + if net.ParseIP(ip) != nil { + return ip + } + } + } + + // 5. X-Cluster-Client-IP (some Kubernetes ingresses) + if ccIP := r.Header.Get("X-Cluster-Client-IP"); ccIP != "" { + if net.ParseIP(ccIP) != nil { + return ccIP + } + } + + // 6. X-Forwarded (less common, but some proxies use it) + if xf := r.Header.Get("X-Forwarded"); xf != "" { + // Format: X-Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43 + if strings.HasPrefix(xf, "for=") { + forPart := strings.Split(xf, ";")[0] + ip := strings.TrimPrefix(forPart, "for=") + if net.ParseIP(ip) != nil { + return ip + } + } + } + + // 7. Forwarded (RFC 7239 standard) + if fwd := r.Header.Get("Forwarded"); fwd != "" { + // Format: Forwarded: for=192.0.2.60;proto=http;by=203.0.113.43 + if strings.Contains(fwd, "for=") { + parts := strings.Split(fwd, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "for=") { + ip := strings.TrimPrefix(part, "for=") + // Handle quoted IPs: for="192.0.2.60" + ip = strings.Trim(ip, "\"") + if net.ParseIP(ip) != nil { + return ip + } + } + } + } + } + + // 8. Fall back to RemoteAddr (direct connection or unknown proxy) + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +// Validate IP against allowlist +func (s *OAuthServer) isIPAllowed(clientIP string, allowedIPs []string) bool { + if len(allowedIPs) == 0 { + return true // No restrictions if allowlist is empty + } + + for _, allowedIP := range allowedIPs { + if clientIP == allowedIP { + return true + } + } + return false +} + +func (s *OAuthServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, clientID string) { + var refreshToken string + + // Parse refresh token from request + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + var req struct { + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.writeOAuthError(w, "invalid_request", "Invalid JSON request", http.StatusBadRequest) + return + } + refreshToken = req.RefreshToken + } else { + refreshToken = r.FormValue("refresh_token") + } + + if refreshToken == "" { + s.writeOAuthError(w, "invalid_request", "Missing refresh_token", http.StatusBadRequest) + return + } + + // Validate client exists + s.mutex.RLock() + _, clientExists := s.clients[clientID] + s.mutex.RUnlock() + + if !clientExists { + log.Printf("OAuth: Client ID '%s' not found in refresh token endpoint", clientID) + s.writeOAuthError(w, "invalid_client", "Client not found", http.StatusUnauthorized) + return + } + + // Find the access token that has this refresh token + s.mutex.Lock() + var oldToken *AccessToken + var oldAccessTokenKey string + exists := false + + for accessTokenKey, token := range s.accessTokens { + if token.RefreshToken == refreshToken { + oldToken = token + oldAccessTokenKey = accessTokenKey + exists = true + break + } + } + + if exists { + // Remove the old access token (which also removes the refresh token) + delete(s.accessTokens, oldAccessTokenKey) + } + s.mutex.Unlock() + + // Persist token deletions to disk + if exists { + s.saveClients() + } + + if !exists { + s.writeOAuthError(w, "invalid_grant", "Invalid refresh token", http.StatusBadRequest) + return + } + + if oldToken.ClientID != clientID { + s.writeOAuthError(w, "invalid_grant", "Refresh token does not belong to client", http.StatusBadRequest) + return + } + + // Generate new access token and refresh token + newAccessToken := s.generateRandomString(48) + newRefreshToken := s.generateRandomString(48) + token := &AccessToken{ + Token: newAccessToken, + RefreshToken: newRefreshToken, + ClientID: clientID, + Scope: oldToken.Scope, + Resource: oldToken.Resource, + ExpiresAt: time.Now().Add(s.tokenExpiration), + } + + s.mutex.Lock() + s.accessTokens[newAccessToken] = token + s.mutex.Unlock() + + // Persist tokens to disk + s.saveClients() + + log.Printf("OAuth: Refreshed tokens for client %s", clientID) + + response := TokenResponse{ + AccessToken: newAccessToken, + TokenType: "Bearer", + ExpiresIn: int(s.tokenExpiration.Seconds()), + RefreshToken: newRefreshToken, + Scope: oldToken.Scope, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// Token Validation +func (s *OAuthServer) ValidateToken(tokenString string) (*AccessToken, bool) { + s.mutex.RLock() + defer s.mutex.RUnlock() + + token, exists := s.accessTokens[tokenString] + if !exists { + return nil, false + } + + if time.Now().After(token.ExpiresAt) { + // Token expired, clean it up + go func() { + s.mutex.Lock() + delete(s.accessTokens, tokenString) + s.mutex.Unlock() + + // Persist cleanup to disk + s.saveClients() + }() + return nil, false + } + + return token, true +} + +func (s *OAuthServer) writeOAuthError(w http.ResponseWriter, error, description string, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(OAuthError{ + Error: error, + ErrorDescription: description, + }) +} + +// Register OAuth routes +func (s *OAuthServer) RegisterRoutes(mux *http.ServeMux) { + // Global OAuth endpoints + mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleServerMetadata) + mux.HandleFunc("/oauth/register", s.handleClientRegistration) + mux.HandleFunc("/oauth/authorize", s.handleAuthorization) + mux.HandleFunc("/oauth/token", s.handleToken) + + // Per-server OAuth discovery endpoints + mux.HandleFunc("/.well-known/oauth-authorization-server/", s.handleServerMetadata) + mux.HandleFunc("/.well-known/oauth-protected-resource/", s.handleProtectedResourceMetadata) +} \ No newline at end of file From 81eda3844475fcd83c98866c1cd6d9942f640a9e Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 07:09:04 -0400 Subject: [PATCH 02/11] fix: Address code review feedback - Use crypto/subtle.ConstantTimeCompare for password comparison to prevent timing attacks - Replace deprecated ioutil.ReadFile/WriteFile with os.ReadFile/WriteFile - Use request context instead of headers for OAuth middleware data passing - Replace HTML string concatenation with html/template package for better security and maintainability - Improve template separation with proper data structures --- http.go | 8 +- oauth.go | 512 ++++++++++++++++++++++++++++++------------------------- 2 files changed, 286 insertions(+), 234 deletions(-) diff --git a/http.go b/http.go index c2daeda..7b28725 100644 --- a/http.go +++ b/http.go @@ -81,11 +81,11 @@ func newOAuth2Middleware(oauth2Config *OAuth2Config, oauthServer *OAuthServer) M } // Add token info to request context for potential use - r.Header.Set("X-OAuth-Client-ID", accessToken.ClientID) - r.Header.Set("X-OAuth-Scope", accessToken.Scope) - r.Header.Set("X-OAuth-Resource", accessToken.Resource) + ctx := context.WithValue(r.Context(), "X-OAuth-Client-ID", accessToken.ClientID) + ctx = context.WithValue(ctx, "X-OAuth-Scope", accessToken.Scope) + ctx = context.WithValue(ctx, "X-OAuth-Resource", accessToken.Resource) - next.ServeHTTP(w, r) + next.ServeHTTP(w, r.WithContext(ctx)) return } diff --git a/oauth.go b/oauth.go index 3ecb5b3..96b5fe3 100644 --- a/oauth.go +++ b/oauth.go @@ -3,9 +3,10 @@ package main import ( "crypto/rand" "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/json" - "io/ioutil" + "html/template" "log" "net" "net/http" @@ -113,6 +114,244 @@ type OAuthError struct { ErrorDescription string `json:"error_description,omitempty"` } +// Template data structures +type AuthPageData struct { + ClientID string + ClientName string + ResourceName string + RedirectURI string + ResponseType string + Scope string + State string + CodeChallenge string + Resource string + ErrorMessage string +} + +type SuccessPageData struct { + RedirectURL string + Username string +} + +// HTML Templates +const authorizationPageTemplate = ` + + + Sign In + + + + + + +` + +const successPageTemplate = ` + + + Sign In Successful + + + + +
+
+
+

Sign In Successful!

+
+ +
+ Welcome, {{.Username}}! You have been successfully authenticated. +
+ +
+ Redirecting to Claude in 3 seconds... +
+ +
+
+
+ +
+ If you are not automatically redirected, + click here to continue +
+
+ +` + func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { var mcpProxyDir string var persistenceFile string @@ -176,7 +415,7 @@ func (s *OAuthServer) loadClients() { return } - data, err := ioutil.ReadFile(s.persistenceFile) + data, err := os.ReadFile(s.persistenceFile) if err != nil { log.Printf("OAuth: Failed to read persistence file: %v", err) return @@ -249,7 +488,7 @@ func (s *OAuthServer) saveClients() { return } - if err := ioutil.WriteFile(s.persistenceFile, data, 0600); err != nil { + if err := os.WriteFile(s.persistenceFile, data, 0600); err != nil { log.Printf("OAuth: Failed to save persistence data: %v", err) return } @@ -492,122 +731,33 @@ func (s *OAuthServer) showAuthorizationPage(w http.ResponseWriter, r *http.Reque } } - // Add error message display - errorHTML := "" - if errorMsg != "" { - errorHTML = `
` + errorMsg + `
` + // Prepare template data + data := AuthPageData{ + ClientID: clientID, + ClientName: clientName, + ResourceName: resourceName, + RedirectURI: redirectURI, + ResponseType: responseType, + Scope: scope, + State: state, + CodeChallenge: codeChallenge, + Resource: resource, + ErrorMessage: errorMsg, } - // HTML login page with error handling - html := ` - - - Sign In - - - - - - -` + // Parse and execute template + tmpl, err := template.New("authPage").Parse(authorizationPageTemplate) + if err != nil { + log.Printf("OAuth: Failed to parse authorization template: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) - w.Write([]byte(html)) + if err := tmpl.Execute(w, data); err != nil { + log.Printf("OAuth: Failed to execute authorization template: %v", err) + } } func (s *OAuthServer) handleAuthorizationPOST(w http.ResponseWriter, r *http.Request) { @@ -637,7 +787,7 @@ func (s *OAuthServer) handleAuthorizationPOST(w http.ResponseWriter, r *http.Req } expectedPassword, exists := s.accessConfig.Users[username] - if !exists || expectedPassword != password { + if !exists || subtle.ConstantTimeCompare([]byte(expectedPassword), []byte(password)) != 1 { log.Printf("OAuth: Authentication failed for username: %s", username) // Show login page again with error message @@ -678,124 +828,26 @@ func (s *OAuthServer) showSuccessPage(w http.ResponseWriter, r *http.Request, re params.Set("state", state) } redirectURL.RawQuery = params.Encode() - - // HTML success page with auto-redirect - html := ` - - - Sign In Successful - - - - -
-
-
-

Sign In Successful!

-
- -
- Welcome back, ` + username + `!
- You have been successfully authenticated. -
- -
- Redirecting to Claude Desktop in 3 seconds... -
- -
-
-
- -
-

Taking longer than expected?

- Click here to continue to Claude Desktop -
-
- -` + + // Prepare template data + data := SuccessPageData{ + RedirectURL: redirectURL.String(), + Username: username, + } + + // Parse and execute template + tmpl, err := template.New("successPage").Parse(successPageTemplate) + if err != nil { + log.Printf("OAuth: Failed to parse success template: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) - w.Write([]byte(html)) + if err := tmpl.Execute(w, data); err != nil { + log.Printf("OAuth: Failed to execute success template: %v", err) + } } // Token Endpoint Handler From 884c45c48c1145b36f3e6967d1b216ebc85f6146 Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 07:12:33 -0400 Subject: [PATCH 03/11] fix: Fix automatic redirect JavaScript error in OAuth success page - Add null checks for DOM elements before accessing them - Use DOMContentLoaded event to ensure DOM is ready before executing JavaScript - Fix countdown functionality and manual redirect fallback - Resolves 'Cannot set properties of null' error --- oauth.go | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/oauth.go b/oauth.go index 96b5fe3..c7d8bb3 100644 --- a/oauth.go +++ b/oauth.go @@ -305,24 +305,39 @@ const successPageTemplate = ` let countdown = 3; function updateCountdown() { - document.getElementById('countdown').textContent = countdown; + const countdownElement = document.getElementById('countdown'); + const redirectTextElement = document.getElementById('redirect-text'); + + if (countdownElement) { + countdownElement.textContent = countdown; + } countdown--; if (countdown < 0) { - document.getElementById('redirect-text').textContent = 'Redirecting now...'; + if (redirectTextElement) { + redirectTextElement.textContent = 'Redirecting now...'; + } window.location.href = '{{.RedirectURL}}'; } else { setTimeout(updateCountdown, 1000); } } - // Fallback for manual redirect after 10 seconds - setTimeout(function() { - document.getElementById('manual-redirect').style.display = 'block'; - }, 10000); + function showManualRedirect() { + const manualRedirectElement = document.getElementById('manual-redirect'); + if (manualRedirectElement) { + manualRedirectElement.style.display = 'block'; + } + } - // Start countdown immediately - updateCountdown(); + // Wait for DOM to be fully loaded + document.addEventListener('DOMContentLoaded', function() { + // Start countdown after DOM is ready + updateCountdown(); + + // Fallback for manual redirect after 10 seconds + setTimeout(showManualRedirect, 10000); + }); From 4f5149fdcd4ab2f5b8c7345029e44b7efdbc51b9 Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 07:22:35 -0400 Subject: [PATCH 04/11] feat: Implement built-in OAuth templates with ejection capability - Add oauth_templates.go with embedded default templates as constants - Implement template override system: external templates take precedence over built-in - Add --eject-templates CLI flag to export built-in templates for customization - Update template loading logic with fallback mechanism - Remove go:embed dependency - templates now work without external files - Update documentation with new template system and ejection workflow - Templates work out-of-the-box but can be customized when needed Benefits: - Zero external dependencies - OAuth works immediately after build - Easy customization via template ejection - Graceful fallback from external to built-in templates - Clear upgrade path for template customization --- README.md | 59 +++++++++ main.go | 41 +++++++ oauth.go | 297 +++++++-------------------------------------- oauth_templates.go | 237 ++++++++++++++++++++++++++++++++++++ 4 files changed, 383 insertions(+), 251 deletions(-) create mode 100644 oauth_templates.go diff --git a/README.md b/README.md index 6500ee4..28e9722 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ make build ./build/mcp-proxy --config path/to/config.json ``` +**Note**: OAuth templates are built-in by default. Use `--eject-templates` to customize them. + ### Install by go ```bash @@ -392,14 +394,71 @@ Once your proxy is running with OAuth enabled, configure Claude Desktop: } ``` +### OAuth Template Customization + +The OAuth 2.1 server includes built-in HTML templates for the authorization and success pages, with support for customization. + +#### Built-in Templates + +By default, the OAuth templates are embedded in the binary and require no external files. The server automatically uses these built-in templates. + +#### Ejecting Templates for Customization + +To customize the OAuth pages, first eject the templates: + +```bash +./mcp-proxy --eject-templates +``` + +This creates: +``` +templates/ +└── oauth/ + ├── authorize.html # Login form page + └── success.html # Success/redirect page +``` + +#### Template Override Behavior + +- **No `templates/oauth/` directory**: Uses built-in templates +- **`templates/oauth/` exists**: Uses external templates with fallback to built-in if loading fails +- **To revert to built-in**: Simply remove the `templates/` directory + +#### Template Data + +**authorize.html** receives: +- `ClientName` - Application name (usually "Claude") +- `ResourceName` - Resource being accessed +- `ClientID`, `RedirectURI`, `ResponseType`, `Scope`, `State`, `CodeChallenge`, `Resource` - OAuth parameters +- `ErrorMessage` - Error message to display (if any) + +**success.html** receives: +- `Username` - Authenticated user's username +- `RedirectURL` - Complete redirect URL with authorization code + +#### Customizing Templates + +After ejecting templates: + +1. Edit the HTML files in `templates/oauth/` +2. Maintain the form structure and hidden fields in `authorize.html` +3. Keep the JavaScript redirect functionality in `success.html` +4. Restart the proxy server to reload templates + +The templates use Go's `html/template` package with automatic XSS protection and context-aware escaping. + ## Usage ```bash Usage of mcp-proxy: -config string path to config file or a http(s) url (default "config.json") + -eject-templates + eject OAuth templates to templates/oauth/ directory for customization -help print help and exit + -insecure + allow insecure HTTPS connections by skipping TLS certificate verification -version print version and exit ``` diff --git a/main.go b/main.go index 1762076..526e5f7 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,8 @@ import ( "flag" "fmt" "log" + "os" + "path/filepath" ) var BuildVersion = "dev" @@ -13,6 +15,7 @@ func main() { insecure := flag.Bool("insecure", false, "allow insecure HTTPS connections by skipping TLS certificate verification") version := flag.Bool("version", false, "print version and exit") help := flag.Bool("help", false, "print help and exit") + ejectTemplates := flag.Bool("eject-templates", false, "eject OAuth templates to templates/oauth/ directory for customization") flag.Parse() if *help { flag.Usage() @@ -22,6 +25,12 @@ func main() { fmt.Println(BuildVersion) return } + if *ejectTemplates { + if err := ejectOAuthTemplates(); err != nil { + log.Fatalf("Failed to eject templates: %v", err) + } + return + } config, err := load(*conf, *insecure) if err != nil { log.Fatalf("Failed to load config: %v", err) @@ -31,3 +40,35 @@ func main() { log.Fatalf("Failed to start server: %v", err) } } + +func ejectOAuthTemplates() error { + templatesDir := "templates/oauth" + + // Create templates directory + if err := os.MkdirAll(templatesDir, 0755); err != nil { + return fmt.Errorf("failed to create templates directory: %v", err) + } + + // Write authorize.html + authorizePath := filepath.Join(templatesDir, "authorize.html") + if err := os.WriteFile(authorizePath, []byte(defaultAuthorizePage), 0644); err != nil { + return fmt.Errorf("failed to write authorize.html: %v", err) + } + + // Write success.html + successPath := filepath.Join(templatesDir, "success.html") + if err := os.WriteFile(successPath, []byte(defaultSuccessPage), 0644); err != nil { + return fmt.Errorf("failed to write success.html: %v", err) + } + + fmt.Printf("OAuth templates ejected to %s/\n", templatesDir) + fmt.Println("You can now customize the HTML templates and restart the server to use them.") + fmt.Println() + fmt.Println("Template files created:") + fmt.Printf(" %s - OAuth authorization/login page\n", authorizePath) + fmt.Printf(" %s - OAuth success/redirect page\n", successPath) + fmt.Println() + fmt.Println("To use the built-in templates again, simply remove the templates/ directory.") + + return nil +} diff --git a/oauth.go b/oauth.go index c7d8bb3..c99d2b7 100644 --- a/oauth.go +++ b/oauth.go @@ -29,6 +29,7 @@ type OAuthServer struct { tokenExpiration time.Duration persistenceFile string accessConfig *OAuth2Config + templates *template.Template } type OAuthClient struct { @@ -133,239 +134,6 @@ type SuccessPageData struct { Username string } -// HTML Templates -const authorizationPageTemplate = ` - - - Sign In - - - - - - -` - -const successPageTemplate = ` - - - Sign In Successful - - - - -
-
-
-

Sign In Successful!

-
- -
- Welcome, {{.Username}}! You have been successfully authenticated. -
- -
- Redirecting to Claude in 3 seconds... -
- -
-
-
- -
- If you are not automatically redirected, - click here to continue -
-
- -` func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { var mcpProxyDir string @@ -402,6 +170,44 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { log.Printf("OAuth: Using custom token expiration: %v", tokenExpiration) } + // Load templates with fallback mechanism + var templates *template.Template + var err error + + // First try to load external templates (for customization) + if _, statErr := os.Stat("templates/oauth"); statErr == nil { + log.Printf("OAuth: Found external templates directory, loading custom templates") + templates, err = template.ParseGlob("templates/oauth/*.html") + if err != nil { + log.Printf("OAuth: Failed to load external templates: %v", err) + log.Printf("OAuth: Falling back to built-in templates") + } else { + log.Printf("OAuth: Successfully loaded external templates") + } + } + + // Fall back to built-in templates if external ones failed or don't exist + if templates == nil { + log.Printf("OAuth: Loading built-in templates") + templates = template.New("") + + // Parse built-in authorize template + _, err = templates.New("authorize.html").Parse(defaultAuthorizePage) + if err != nil { + log.Printf("OAuth: Failed to parse built-in authorize template: %v", err) + return nil + } + + // Parse built-in success template + _, err = templates.New("success.html").Parse(defaultSuccessPage) + if err != nil { + log.Printf("OAuth: Failed to parse built-in success template: %v", err) + return nil + } + + log.Printf("OAuth: Successfully loaded built-in templates") + } + server := &OAuthServer{ baseURL: baseURL, clients: make(map[string]*OAuthClient), @@ -410,6 +216,7 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { tokenExpiration: tokenExpiration, persistenceFile: persistenceFile, accessConfig: accessConfig, + templates: templates, } // Load persisted clients @@ -760,18 +567,12 @@ func (s *OAuthServer) showAuthorizationPage(w http.ResponseWriter, r *http.Reque ErrorMessage: errorMsg, } - // Parse and execute template - tmpl, err := template.New("authPage").Parse(authorizationPageTemplate) - if err != nil { - log.Printf("OAuth: Failed to parse authorization template: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - + // Execute authorization template w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) - if err := tmpl.Execute(w, data); err != nil { + if err := s.templates.ExecuteTemplate(w, "authorize.html", data); err != nil { log.Printf("OAuth: Failed to execute authorization template: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -850,18 +651,12 @@ func (s *OAuthServer) showSuccessPage(w http.ResponseWriter, r *http.Request, re Username: username, } - // Parse and execute template - tmpl, err := template.New("successPage").Parse(successPageTemplate) - if err != nil { - log.Printf("OAuth: Failed to parse success template: %v", err) - http.Error(w, "Internal server error", http.StatusInternalServerError) - return - } - + // Execute success template w.Header().Set("Content-Type", "text/html") w.WriteHeader(http.StatusOK) - if err := tmpl.Execute(w, data); err != nil { + if err := s.templates.ExecuteTemplate(w, "success.html", data); err != nil { log.Printf("OAuth: Failed to execute success template: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) } } diff --git a/oauth_templates.go b/oauth_templates.go new file mode 100644 index 0000000..bc21651 --- /dev/null +++ b/oauth_templates.go @@ -0,0 +1,237 @@ +package main + +// Default embedded OAuth templates +// These are used as fallback when external templates are not found + +const defaultAuthorizePage = ` + + + Sign In + + + + + + +` + +const defaultSuccessPage = ` + + + Sign In Successful + + + + +
+
+
+

Sign In Successful!

+
+ +
+ Welcome, {{.Username}}! You have been successfully authenticated. +
+ +
+ Redirecting to Claude in 3 seconds... +
+ +
+
+
+ +
+ If you are not automatically redirected, + click here to continue +
+
+ +` \ No newline at end of file From 14b6d2fb9d670734fd6306d141879d7a3bb61873 Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 07:26:06 -0400 Subject: [PATCH 05/11] feat: Add configurable template directory with CLI override support - Add templateDir field to OAuth2Config for custom template locations - Update template loading to use configured directory with fallback - Add --eject-templates-to flag for CLI override of template ejection path - Enhance ejectOAuthTemplates to support custom base directories - Update template loading priority: configured dir > built-in templates - Improve documentation with templateDir configuration examples - Update CLI usage documentation and config examples Template directory behavior: - Configuration: Set oauth2.templateDir in config (default: 'templates') - Ejection: --eject-templates uses config dir, --eject-templates-to overrides - Loading: Server loads from {templateDir}/oauth/ with built-in fallback This enables flexible deployment scenarios while maintaining zero-config defaults. --- README.md | 15 ++++++++++----- config.example.json | 3 ++- config.go | 1 + main.go | 28 ++++++++++++++++++++++++---- oauth.go | 17 ++++++++++++----- 5 files changed, 49 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 28e9722..444157f 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,8 @@ The server is configured using a JSON file. Below is an example configuration: "34.162.136.91", "34.162.142.92", "34.162.183.95" - ] + ], + "templateDir": "/custom/templates" } } }, @@ -130,6 +131,7 @@ Common options for `mcpProxy` and `mcpServers`. - `persistenceDir`: Directory for storing OAuth client registrations. Defaults to `$HOME/.mcpproxy` if not specified. - `allowedIPs`: IP addresses permitted to register OAuth clients. Use Claude's official IPs for security. Empty array allows all IPs. - `tokenExpirationMinutes`: Access token expiration time in minutes. Defaults to 60 minutes (1 hour) if not specified. + - `templateDir`: Base directory for OAuth HTML templates. Server looks for templates in `{templateDir}/oauth/`. Defaults to `templates` if not specified. - `toolFilter`: Optional tool filtering configuration. **This configuration is only effective in `mcpServers`.** - `mode`: Specifies the filtering mode. Must be explicitly set to `allow` or `block` if `list` is provided. If `list` is present but `mode` is missing or invalid, the filter will be ignored for this server. - `list`: A list of tool names to filter (either allow or block based on the `mode`). @@ -420,9 +422,10 @@ templates/ #### Template Override Behavior -- **No `templates/oauth/` directory**: Uses built-in templates -- **`templates/oauth/` exists**: Uses external templates with fallback to built-in if loading fails -- **To revert to built-in**: Simply remove the `templates/` directory +The server loads templates in this priority order: +1. **External templates**: `{templateDir}/oauth/*.html` (where `templateDir` is from config, default: `templates`) +2. **Built-in templates**: Embedded defaults if external templates don't exist or fail to load +3. **To revert to built-in**: Remove the template directory or set `templateDir` to a non-existent path #### Template Data @@ -454,7 +457,9 @@ Usage of mcp-proxy: -config string path to config file or a http(s) url (default "config.json") -eject-templates - eject OAuth templates to templates/oauth/ directory for customization + eject OAuth templates to configured templateDir/oauth/ (or templates/oauth/ if not configured) + -eject-templates-to string + eject OAuth templates to specified directory (overrides config templateDir) -help print help and exit -insecure diff --git a/config.example.json b/config.example.json index ebf9fd5..db982a6 100644 --- a/config.example.json +++ b/config.example.json @@ -24,7 +24,8 @@ "34.162.136.91", "34.162.142.92", "34.162.183.95" - ] + ], + "templateDir": "/custom/templates" } } }, diff --git a/config.go b/config.go index 2cb968a..9371c7d 100644 --- a/config.go +++ b/config.go @@ -62,6 +62,7 @@ type OAuth2Config struct { PersistenceDir string `json:"persistenceDir,omitempty"` AllowedIPs []string `json:"allowedIPs,omitempty"` TokenExpirationMinutes int `json:"tokenExpirationMinutes,omitempty"` + TemplateDir string `json:"templateDir,omitempty"` } type OptionsV2 struct { diff --git a/main.go b/main.go index 526e5f7..a343425 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ func main() { version := flag.Bool("version", false, "print version and exit") help := flag.Bool("help", false, "print help and exit") ejectTemplates := flag.Bool("eject-templates", false, "eject OAuth templates to templates/oauth/ directory for customization") + ejectTemplatesTo := flag.String("eject-templates-to", "", "eject OAuth templates to specified directory (overrides config templateDir)") flag.Parse() if *help { flag.Usage() @@ -25,8 +26,27 @@ func main() { fmt.Println(BuildVersion) return } - if *ejectTemplates { - if err := ejectOAuthTemplates(); err != nil { + if *ejectTemplates || *ejectTemplatesTo != "" { + var templateDir string + + if *ejectTemplatesTo != "" { + // Use specified directory directly + templateDir = *ejectTemplatesTo + } else { + // Load config to get templateDir if configured + config, err := load(*conf, *insecure) + if err != nil { + log.Printf("Warning: Failed to load config for template directory: %v", err) + log.Printf("Using default templates directory") + templateDir = "templates" + } else if config.McpProxy.Options != nil && config.McpProxy.Options.OAuth2 != nil && config.McpProxy.Options.OAuth2.TemplateDir != "" { + templateDir = config.McpProxy.Options.OAuth2.TemplateDir + } else { + templateDir = "templates" + } + } + + if err := ejectOAuthTemplates(templateDir); err != nil { log.Fatalf("Failed to eject templates: %v", err) } return @@ -41,8 +61,8 @@ func main() { } } -func ejectOAuthTemplates() error { - templatesDir := "templates/oauth" +func ejectOAuthTemplates(baseTemplateDir string) error { + templatesDir := filepath.Join(baseTemplateDir, "oauth") // Create templates directory if err := os.MkdirAll(templatesDir, 0755); err != nil { diff --git a/oauth.go b/oauth.go index c99d2b7..a77938c 100644 --- a/oauth.go +++ b/oauth.go @@ -174,15 +174,22 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { var templates *template.Template var err error + // Determine template directory + templateDir := "templates/oauth" + if accessConfig != nil && accessConfig.TemplateDir != "" { + templateDir = filepath.Join(accessConfig.TemplateDir, "oauth") + } + // First try to load external templates (for customization) - if _, statErr := os.Stat("templates/oauth"); statErr == nil { - log.Printf("OAuth: Found external templates directory, loading custom templates") - templates, err = template.ParseGlob("templates/oauth/*.html") + if _, statErr := os.Stat(templateDir); statErr == nil { + templateGlob := filepath.Join(templateDir, "*.html") + log.Printf("OAuth: Found external templates directory at '%s', loading custom templates", templateDir) + templates, err = template.ParseGlob(templateGlob) if err != nil { - log.Printf("OAuth: Failed to load external templates: %v", err) + log.Printf("OAuth: Failed to load external templates from '%s': %v", templateDir, err) log.Printf("OAuth: Falling back to built-in templates") } else { - log.Printf("OAuth: Successfully loaded external templates") + log.Printf("OAuth: Successfully loaded external templates from '%s'", templateDir) } } From 332ba48138728e223f2b724fc0d8ab5d6767b03b Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 21 Aug 2025 07:34:04 -0400 Subject: [PATCH 06/11] feat: Add automatic hot reloading for external OAuth templates - Implement file modification time tracking for external templates - Add reloadTemplatesIfChanged() method to check and reload templates automatically - Hot reload triggers on every OAuth request when using external templates - Zero-config hot reloading - always enabled for external templates - Graceful fallback if template reload fails (keeps using previous templates) - Update documentation with hot reload feature explanation Benefits: - Live development: Edit templates and see changes immediately - No server restart required for template updates - Seamless template customization workflow - Only applies to external templates - built-in templates remain static for performance Hot reload detection: - Checks authorize.html and success.html modification times - Reloads entire template set if any file changed - Logs template reload events for debugging --- README.md | 11 +++++++ oauth.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 444157f..027cae9 100644 --- a/README.md +++ b/README.md @@ -427,6 +427,17 @@ The server loads templates in this priority order: 2. **Built-in templates**: Embedded defaults if external templates don't exist or fail to load 3. **To revert to built-in**: Remove the template directory or set `templateDir` to a non-existent path +#### Hot Reloading + +When using external templates, the server automatically detects file changes and reloads templates **without requiring a restart**: + +- **Automatic detection**: Checks file modification times on every OAuth request +- **Zero-config**: Hot reloading is always enabled for external templates +- **Live development**: Edit templates and see changes immediately in your browser +- **Fallback protection**: If reloading fails, continues using the previous templates + +This makes template customization seamless during development and testing. + #### Template Data **authorize.html** receives: diff --git a/oauth.go b/oauth.go index a77938c..9d32e79 100644 --- a/oauth.go +++ b/oauth.go @@ -30,6 +30,9 @@ type OAuthServer struct { persistenceFile string accessConfig *OAuth2Config templates *template.Template + templateDir string + useExternalTemplates bool + templateModTimes map[string]time.Time } type OAuthClient struct { @@ -180,16 +183,32 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { templateDir = filepath.Join(accessConfig.TemplateDir, "oauth") } + var useExternalTemplates bool + var templateModTimes map[string]time.Time + // First try to load external templates (for customization) if _, statErr := os.Stat(templateDir); statErr == nil { templateGlob := filepath.Join(templateDir, "*.html") - log.Printf("OAuth: Found external templates directory at '%s', loading custom templates", templateDir) + log.Printf("OAuth: Found external templates directory at '%s', loading custom templates with hot reload", templateDir) templates, err = template.ParseGlob(templateGlob) if err != nil { log.Printf("OAuth: Failed to load external templates from '%s': %v", templateDir, err) log.Printf("OAuth: Falling back to built-in templates") } else { log.Printf("OAuth: Successfully loaded external templates from '%s'", templateDir) + useExternalTemplates = true + templateModTimes = make(map[string]time.Time) + + // Record initial modification times + authPath := filepath.Join(templateDir, "authorize.html") + successPath := filepath.Join(templateDir, "success.html") + + if authStat, err := os.Stat(authPath); err == nil { + templateModTimes["authorize.html"] = authStat.ModTime() + } + if successStat, err := os.Stat(successPath); err == nil { + templateModTimes["success.html"] = successStat.ModTime() + } } } @@ -216,14 +235,17 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { } server := &OAuthServer{ - baseURL: baseURL, - clients: make(map[string]*OAuthClient), - authCodes: make(map[string]*AuthorizationCode), - accessTokens: make(map[string]*AccessToken), - tokenExpiration: tokenExpiration, - persistenceFile: persistenceFile, - accessConfig: accessConfig, - templates: templates, + baseURL: baseURL, + clients: make(map[string]*OAuthClient), + authCodes: make(map[string]*AuthorizationCode), + accessTokens: make(map[string]*AccessToken), + tokenExpiration: tokenExpiration, + persistenceFile: persistenceFile, + accessConfig: accessConfig, + templates: templates, + templateDir: templateDir, + useExternalTemplates: useExternalTemplates, + templateModTimes: templateModTimes, } // Load persisted clients @@ -326,6 +348,48 @@ func (s *OAuthServer) saveClients() { len(clients), len(accessTokens)) } +func (s *OAuthServer) reloadTemplatesIfChanged() { + if !s.useExternalTemplates || s.templateDir == "" { + return // Nothing to reload for built-in templates + } + + // Check modification times + authPath := filepath.Join(s.templateDir, "authorize.html") + successPath := filepath.Join(s.templateDir, "success.html") + + var needReload bool + + // Check authorize.html + if authStat, err := os.Stat(authPath); err == nil { + if lastMod, exists := s.templateModTimes["authorize.html"]; !exists || authStat.ModTime().After(lastMod) { + s.templateModTimes["authorize.html"] = authStat.ModTime() + needReload = true + } + } + + // Check success.html + if successStat, err := os.Stat(successPath); err == nil { + if lastMod, exists := s.templateModTimes["success.html"]; !exists || successStat.ModTime().After(lastMod) { + s.templateModTimes["success.html"] = successStat.ModTime() + needReload = true + } + } + + if needReload { + log.Printf("OAuth: Template files changed, reloading from '%s'", s.templateDir) + + // Reload templates + templateGlob := filepath.Join(s.templateDir, "*.html") + if newTemplates, err := template.ParseGlob(templateGlob); err != nil { + log.Printf("OAuth: Failed to reload templates: %v", err) + // Keep using existing templates + } else { + s.templates = newTemplates + log.Printf("OAuth: Templates reloaded successfully") + } + } +} + func (s *OAuthServer) generateRandomString(length int) string { bytes := make([]byte, length) rand.Read(bytes) @@ -542,6 +606,9 @@ func (s *OAuthServer) handleAuthorizationGET(w http.ResponseWriter, r *http.Requ } func (s *OAuthServer) showAuthorizationPage(w http.ResponseWriter, r *http.Request, clientID, redirectURI, responseType, scope, state, codeChallenge, resource, errorMsg string) { + // Check for template updates if using external templates + s.reloadTemplatesIfChanged() + // Skip client validation at authorization endpoint per Claude DCR spec // Client validation will happen at token endpoint where invalid_client triggers re-registration log.Printf("OAuth: Authorization request for client_id '%s' - proceeding to login", clientID) @@ -643,6 +710,9 @@ func (s *OAuthServer) handleAuthorizationPOST(w http.ResponseWriter, r *http.Req } func (s *OAuthServer) showSuccessPage(w http.ResponseWriter, r *http.Request, redirectURI, code, state, username string) { + // Check for template updates if using external templates + s.reloadTemplatesIfChanged() + // Build redirect URL redirectURL, _ := url.Parse(redirectURI) params := redirectURL.Query() From c3d8e32cb497648216e063df1feb463474582e9a Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Fri, 22 Aug 2025 23:16:13 -0400 Subject: [PATCH 07/11] Clean up auth --- Makefile | 6 +++++- oauth.go | 18 +++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 9cdf1ea..cf72149 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,11 @@ build: .PHONY: buildLinuxX86 buildLinuxX86: - GOOS=linux GOARCH=amd64 $(GO_BUILD) -o $(BUILD_DIR)/ ./... + GOOS=linux GOARCH=amd64 $(GO_BUILD) -o $(BUILD_DIR) ./... + +.PHONY: buildMacIntel +buildMacIntel: + GOOS=darwin GOARCH=amd64 $(GO_BUILD) -o $(BUILD_DIR)/mcp-proxy-intel ./... .PHONY: buildImage buildImage: diff --git a/oauth.go b/oauth.go index 9d32e79..41574c4 100644 --- a/oauth.go +++ b/oauth.go @@ -1097,15 +1097,15 @@ func (s *OAuthServer) ValidateToken(tokenString string) (*AccessToken, bool) { } if time.Now().After(token.ExpiresAt) { - // Token expired, clean it up - go func() { - s.mutex.Lock() - delete(s.accessTokens, tokenString) - s.mutex.Unlock() - - // Persist cleanup to disk - s.saveClients() - }() + // Token expired, clean it up synchronously to prevent race conditions + s.mutex.RUnlock() // Release read lock + s.mutex.Lock() // Acquire write lock + delete(s.accessTokens, tokenString) + s.mutex.Unlock() + + // Persist cleanup to disk + s.saveClients() + return nil, false } From ce09a7291380547729b8e980a3b8782bafc30dce Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Tue, 26 Aug 2025 13:24:15 -0400 Subject: [PATCH 08/11] Fix the lock --- oauth.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oauth.go b/oauth.go index 41574c4..519fb14 100644 --- a/oauth.go +++ b/oauth.go @@ -1089,10 +1089,10 @@ func (s *OAuthServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, // Token Validation func (s *OAuthServer) ValidateToken(tokenString string) (*AccessToken, bool) { s.mutex.RLock() - defer s.mutex.RUnlock() token, exists := s.accessTokens[tokenString] if !exists { + s.mutex.RUnlock() return nil, false } @@ -1109,6 +1109,7 @@ func (s *OAuthServer) ValidateToken(tokenString string) (*AccessToken, bool) { return nil, false } + s.mutex.RUnlock() return token, true } From b6e98a1353c8ade630a52362e0a1304db070faea Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Thu, 4 Sep 2025 10:55:58 -0400 Subject: [PATCH 09/11] Allow disable token expiration --- README.md | 4 +++- config.go | 1 + oauth.go | 64 +++++++++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 027cae9..55756ea 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ Common options for `mcpProxy` and `mcpServers`. - `persistenceDir`: Directory for storing OAuth client registrations. Defaults to `$HOME/.mcpproxy` if not specified. - `allowedIPs`: IP addresses permitted to register OAuth clients. Use Claude's official IPs for security. Empty array allows all IPs. - `tokenExpirationMinutes`: Access token expiration time in minutes. Defaults to 60 minutes (1 hour) if not specified. + - `disableTokenExpiration`: When set to `true`, disables token expiration entirely. Tokens will never expire. Overrides `tokenExpirationMinutes` setting. - `templateDir`: Base directory for OAuth HTML templates. Server looks for templates in `{templateDir}/oauth/`. Defaults to `templates` if not specified. - `toolFilter`: Optional tool filtering configuration. **This configuration is only effective in `mcpServers`.** - `mode`: Specifies the filtering mode. Must be explicitly set to `allow` or `block` if `list` is provided. If `list` is present but `mode` is missing or invalid, the filter will be ignored for this server. @@ -354,7 +355,8 @@ You can restrict OAuth client registration to specific IP addresses for enhanced "test": "test123" }, "allowedIPs": [], - "tokenExpirationMinutes": 60 + "tokenExpirationMinutes": 60, + "disableTokenExpiration": false } } } diff --git a/config.go b/config.go index 9371c7d..71db871 100644 --- a/config.go +++ b/config.go @@ -62,6 +62,7 @@ type OAuth2Config struct { PersistenceDir string `json:"persistenceDir,omitempty"` AllowedIPs []string `json:"allowedIPs,omitempty"` TokenExpirationMinutes int `json:"tokenExpirationMinutes,omitempty"` + DisableTokenExpiration bool `json:"disableTokenExpiration,omitempty"` TemplateDir string `json:"templateDir,omitempty"` } diff --git a/oauth.go b/oauth.go index 519fb14..7143a8b 100644 --- a/oauth.go +++ b/oauth.go @@ -27,6 +27,7 @@ type OAuthServer struct { accessTokens map[string]*AccessToken mutex sync.RWMutex tokenExpiration time.Duration + disableTokenExpiration bool persistenceFile string accessConfig *OAuth2Config templates *template.Template @@ -168,9 +169,17 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { // Set token expiration from config or default to 1 hour tokenExpiration := time.Hour // Default 1 hour - if accessConfig != nil && accessConfig.TokenExpirationMinutes > 0 { - tokenExpiration = time.Duration(accessConfig.TokenExpirationMinutes) * time.Minute - log.Printf("OAuth: Using custom token expiration: %v", tokenExpiration) + disableTokenExpiration := false + + if accessConfig != nil { + if accessConfig.DisableTokenExpiration { + disableTokenExpiration = true + tokenExpiration = 100 * 365 * 24 * time.Hour // Set to 100 years + log.Printf("OAuth: Token expiration disabled - tokens will not expire") + } else if accessConfig.TokenExpirationMinutes > 0 { + tokenExpiration = time.Duration(accessConfig.TokenExpirationMinutes) * time.Minute + log.Printf("OAuth: Using custom token expiration: %v", tokenExpiration) + } } // Load templates with fallback mechanism @@ -240,6 +249,7 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { authCodes: make(map[string]*AuthorizationCode), accessTokens: make(map[string]*AccessToken), tokenExpiration: tokenExpiration, + disableTokenExpiration: disableTokenExpiration, persistenceFile: persistenceFile, accessConfig: accessConfig, templates: templates, @@ -278,13 +288,19 @@ func (s *OAuthServer) loadClients() { s.mutex.Lock() s.clients = persistenceData.Clients - // Load tokens, filtering out expired ones + // Load tokens, filtering out expired ones only if expiration is enabled validAccessTokens := make(map[string]*AccessToken) - now := time.Now() - for token, accessToken := range persistenceData.AccessTokens { - if accessToken.ExpiresAt.After(now) { - validAccessTokens[token] = accessToken + if s.disableTokenExpiration { + // Keep all tokens if expiration is disabled + validAccessTokens = persistenceData.AccessTokens + } else { + // Filter out expired tokens + now := time.Now() + for token, accessToken := range persistenceData.AccessTokens { + if accessToken.ExpiresAt.After(now) { + validAccessTokens[token] = accessToken + } } } @@ -517,12 +533,25 @@ func (s *OAuthServer) handleClientRegistration(w http.ResponseWriter, r *http.Re for _, uri := range req.RedirectURIs { validURI := false + + // Check exact matches first for _, allowed := range allowedCallbackURLs { if uri == allowed { validURI = true break } } + + // If not an exact match, check if it's a localhost callback for Claude Code + if !validURI { + if parsedURI, err := url.Parse(uri); err == nil { + if parsedURI.Scheme == "http" && + parsedURI.Hostname() == "localhost" { + validURI = true + } + } + } + if !validURI { log.Printf("OAuth: Client registration failed - invalid redirect URI: %s", uri) s.writeOAuthError(w, "invalid_redirect_uri", "Redirect URI not allowed", http.StatusBadRequest) @@ -874,10 +903,16 @@ func (s *OAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { // Persist tokens to disk s.saveClients() + // Set expires_in to 0 when expiration is disabled (RFC 6749 - 0 means no expiration) + expiresIn := int(s.tokenExpiration.Seconds()) + if s.disableTokenExpiration { + expiresIn = 0 + } + response := TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", - ExpiresIn: int(s.tokenExpiration.Seconds()), + ExpiresIn: expiresIn, RefreshToken: refreshToken, Scope: authCode.Scope, } @@ -1074,10 +1109,16 @@ func (s *OAuthServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, log.Printf("OAuth: Refreshed tokens for client %s", clientID) + // Set expires_in to 0 when expiration is disabled (RFC 6749 - 0 means no expiration) + expiresIn := int(s.tokenExpiration.Seconds()) + if s.disableTokenExpiration { + expiresIn = 0 + } + response := TokenResponse{ AccessToken: newAccessToken, TokenType: "Bearer", - ExpiresIn: int(s.tokenExpiration.Seconds()), + ExpiresIn: expiresIn, RefreshToken: newRefreshToken, Scope: oldToken.Scope, } @@ -1096,7 +1137,8 @@ func (s *OAuthServer) ValidateToken(tokenString string) (*AccessToken, bool) { return nil, false } - if time.Now().After(token.ExpiresAt) { + // Skip expiration check if token expiration is disabled + if !s.disableTokenExpiration && time.Now().After(token.ExpiresAt) { // Token expired, clean it up synchronously to prevent race conditions s.mutex.RUnlock() // Release read lock s.mutex.Lock() // Acquire write lock From dd0a2ede200292c4fa54acfdc9096cb7cb42b92c Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Tue, 23 Sep 2025 21:10:03 -0400 Subject: [PATCH 10/11] More token expiration changes --- http.go | 59 ++++++++++++++++++++++++++++++++++++-------------------- oauth.go | 6 +++--- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/http.go b/http.go index 7b28725..ffa11a5 100644 --- a/http.go +++ b/http.go @@ -50,11 +50,11 @@ func newAuthMiddleware(tokens []string) MiddlewareFunc { } } -func newOAuth2Middleware(oauth2Config *OAuth2Config, oauthServer *OAuthServer) MiddlewareFunc { - if oauth2Config == nil || !oauth2Config.Enabled { - return func(next http.Handler) http.Handler { - return next - } +func newCombinedAuthMiddleware(authTokens []string, oauth2Config *OAuth2Config, oauthServer *OAuthServer) MiddlewareFunc { + // Create token set for fast lookup + tokenSet := make(map[string]struct{}, len(authTokens)) + for _, token := range authTokens { + tokenSet[token] = struct{}{} } return func(next http.Handler) http.Handler { @@ -65,7 +65,7 @@ func newOAuth2Middleware(oauth2Config *OAuth2Config, oauthServer *OAuthServer) M return } - // Check for Bearer token (OAuth 2.1) + // Check for Bearer token if strings.HasPrefix(authHeader, "Bearer ") { token := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) if token == "" { @@ -73,19 +73,32 @@ func newOAuth2Middleware(oauth2Config *OAuth2Config, oauthServer *OAuthServer) M return } - // Validate token with OAuth server - accessToken, valid := oauthServer.ValidateToken(token) - if !valid { - http.Error(w, "Invalid or expired access token", http.StatusUnauthorized) - return + // First, try predefined auth tokens (fastest check) + if len(authTokens) > 0 { + if _, ok := tokenSet[token]; ok { + log.Printf("Request authenticated with predefined bearer token") + next.ServeHTTP(w, r) + return + } } - // Add token info to request context for potential use - ctx := context.WithValue(r.Context(), "X-OAuth-Client-ID", accessToken.ClientID) - ctx = context.WithValue(ctx, "X-OAuth-Scope", accessToken.Scope) - ctx = context.WithValue(ctx, "X-OAuth-Resource", accessToken.Resource) + // Second, try OAuth validation if enabled + if oauth2Config != nil && oauth2Config.Enabled && oauthServer != nil { + accessToken, valid := oauthServer.ValidateToken(token) + if valid { + log.Printf("Request authenticated with OAuth token for client: %s", accessToken.ClientID) + // Add token info to request context for potential use + ctx := context.WithValue(r.Context(), "X-OAuth-Client-ID", accessToken.ClientID) + ctx = context.WithValue(ctx, "X-OAuth-Scope", accessToken.Scope) + ctx = context.WithValue(ctx, "X-OAuth-Resource", accessToken.Resource) + + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } - next.ServeHTTP(w, r.WithContext(ctx)) + // If neither auth method worked + http.Error(w, "Invalid or expired access token", http.StatusUnauthorized) return } @@ -173,12 +186,16 @@ func startHTTPServer(config *Config) error { middlewares = append(middlewares, loggerMiddleware(name)) } - // Apply authentication middleware based on proxy configuration - // OAuth2 authentication applies when the proxy itself uses streamable-http transport - if config.McpProxy.Type == MCPServerTypeStreamable && config.McpProxy.Options.OAuth2 != nil && config.McpProxy.Options.OAuth2.Enabled { - middlewares = append(middlewares, newOAuth2Middleware(config.McpProxy.Options.OAuth2, oauthServer)) + // Apply combined authentication middleware (supports both predefined tokens and OAuth) + if config.McpProxy.Type == MCPServerTypeStreamable && + ((config.McpProxy.Options.OAuth2 != nil && config.McpProxy.Options.OAuth2.Enabled) || len(clientConfig.Options.AuthTokens) > 0) { + middlewares = append(middlewares, newCombinedAuthMiddleware( + clientConfig.Options.AuthTokens, + config.McpProxy.Options.OAuth2, + oauthServer, + )) } else if len(clientConfig.Options.AuthTokens) > 0 { - // Fall back to token authentication if OAuth2 is not configured + // For non-streamable transports, use simple auth middleware middlewares = append(middlewares, newAuthMiddleware(clientConfig.Options.AuthTokens)) } mcpRoute := path.Join(baseURL.Path, name) diff --git a/oauth.go b/oauth.go index 7143a8b..19fb7e4 100644 --- a/oauth.go +++ b/oauth.go @@ -172,7 +172,7 @@ func NewOAuthServer(baseURL string, accessConfig *OAuth2Config) *OAuthServer { disableTokenExpiration := false if accessConfig != nil { - if accessConfig.DisableTokenExpiration { + if accessConfig.DisableTokenExpiration || accessConfig.TokenExpirationMinutes == 0 { disableTokenExpiration = true tokenExpiration = 100 * 365 * 24 * time.Hour // Set to 100 years log.Printf("OAuth: Token expiration disabled - tokens will not expire") @@ -545,8 +545,8 @@ func (s *OAuthServer) handleClientRegistration(w http.ResponseWriter, r *http.Re // If not an exact match, check if it's a localhost callback for Claude Code if !validURI { if parsedURI, err := url.Parse(uri); err == nil { - if parsedURI.Scheme == "http" && - parsedURI.Hostname() == "localhost" { + if parsedURI.Scheme == "http" && + (parsedURI.Hostname() == "localhost" || parsedURI.Hostname() == "127.0.0.1") { validURI = true } } From cce39166982fa69bbb5250b33e5004585d2c3a2d Mon Sep 17 00:00:00 2001 From: j-mcnally Date: Tue, 23 Sep 2025 22:53:36 -0400 Subject: [PATCH 11/11] Per server user restrictions --- Makefile | 2 +- USER_SERVER_FILTERING.md | 128 +++++++++++++++++++++ config.go | 42 +++++++ example_config_with_user_restrictions.json | 56 +++++++++ http.go | 37 +++++- oauth.go | 16 ++- 6 files changed, 274 insertions(+), 7 deletions(-) create mode 100644 USER_SERVER_FILTERING.md create mode 100644 example_config_with_user_restrictions.json diff --git a/Makefile b/Makefile index cf72149..162bb49 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ build: .PHONY: buildLinuxX86 buildLinuxX86: - GOOS=linux GOARCH=amd64 $(GO_BUILD) -o $(BUILD_DIR) ./... + GOOS=linux GOARCH=amd64 $(GO_BUILD) -o $(BUILD_DIR)/mcp-proxy-linux-amd64 ./... .PHONY: buildMacIntel buildMacIntel: diff --git a/USER_SERVER_FILTERING.md b/USER_SERVER_FILTERING.md new file mode 100644 index 0000000..4bf0f16 --- /dev/null +++ b/USER_SERVER_FILTERING.md @@ -0,0 +1,128 @@ +# Per-User Server Access Control + +This feature allows you to restrict which MCP servers individual users can access when using OAuth 2.1 authentication. It uses per-server user filters similar to the existing tool filter system. + +## Configuration + +Add `userFilter` options to individual MCP servers in your `config.json`: + +```json +{ + "mcpProxy": { + "options": { + "oauth2": { + "enabled": true, + "users": { + "alice": "password123", + "bob": "password456", + "admin": "adminpass" + } + } + } + }, + "mcpServers": { + "server1": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "options": { + "userFilter": { + "mode": "allow", + "list": ["alice", "admin"] + } + } + }, + "server2": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-git"], + "options": { + "userFilter": { + "mode": "allow", + "list": ["alice", "bob", "admin"] + } + } + }, + "server3": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-brave-search"], + "options": { + "userFilter": { + "mode": "block", + "list": ["alice"] + } + } + } + } +} +``` + +## User Filter Modes + +### Allow Mode (`"mode": "allow"`) +- **Purpose**: Only allow specified users +- **Behavior**: Users must be in the `list` to access the server +- **Example**: `{"mode": "allow", "list": ["alice", "bob"]}` - only alice and bob can access + +### Block Mode (`"mode": "block"`) +- **Purpose**: Block specified users, allow all others +- **Behavior**: Users in the `list` are denied access, everyone else is allowed +- **Example**: `{"mode": "block", "list": ["alice"]}` - alice is blocked, others can access + +## Example Access Patterns + +Based on the configuration above: + +- **server1**: Only `alice` and `admin` can access (allow mode) +- **server2**: `alice`, `bob`, and `admin` can access (allow mode) +- **server3**: Everyone except `alice` can access (block mode) + +So the effective access is: +- **alice**: Can access `server1` and `server2`, blocked from `server3` +- **bob**: Can access `server2` and `server3`, blocked from `server1` +- **admin**: Can access `server1` and `server2` and `server3` + +## How It Works + +1. When a user authenticates via OAuth 2.1, their username is stored in the access token +2. Each request includes the username in the request context +3. The `newServerAccessMiddleware` checks the server's `userFilter` configuration +4. If access is denied, the request returns HTTP 403 Forbidden + +## Default Behavior + +- **No userFilter**: All authenticated users have access (backward compatibility) +- **Empty list**: Behavior depends on mode: + - Allow mode with empty list: No users allowed + - Block mode with empty list: All users allowed +- **Token-based authentication**: Bypasses user-specific restrictions (no username available) + +## Testing + +Use the provided `example_config_with_user_restrictions.json` to test the feature: + +1. Start the server: `./build/mcp-proxy --config example_config_with_user_restrictions.json` +2. Authenticate as different users and try accessing different server endpoints +3. Verify that access is properly restricted based on each server's configuration + +## Error Messages + +When access is denied, users will see: +``` +HTTP 403 Forbidden +Access denied: You don't have permission to access this server +``` + +## Logging + +The server logs access decisions with filter details: +``` +User alice granted access to server1 +User bob denied access to server1 (mode: allow, list: [alice admin]) +``` + +## Benefits of Per-Server Approach + +- **Granular Control**: Each server can have different user access rules +- **Flexible Modes**: Use allow-lists for restricted servers, block-lists for open servers +- **Consistent API**: Follows the same pattern as `toolFilter` +- **Independent Configuration**: Server access rules are self-contained +- **Mix and Match**: Some servers can be unrestricted while others have filters \ No newline at end of file diff --git a/config.go b/config.go index 71db871..7c8dc49 100644 --- a/config.go +++ b/config.go @@ -56,6 +56,47 @@ type ToolFilterConfig struct { List []string `json:"list,omitempty"` } +type UserFilterMode string + +const ( + UserFilterModeAllow UserFilterMode = "allow" + UserFilterModeBlock UserFilterMode = "block" +) + +type UserFilterConfig struct { + Mode UserFilterMode `json:"mode,omitempty"` + List []string `json:"list,omitempty"` +} + +// IsUserAllowed checks if a user is allowed based on the user filter configuration +func (ufc *UserFilterConfig) IsUserAllowed(username string) bool { + if ufc == nil || username == "" { + // No filter configured or empty username - allow by default + return true + } + + // Check if username is in the list + userInList := false + for _, user := range ufc.List { + if user == username { + userInList = true + break + } + } + + switch ufc.Mode { + case UserFilterModeAllow: + // Allow mode: user must be in the list to be allowed + return userInList + case UserFilterModeBlock: + // Block mode: user must NOT be in the list to be allowed + return !userInList + default: + // No mode specified or unknown mode - allow by default + return true + } +} + type OAuth2Config struct { Enabled bool `json:"enabled,omitempty"` Users map[string]string `json:"users,omitempty"` @@ -72,6 +113,7 @@ type OptionsV2 struct { AuthTokens []string `json:"authTokens,omitempty"` OAuth2 *OAuth2Config `json:"oauth2,omitempty"` ToolFilter *ToolFilterConfig `json:"toolFilter,omitempty"` + UserFilter *UserFilterConfig `json:"userFilter,omitempty"` } type MCPProxyConfigV2 struct { diff --git a/example_config_with_user_restrictions.json b/example_config_with_user_restrictions.json new file mode 100644 index 0000000..ec7932c --- /dev/null +++ b/example_config_with_user_restrictions.json @@ -0,0 +1,56 @@ +{ + "mcpProxy": { + "baseURL": "http://localhost:9090", + "addr": ":9090", + "name": "MCP Proxy with User Restrictions", + "version": "1.0.0", + "type": "streamable-http", + "options": { + "oauth2": { + "enabled": true, + "users": { + "alice": "password123", + "bob": "password456", + "admin": "adminpass" + }, + "disableTokenExpiration": true, + "persistenceDir": ".mcpproxy-test" + } + } + }, + "mcpServers": { + "server1": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "options": { + "logEnabled": true, + "userFilter": { + "mode": "allow", + "list": ["alice", "admin"] + } + } + }, + "server2": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-git", "--repository", "."], + "options": { + "logEnabled": true, + "userFilter": { + "mode": "allow", + "list": ["alice", "bob", "admin"] + } + } + }, + "server3": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-brave-search"], + "options": { + "logEnabled": true, + "userFilter": { + "mode": "block", + "list": ["alice"] + } + } + } + } +} \ No newline at end of file diff --git a/http.go b/http.go index ffa11a5..c62e035 100644 --- a/http.go +++ b/http.go @@ -86,11 +86,12 @@ func newCombinedAuthMiddleware(authTokens []string, oauth2Config *OAuth2Config, if oauth2Config != nil && oauth2Config.Enabled && oauthServer != nil { accessToken, valid := oauthServer.ValidateToken(token) if valid { - log.Printf("Request authenticated with OAuth token for client: %s", accessToken.ClientID) + log.Printf("Request authenticated with OAuth token for client: %s, username: %s", accessToken.ClientID, accessToken.Username) // Add token info to request context for potential use ctx := context.WithValue(r.Context(), "X-OAuth-Client-ID", accessToken.ClientID) ctx = context.WithValue(ctx, "X-OAuth-Scope", accessToken.Scope) ctx = context.WithValue(ctx, "X-OAuth-Resource", accessToken.Resource) + ctx = context.WithValue(ctx, "X-OAuth-Username", accessToken.Username) next.ServeHTTP(w, r.WithContext(ctx)) return @@ -130,6 +131,36 @@ func recoverMiddleware(prefix string) MiddlewareFunc { } } +func newServerAccessMiddleware(serverName string, userFilter *UserFilterConfig) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip access control if no user filter configured + if userFilter == nil { + next.ServeHTTP(w, r) + return + } + + username, ok := r.Context().Value("X-OAuth-Username").(string) + if !ok || username == "" { + // No OAuth username in context, proceed normally (this might be token-based auth) + next.ServeHTTP(w, r) + return + } + + // Check if user has access to this server based on user filter + if !userFilter.IsUserAllowed(username) { + log.Printf("User %s denied access to server %s (mode: %s, list: %v)", + username, serverName, userFilter.Mode, userFilter.List) + http.Error(w, "Access denied: You don't have permission to access this server", http.StatusForbidden) + return + } + + log.Printf("User %s granted access to server %s", username, serverName) + next.ServeHTTP(w, r) + }) + } +} + func startHTTPServer(config *Config) error { baseURL, uErr := url.Parse(config.McpProxy.BaseURL) if uErr != nil { @@ -194,6 +225,10 @@ func startHTTPServer(config *Config) error { config.McpProxy.Options.OAuth2, oauthServer, )) + // Add server access middleware for OAuth-enabled servers with user filters + if config.McpProxy.Options.OAuth2 != nil && config.McpProxy.Options.OAuth2.Enabled && clientConfig.Options.UserFilter != nil { + middlewares = append(middlewares, newServerAccessMiddleware(name, clientConfig.Options.UserFilter)) + } } else if len(clientConfig.Options.AuthTokens) > 0 { // For non-streamable transports, use simple auth middleware middlewares = append(middlewares, newAuthMiddleware(clientConfig.Options.AuthTokens)) diff --git a/oauth.go b/oauth.go index 19fb7e4..97f010b 100644 --- a/oauth.go +++ b/oauth.go @@ -53,6 +53,7 @@ type AuthorizationCode struct { CodeChallenge string // PKCE challenge ExpiresAt time.Time Resource string + Username string } type AccessToken struct { @@ -62,6 +63,7 @@ type AccessToken struct { Scope string Resource string ExpiresAt time.Time + Username string } // OAuth Server Metadata Response @@ -726,6 +728,7 @@ func (s *OAuthServer) handleAuthorizationPOST(w http.ResponseWriter, r *http.Req CodeChallenge: codeChallenge, ExpiresAt: time.Now().Add(10 * time.Minute), Resource: resource, + Username: username, } s.mutex.Lock() @@ -894,6 +897,7 @@ func (s *OAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { Scope: authCode.Scope, Resource: resource, ExpiresAt: time.Now().Add(s.tokenExpiration), + Username: authCode.Username, } s.mutex.Lock() @@ -903,10 +907,10 @@ func (s *OAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { // Persist tokens to disk s.saveClients() - // Set expires_in to 0 when expiration is disabled (RFC 6749 - 0 means no expiration) + // Set expires_in to 5 years when expiration is disabled (clients handle this better than 0) expiresIn := int(s.tokenExpiration.Seconds()) if s.disableTokenExpiration { - expiresIn = 0 + expiresIn = int((5 * 365 * 24 * time.Hour).Seconds()) // 5 years } response := TokenResponse{ @@ -1098,6 +1102,7 @@ func (s *OAuthServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, Scope: oldToken.Scope, Resource: oldToken.Resource, ExpiresAt: time.Now().Add(s.tokenExpiration), + Username: oldToken.Username, } s.mutex.Lock() @@ -1109,10 +1114,10 @@ func (s *OAuthServer) handleRefreshToken(w http.ResponseWriter, r *http.Request, log.Printf("OAuth: Refreshed tokens for client %s", clientID) - // Set expires_in to 0 when expiration is disabled (RFC 6749 - 0 means no expiration) + // Set expires_in to 5 years when expiration is disabled (clients handle this better than 0) expiresIn := int(s.tokenExpiration.Seconds()) if s.disableTokenExpiration { - expiresIn = 0 + expiresIn = int((5 * 365 * 24 * time.Hour).Seconds()) // 5 years } response := TokenResponse{ @@ -1164,6 +1169,7 @@ func (s *OAuthServer) writeOAuthError(w http.ResponseWriter, error, description }) } + // Register OAuth routes func (s *OAuthServer) RegisterRoutes(mux *http.ServeMux) { // Global OAuth endpoints @@ -1171,7 +1177,7 @@ func (s *OAuthServer) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/oauth/register", s.handleClientRegistration) mux.HandleFunc("/oauth/authorize", s.handleAuthorization) mux.HandleFunc("/oauth/token", s.handleToken) - + // Per-server OAuth discovery endpoints mux.HandleFunc("/.well-known/oauth-authorization-server/", s.handleServerMetadata) mux.HandleFunc("/.well-known/oauth-protected-resource/", s.handleProtectedResourceMetadata)