From 61b149f1995bfc2fb67facdfc1f6ed56cd834e6c Mon Sep 17 00:00:00 2001 From: Justin Chang Date: Fri, 13 Feb 2026 19:34:12 -0800 Subject: [PATCH 1/5] Add dynamic OAuth discovery for community MCP servers Community servers from third-party registries lack oauth.providers metadata, so the existing IsRemoteOAuthServer() gate prevents DCR entries from being created and OAuth tokens from being consumed. This adds support for the full dynamic OAuth lifecycle: discovery, token attachment, provider refresh, and connection invalidation. Discovery (DCR registration): - Add RegisterProviderForDynamicDiscovery() in pkg/oauth/dcr_registration.go - Add fallback branch in workingset, server enable, and gateway mcpadd Token consumption: - Attach stored OAuth tokens for community servers during Initialize (fallback when Spec.OAuth is nil but Remote.URL is set) - Start OAuth provider refresh loop for community servers with stored tokens - Fix InvalidateOAuthClients to match remote servers by name, not Spec.OAuth --- cmd/docker-mcp/oauth/revoke.go | 2 +- cmd/docker-mcp/server/enable.go | 9 ++ pkg/gateway/clientpool.go | 33 +++-- pkg/gateway/clientpool_test.go | 200 +++++++++++++++++++++++++++++ pkg/gateway/mcpadd.go | 18 ++- pkg/gateway/run.go | 18 ++- pkg/mcp/remote.go | 11 ++ pkg/mcp/remote_test.go | 144 +++++++++++++++++++++ pkg/oauth/dcr_registration.go | 51 ++++++++ pkg/oauth/dcr_registration_test.go | 96 ++++++++++++++ pkg/oauth/provider.go | 6 +- pkg/workingset/oauth.go | 21 +-- pkg/workingset/server.go | 4 +- 13 files changed, 580 insertions(+), 33 deletions(-) create mode 100644 pkg/mcp/remote_test.go create mode 100644 pkg/oauth/dcr_registration_test.go diff --git a/cmd/docker-mcp/oauth/revoke.go b/cmd/docker-mcp/oauth/revoke.go index e7c03f698..448b1ae9a 100644 --- a/cmd/docker-mcp/oauth/revoke.go +++ b/cmd/docker-mcp/oauth/revoke.go @@ -26,7 +26,7 @@ func Revoke(ctx context.Context, app string) error { func revokeDesktopMode(ctx context.Context, app string) error { client := desktop.NewAuthClient() - // Revoke tokens + // Revoke tokens via Docker Desktop if err := client.DeleteOAuthApp(ctx, app); err != nil { return fmt.Errorf("failed to revoke OAuth access: %w", err) } diff --git a/cmd/docker-mcp/server/enable.go b/cmd/docker-mcp/server/enable.go index 62cbd834c..0c045aaa4 100644 --- a/cmd/docker-mcp/server/enable.go +++ b/cmd/docker-mcp/server/enable.go @@ -79,6 +79,15 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad fmt.Printf("Server %s requires OAuth authentication but DCR is disabled.\n", serverName) fmt.Printf(" To enable automatic OAuth setup, run: docker mcp feature enable mcp-oauth-dcr\n") fmt.Printf(" Or set up OAuth manually using: docker mcp oauth authorize %s\n", serverName) + } else if mcpOAuthDcrEnabled && server.Type == "remote" && !server.IsOAuthServer() && server.Remote.URL != "" { + // Community server without oauth.providers — probe for OAuth + if pkgoauth.IsCEMode() { + fmt.Printf("Remote server %s enabled. Run 'docker mcp oauth authorize %s' if authentication is required\n", serverName, serverName) + } else { + if err := pkgoauth.RegisterProviderForDynamicDiscovery(ctx, serverName, server.Remote.URL); err != nil { + fmt.Printf("Warning: Dynamic OAuth discovery failed for %s: %v\n", serverName, err) + } + } } } else { return fmt.Errorf("server %s not found in catalog", serverName) diff --git a/pkg/gateway/clientpool.go b/pkg/gateway/clientpool.go index 66f97231d..b9afe3268 100644 --- a/pkg/gateway/clientpool.go +++ b/pkg/gateway/clientpool.go @@ -179,24 +179,23 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) { var invalidatedKeys []clientKey for key, keptClient := range cp.keptClients { - // Check if this client uses OAuth for the specified provider - if keptClient.Config.Spec.OAuth != nil { - // Match by server name (for DCR providers, server name matches provider) - if keptClient.Config.Name == provider { - log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name)) - - // Close the connection - client, err := keptClient.Getter.GetClient(context.TODO()) - if err == nil { - client.Session().Close() - log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name)) - } else { - log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err)) - } - - // Mark for removal from kept clients - invalidatedKeys = append(invalidatedKeys, key) + // Check if this remote client matches the OAuth provider + // Matches both catalog servers (explicit OAuth metadata) and community servers + // (dynamic OAuth discovery via DCR without Spec.OAuth) + if keptClient.Config.Name == provider && keptClient.Config.IsRemote() { + log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name)) + + // Close the connection + client, err := keptClient.Getter.GetClient(context.TODO()) + if err == nil { + client.Session().Close() + log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name)) + } else { + log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err)) } + + // Mark for removal from kept clients + invalidatedKeys = append(invalidatedKeys, key) } } diff --git a/pkg/gateway/clientpool_test.go b/pkg/gateway/clientpool_test.go index ef1fbccd5..b4db37cdb 100644 --- a/pkg/gateway/clientpool_test.go +++ b/pkg/gateway/clientpool_test.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "os" "testing" "time" @@ -277,6 +278,205 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any { return config } +func readOnly() *bool { + return boolPtr(true) +} + +func boolPtr(b bool) *bool { + return &b +} + +func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) { + // Community server: remote URL set, but no Spec.OAuth metadata. + // This verifies Gap 3: InvalidateOAuthClients matches community servers + // that use dynamic OAuth discovery without explicit OAuth config. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) // mark as executed + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "com-notion-mcp"} + cp.keptClients[key] = keptClient{ + Name: "com-notion-mcp", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "com-notion-mcp", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://mcp.notion.so/mcp", + Transport: "streamable-http", + }, + // No OAuth field - community server + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Empty(t, cp.keptClients, "community server should be invalidated by name") +} + +func TestInvalidateOAuthClients_MatchesCatalogServer(t *testing.T) { + // Catalog server: remote URL set WITH Spec.OAuth metadata. + // Verifies backward compatibility: catalog servers still get invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "notion-remote"} + cp.keptClients[key] = keptClient{ + Name: "notion-remote", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "notion-remote", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://mcp.notion.so/mcp", + Transport: "streamable-http", + }, + OAuth: &catalog.OAuth{ + Providers: []catalog.OAuthProvider{{Provider: "notion"}}, + }, + }, + }, + } + + cp.InvalidateOAuthClients("notion-remote") + + assert.Empty(t, cp.keptClients, "catalog server should be invalidated by name") +} + +func TestInvalidateOAuthClients_SkipsNonRemoteServer(t *testing.T) { + // Docker container server: not remote, should NOT be invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "my-container-server"} + cp.keptClients[key] = keptClient{ + Name: "my-container-server", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "my-container-server", + Spec: catalog.Server{ + Type: "server", + Image: "mcp/my-server:latest", + // Not remote - no URL + }, + }, + } + + cp.InvalidateOAuthClients("my-container-server") + + assert.Len(t, cp.keptClients, 1, "non-remote server should NOT be invalidated") +} + +func TestInvalidateOAuthClients_SkipsMismatchedName(t *testing.T) { + // Remote server with different name: should NOT be invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "other-server"} + cp.keptClients[key] = keptClient{ + Name: "other-server", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "other-server", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://other.example.com/mcp", + }, + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Len(t, cp.keptClients, 1, "server with different name should NOT be invalidated") +} + +func TestInvalidateOAuthClients_OnlyMatchingRemoved(t *testing.T) { + // Multiple clients: only the matching remote server should be removed. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + makeGetter := func() *clientGetter { + g := &clientGetter{} + g.once.Do(func() {}) + g.err = fmt.Errorf("mock: no real client") + return g + } + + // Community OAuth server (should be invalidated) + cp.keptClients[clientKey{serverName: "com-notion-mcp"}] = keptClient{ + Name: "com-notion-mcp", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "com-notion-mcp", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{URL: "https://mcp.notion.so/mcp"}, + }, + }, + } + + // Different remote server (should NOT be invalidated) + cp.keptClients[clientKey{serverName: "github-remote"}] = keptClient{ + Name: "github-remote", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "github-remote", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{URL: "https://mcp.github.com/mcp"}, + }, + }, + } + + // Docker container server (should NOT be invalidated) + cp.keptClients[clientKey{serverName: "local-server"}] = keptClient{ + Name: "local-server", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "local-server", + Spec: catalog.Server{ + Type: "server", + Image: "mcp/local:latest", + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Len(t, cp.keptClients, 2, "only the matching remote server should be removed") + _, hasNotion := cp.keptClients[clientKey{serverName: "com-notion-mcp"}] + assert.False(t, hasNotion, "com-notion-mcp should have been removed") + _, hasGithub := cp.keptClients[clientKey{serverName: "github-remote"}] + assert.True(t, hasGithub, "github-remote should remain") + _, hasLocal := cp.keptClients[clientKey{serverName: "local-server"}] + assert.True(t, hasLocal, "local-server should remain") +} + func TestStdioClientInitialization(t *testing.T) { // This is an integration test that requires Docker if testing.Short() { diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index e89f26490..94b364b97 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -289,7 +289,8 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler { // Handle OAuth DCR only when the client supports elicitation (e.g. not stdio-based clients) if g.McpOAuthDcrEnabled && serverConfig != nil && - serverConfig.Spec.IsRemoteOAuthServer() { + (serverConfig.Spec.IsRemoteOAuthServer() || + (serverConfig.Spec.Type == "remote" && !serverConfig.Spec.IsOAuthServer() && serverConfig.Spec.Remote.URL != "")) { init := req.Session.InitializeParams() if init != nil && @@ -444,7 +445,20 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str if !providerExists { // Register DCR client with DD so user can authorize if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { - log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + // Fallback: try dynamic discovery for community servers without oauth.providers + if serverConfig, _, found := g.configuration.Find(serverName); found && serverConfig.Spec.Remote.URL != "" { + if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverConfig.Spec.Remote.URL); err != nil { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + } else { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + } + + // Verify DCR entry was created — dynamic discovery may have found no OAuth requirement + authClient := desktop.NewAuthClient() + if _, err := authClient.GetDCRClient(ctx, serverName); err != nil { + return true, "" // Server doesn't require OAuth } // Start provider (CE mode only - Desktop mode doesn't need polling) diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index fe53c344b..8f51c3a1b 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -352,13 +352,23 @@ func (g *Gateway) Run(ctx context.Context) error { // Start OAuth provider for each OAuth server. // Each provider runs in its own goroutine with dynamic timing based on token expiry. log.Log("- Starting OAuth provider loops...") + credHelper := oauth.NewOAuthCredentialHelper() for _, serverName := range configuration.ServerNames() { serverConfig, _, found := configuration.Find(serverName) - if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() { + if !found || serverConfig == nil { continue } - g.startProvider(ctx, serverName) + if serverConfig.Spec.IsRemoteOAuthServer() { + g.startProvider(ctx, serverName) + } else if serverConfig.IsRemote() { + // Community servers: start provider if they have a stored OAuth token + // from dynamic discovery (DCR without explicit OAuth metadata) + if exists, _ := credHelper.TokenExists(ctx, serverName); exists { + log.Logf("- Starting OAuth provider for community server: %s", serverName) + g.startProvider(ctx, serverName) + } + } } } @@ -697,7 +707,9 @@ func (g *Gateway) routeEventToProvider(event oauth.Event) { g.clientPool.InvalidateOAuthClients(event.Provider) case oauth.EventLogoutSuccess: - // User logged out - stop provider if exists + // Invalidate cached OAuth client connections (clear stale bearer tokens) + g.clientPool.InvalidateOAuthClients(event.Provider) + // Stop provider if exists if exists { log.Logf("- Stopping provider for %s after logout", event.Provider) g.stopProvider(event.Provider) diff --git a/pkg/mcp/remote.go b/pkg/mcp/remote.go index 9931c1c2f..2bf706187 100644 --- a/pkg/mcp/remote.go +++ b/pkg/mcp/remote.go @@ -93,6 +93,17 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParam } else if token != "" { headers["Authorization"] = "Bearer " + token } + } else if c.config.Spec.Remote.URL != "" { + // Community servers may have OAuth tokens via dynamic discovery (DCR) + // without explicit OAuth metadata in the catalog. Try to get a stored token. + credHelper := oauth.NewOAuthCredentialHelper() + token, err := credHelper.GetOAuthToken(ctx, c.config.Name) + if err == nil && token != "" { + if verbose { + log.Logf(" - Using dynamic OAuth token for: %s", c.config.Name) + } + headers["Authorization"] = "Bearer " + token + } } var mcpTransport mcp.Transport diff --git a/pkg/mcp/remote_test.go b/pkg/mcp/remote_test.go new file mode 100644 index 000000000..53bdc2467 --- /dev/null +++ b/pkg/mcp/remote_test.go @@ -0,0 +1,144 @@ +package mcp + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// roundTripFunc is an adapter to use functions as http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestHeaderRoundTripper_AttachesAuthorizationHeader(t *testing.T) { + // Verifies that headerRoundTripper propagates Authorization headers to requests. + // This is the mechanism through which OAuth tokens (both catalog and dynamic) reach + // the remote MCP server. + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer test-oauth-token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer test-oauth-token", capturedReq.Header.Get("Authorization")) +} + +func TestHeaderRoundTripper_DoesNotOverrideExistingAccept(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Accept": "application/json", + "Authorization": "Bearer token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "text/event-stream", capturedReq.Header.Get("Accept"), + "Accept should not be overridden when already set") + assert.Equal(t, "Bearer token", capturedReq.Header.Get("Authorization"), + "Authorization should still be set") +} + +func TestHeaderRoundTripper_DoesNotMutateOriginalRequest(t *testing.T) { + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + assert.Empty(t, req.Header.Get("Authorization"), + "original request should not be mutated") +} + +func TestHeaderRoundTripper_MultipleCustomHeaders(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer dynamic-oauth-token", + "X-Custom": "custom-value", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer dynamic-oauth-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedReq.Header.Get("X-Custom")) +} + +func TestHeaderRoundTripper_EmptyHeaders(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{}, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Empty(t, capturedReq.Header.Get("Authorization"), + "no Authorization header when headers map is empty") +} diff --git a/pkg/oauth/dcr_registration.go b/pkg/oauth/dcr_registration.go index c6039dd2f..85656de77 100644 --- a/pkg/oauth/dcr_registration.go +++ b/pkg/oauth/dcr_registration.go @@ -3,11 +3,33 @@ package oauth import ( "context" "fmt" + "time" + + oauthhelpers "github.com/docker/mcp-gateway-oauth-helpers" "github.com/docker/mcp-gateway/pkg/catalog" "github.com/docker/mcp-gateway/pkg/desktop" ) +// dcrRegistrationClient is the subset of desktop.Tools used for DCR registration. +// Extracted as an interface to enable testing. +type dcrRegistrationClient interface { + GetDCRClient(ctx context.Context, app string) (*desktop.DCRClient, error) + RegisterDCRClientPending(ctx context.Context, app string, req desktop.RegisterDCRRequest) error +} + +// oauthProber abstracts OAuth discovery to enable testing. +type oauthProber interface { + DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*oauthhelpers.Discovery, error) +} + +// defaultOAuthProber wraps the package-level function. +type defaultOAuthProber struct{} + +func (defaultOAuthProber) DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*oauthhelpers.Discovery, error) { + return oauthhelpers.DiscoverOAuthRequirements(ctx, serverURL) +} + // RegisterProviderForLazySetup registers a DCR provider with Docker Desktop // This allows 'docker mcp oauth authorize' to work before full DCR is complete // Idempotent - safe to call multiple times for the same server @@ -46,6 +68,35 @@ func RegisterProviderForLazySetup(ctx context.Context, serverName string) error return client.RegisterDCRClientPending(ctx, serverName, dcrRequest) } +// RegisterProviderForDynamicDiscovery probes a remote server for OAuth support +// and creates a pending DCR entry if the server requires OAuth. +// This is used for community servers that lack oauth.providers metadata in the catalog. +// Idempotent - safe to call multiple times for the same server. +func RegisterProviderForDynamicDiscovery(ctx context.Context, serverName, serverURL string) error { + return registerProviderForDynamicDiscovery(ctx, serverName, serverURL, desktop.NewAuthClient(), defaultOAuthProber{}) +} + +func registerProviderForDynamicDiscovery(ctx context.Context, serverName, serverURL string, client dcrRegistrationClient, prober oauthProber) error { + // Idempotent check - already registered? + _, err := client.GetDCRClient(ctx, serverName) + if err == nil { + return nil // Already registered + } + + // Probe the server with a timeout to discover OAuth requirements + probeCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + discovery, err := prober.DiscoverOAuthRequirements(probeCtx, serverURL) + if err != nil || !discovery.RequiresOAuth { + return nil // Server doesn't need OAuth, not an error + } + + // Register with DD (pending DCR state) using server name as provider name + return client.RegisterDCRClientPending(ctx, serverName, desktop.RegisterDCRRequest{ + ProviderName: serverName, + }) +} + // RegisterProviderWithSnapshot registers a DCR provider using OAuth metadata from the server snapshot // This avoids querying the catalog since the snapshot already contains all necessary OAuth information // Idempotent - safe to call multiple times for the same server diff --git a/pkg/oauth/dcr_registration_test.go b/pkg/oauth/dcr_registration_test.go new file mode 100644 index 000000000..211555c85 --- /dev/null +++ b/pkg/oauth/dcr_registration_test.go @@ -0,0 +1,96 @@ +package oauth + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + oauthhelpers "github.com/docker/mcp-gateway-oauth-helpers" + "github.com/docker/mcp-gateway/pkg/desktop" +) + +// mockDCRClient implements dcrRegistrationClient for testing. +type mockDCRClient struct { + clients map[string]*desktop.DCRClient + registered map[string]desktop.RegisterDCRRequest +} + +func newMockDCRClient() *mockDCRClient { + return &mockDCRClient{ + clients: make(map[string]*desktop.DCRClient), + registered: make(map[string]desktop.RegisterDCRRequest), + } +} + +func (m *mockDCRClient) GetDCRClient(_ context.Context, app string) (*desktop.DCRClient, error) { + c, ok := m.clients[app] + if !ok { + return nil, errors.New("not found") + } + return c, nil +} + +func (m *mockDCRClient) RegisterDCRClientPending(_ context.Context, app string, req desktop.RegisterDCRRequest) error { + m.registered[app] = req + return nil +} + +// mockProber implements oauthProber for testing. +type mockProber struct { + discovery *oauthhelpers.Discovery + err error +} + +func (m *mockProber) DiscoverOAuthRequirements(_ context.Context, _ string) (*oauthhelpers.Discovery, error) { + return m.discovery, m.err +} + +func TestRegisterProviderForDynamicDiscovery_SkipsAlreadyRegistered(t *testing.T) { + client := newMockDCRClient() + client.clients["my-server"] = &desktop.DCRClient{State: "unregistered"} + + prober := &mockProber{} // should not be called + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when already exists") +} + +func TestRegisterProviderForDynamicDiscovery_RegistersWhenOAuthRequired(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: &oauthhelpers.Discovery{RequiresOAuth: true}, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "ai-kubit-mcp-server", "https://mcp.kubit.ai/mcp", client, prober) + require.NoError(t, err) + + req, ok := client.registered["ai-kubit-mcp-server"] + require.True(t, ok, "should have registered DCR client") + assert.Equal(t, "ai-kubit-mcp-server", req.ProviderName, "provider name should match server name") +} + +func TestRegisterProviderForDynamicDiscovery_SkipsWhenNoOAuthRequired(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: &oauthhelpers.Discovery{RequiresOAuth: false}, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when OAuth not required") +} + +func TestRegisterProviderForDynamicDiscovery_SkipsOnProbeError(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + err: errors.New("connection refused"), + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://unreachable.example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when probe fails") +} diff --git a/pkg/oauth/provider.go b/pkg/oauth/provider.go index 92e28547b..1256edf51 100644 --- a/pkg/oauth/provider.go +++ b/pkg/oauth/provider.go @@ -152,10 +152,14 @@ func (p *Provider) Run(ctx context.Context) { // Trigger refresh if needed if shouldTriggerRefresh { if IsCEMode() { - // CE mode: Refresh token directly + // CE mode: Refresh token directly, then reload server connection go func() { if err := p.refreshTokenCE(); err != nil { log.Logf("! Token refresh failed for %s: %v", p.name, err) + return + } + if err := p.reloadFn(ctx, p.name); err != nil { + log.Logf("! Failed to reload %s after token refresh: %v", p.name, err) } }() } else { diff --git a/pkg/workingset/oauth.go b/pkg/workingset/oauth.go index fa3f37dd0..5ebe51290 100644 --- a/pkg/workingset/oauth.go +++ b/pkg/workingset/oauth.go @@ -24,15 +24,20 @@ func RegisterOAuthProvidersForServers(ctx context.Context, servers []Server) { if server.Snapshot == nil { continue } - if !server.Snapshot.Server.IsRemoteOAuthServer() { - continue - } - - serverName := server.Snapshot.Server.Name - providerName := server.Snapshot.Server.OAuth.Providers[0].Provider + if server.Snapshot.Server.IsRemoteOAuthServer() { + serverName := server.Snapshot.Server.Name + providerName := server.Snapshot.Server.OAuth.Providers[0].Provider - if err := oauth.RegisterProviderWithSnapshot(ctx, serverName, providerName); err != nil { - log.Log(fmt.Sprintf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)) + if err := oauth.RegisterProviderWithSnapshot(ctx, serverName, providerName); err != nil { + log.Log(fmt.Sprintf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)) + } + } else if server.Snapshot.Server.Type == "remote" && server.Snapshot.Server.Remote.URL != "" { + // Community servers without oauth.providers: probe for OAuth dynamically + serverName := server.Snapshot.Server.Name + serverURL := server.Snapshot.Server.Remote.URL + if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverURL); err != nil { + log.Log(fmt.Sprintf("Warning: Failed dynamic OAuth discovery for %s: %v", serverName, err)) + } } } } diff --git a/pkg/workingset/server.go b/pkg/workingset/server.go index 437cad066..d276429c5 100644 --- a/pkg/workingset/server.go +++ b/pkg/workingset/server.go @@ -158,6 +158,7 @@ func RemoveServers(ctx context.Context, dao db.DAO, id string, serverNames []str // Tests can override this to verify the call without requiring Docker Desktop. var cleanupDCREntriesFunc = CleanupOrphanedDCREntries + // dcrClient abstracts the Desktop API operations needed for cleanup, // allowing tests to substitute a mock implementation. type dcrClient interface { @@ -167,7 +168,8 @@ type dcrClient interface { } // CleanupOrphanedDCREntries removes DCR entries for servers that no longer -// exist in any profile. This prevents stale OAuth entries from accumulating. +// exist in any profile and are not authorized. This prevents stale OAuth +// entries from accumulating. func CleanupOrphanedDCREntries(ctx context.Context, dao db.DAO, serverNames []string) { if oauth.IsCEMode() { return From 4f4219c4ee475c63b6fb91246844ab85e715d4ae Mon Sep 17 00:00:00 2001 From: Justin Chang Date: Sat, 14 Feb 2026 19:56:56 -0800 Subject: [PATCH 2/5] Limit DCR cleanup to successfully removed servers Pass only the names that were actually removed from the profile to CleanupOrphanedDCREntries, instead of the full user-supplied list. Co-Authored-By: Claude Opus 4.6 --- pkg/gateway/clientpool_test.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pkg/gateway/clientpool_test.go b/pkg/gateway/clientpool_test.go index b4db37cdb..96a149752 100644 --- a/pkg/gateway/clientpool_test.go +++ b/pkg/gateway/clientpool_test.go @@ -278,14 +278,6 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any { return config } -func readOnly() *bool { - return boolPtr(true) -} - -func boolPtr(b bool) *bool { - return &b -} - func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) { // Community server: remote URL set, but no Spec.OAuth metadata. // This verifies Gap 3: InvalidateOAuthClients matches community servers From 9511378ea96bfe773f25b1063cfcf5688c1077ed Mon Sep 17 00:00:00 2001 From: Justin Chang Date: Wed, 25 Feb 2026 11:04:30 -0800 Subject: [PATCH 3/5] Address PR review feedback for dynamic OAuth discovery - Guard against nil discovery response in registerProviderForDynamicDiscovery - Distinguish HTTP 404 from transient errors in DCR verification - Consolidate duplicate OAuth credential helper in remote.go - Simplify mcpadd OAuth condition to use serverConfig.IsRemote() - Remove extra blank line in server.go Co-Authored-By: Claude Opus 4.6 --- pkg/gateway/mcpadd.go | 16 +++++++++++----- pkg/mcp/remote.go | 18 ++++-------------- pkg/oauth/dcr_registration.go | 2 +- pkg/oauth/dcr_registration_test.go | 12 ++++++++++++ pkg/workingset/server.go | 1 - 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index 94b364b97..a1eacbd3b 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -286,11 +286,12 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler { } } - // Handle OAuth DCR only when the client supports elicitation (e.g. not stdio-based clients) + // Handle OAuth DCR for any remote server — covers both catalog servers + // (explicit OAuth metadata) and community servers (dynamic discovery). + // getRemoteOAuthServerStatus handles the case where OAuth is not needed. if g.McpOAuthDcrEnabled && serverConfig != nil && - (serverConfig.Spec.IsRemoteOAuthServer() || - (serverConfig.Spec.Type == "remote" && !serverConfig.Spec.IsOAuthServer() && serverConfig.Spec.Remote.URL != "")) { + serverConfig.IsRemote() { init := req.Session.InitializeParams() if init != nil && @@ -455,10 +456,15 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str } } - // Verify DCR entry was created — dynamic discovery may have found no OAuth requirement + // Verify DCR entry was created — dynamic discovery may have found no OAuth requirement. + // Distinguish "not found" (server doesn't need OAuth) from transient API errors. authClient := desktop.NewAuthClient() if _, err := authClient.GetDCRClient(ctx, serverName); err != nil { - return true, "" // Server doesn't require OAuth + if strings.Contains(err.Error(), "HTTP 404") { + return true, "" // Server doesn't require OAuth + } + log.Logf("Warning: Failed to verify DCR entry for %s (may be transient): %v", serverName, err) + return true, "" // Fail open — avoid blocking the add flow on transient errors } // Start provider (CE mode only - Desktop mode doesn't need polling) diff --git a/pkg/mcp/remote.go b/pkg/mcp/remote.go index 2bf706187..c0c17d3b8 100644 --- a/pkg/mcp/remote.go +++ b/pkg/mcp/remote.go @@ -81,26 +81,16 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParam headers[k] = expandEnv(v, env) } - // Add OAuth token if remote server has OAuth configuration - if c.config.Spec.OAuth != nil && len(c.config.Spec.OAuth.Providers) > 0 { - if verbose { - log.Logf(" - Using OAuth token for: %s", c.config.Name) - } + // Add OAuth token for remote servers — covers both catalog servers (explicit + // OAuth metadata) and community servers (dynamic discovery via DCR). + if c.config.Spec.Remote.URL != "" { credHelper := oauth.NewOAuthCredentialHelper() token, err := credHelper.GetOAuthToken(ctx, c.config.Name) if err != nil { log.Logf("Failed to get OAuth token for %s: %v", c.config.Name, err) } else if token != "" { - headers["Authorization"] = "Bearer " + token - } - } else if c.config.Spec.Remote.URL != "" { - // Community servers may have OAuth tokens via dynamic discovery (DCR) - // without explicit OAuth metadata in the catalog. Try to get a stored token. - credHelper := oauth.NewOAuthCredentialHelper() - token, err := credHelper.GetOAuthToken(ctx, c.config.Name) - if err == nil && token != "" { if verbose { - log.Logf(" - Using dynamic OAuth token for: %s", c.config.Name) + log.Logf(" - Using OAuth token for: %s", c.config.Name) } headers["Authorization"] = "Bearer " + token } diff --git a/pkg/oauth/dcr_registration.go b/pkg/oauth/dcr_registration.go index 85656de77..cf8d5baf6 100644 --- a/pkg/oauth/dcr_registration.go +++ b/pkg/oauth/dcr_registration.go @@ -87,7 +87,7 @@ func registerProviderForDynamicDiscovery(ctx context.Context, serverName, server probeCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() discovery, err := prober.DiscoverOAuthRequirements(probeCtx, serverURL) - if err != nil || !discovery.RequiresOAuth { + if err != nil || discovery == nil || !discovery.RequiresOAuth { return nil // Server doesn't need OAuth, not an error } diff --git a/pkg/oauth/dcr_registration_test.go b/pkg/oauth/dcr_registration_test.go index 211555c85..a1d8ae986 100644 --- a/pkg/oauth/dcr_registration_test.go +++ b/pkg/oauth/dcr_registration_test.go @@ -84,6 +84,18 @@ func TestRegisterProviderForDynamicDiscovery_SkipsWhenNoOAuthRequired(t *testing assert.Empty(t, client.registered, "should not register when OAuth not required") } +func TestRegisterProviderForDynamicDiscovery_SkipsOnNilDiscovery(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: nil, + err: nil, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when discovery is nil") +} + func TestRegisterProviderForDynamicDiscovery_SkipsOnProbeError(t *testing.T) { client := newMockDCRClient() prober := &mockProber{ diff --git a/pkg/workingset/server.go b/pkg/workingset/server.go index d276429c5..cd9dd2eb7 100644 --- a/pkg/workingset/server.go +++ b/pkg/workingset/server.go @@ -158,7 +158,6 @@ func RemoveServers(ctx context.Context, dao db.DAO, id string, serverNames []str // Tests can override this to verify the call without requiring Docker Desktop. var cleanupDCREntriesFunc = CleanupOrphanedDCREntries - // dcrClient abstracts the Desktop API operations needed for cleanup, // allowing tests to substitute a mock implementation. type dcrClient interface { From ac2a77c9982d716ecca515198a4c60ec6d6a153e Mon Sep 17 00:00:00 2001 From: Justin Chang Date: Wed, 25 Feb 2026 11:33:43 -0800 Subject: [PATCH 4/5] Revert remote.go OAuth token consolidation to avoid log noise The consolidated check called GetOAuthToken for all remote servers, which logs errors for servers without stored tokens. Restore the two-branch structure so catalog servers (explicit OAuth metadata) log errors while community servers silently ignore missing tokens. Co-Authored-By: Claude Opus 4.6 --- pkg/mcp/remote.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pkg/mcp/remote.go b/pkg/mcp/remote.go index c0c17d3b8..2bf706187 100644 --- a/pkg/mcp/remote.go +++ b/pkg/mcp/remote.go @@ -81,16 +81,26 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParam headers[k] = expandEnv(v, env) } - // Add OAuth token for remote servers — covers both catalog servers (explicit - // OAuth metadata) and community servers (dynamic discovery via DCR). - if c.config.Spec.Remote.URL != "" { + // Add OAuth token if remote server has OAuth configuration + if c.config.Spec.OAuth != nil && len(c.config.Spec.OAuth.Providers) > 0 { + if verbose { + log.Logf(" - Using OAuth token for: %s", c.config.Name) + } credHelper := oauth.NewOAuthCredentialHelper() token, err := credHelper.GetOAuthToken(ctx, c.config.Name) if err != nil { log.Logf("Failed to get OAuth token for %s: %v", c.config.Name, err) } else if token != "" { + headers["Authorization"] = "Bearer " + token + } + } else if c.config.Spec.Remote.URL != "" { + // Community servers may have OAuth tokens via dynamic discovery (DCR) + // without explicit OAuth metadata in the catalog. Try to get a stored token. + credHelper := oauth.NewOAuthCredentialHelper() + token, err := credHelper.GetOAuthToken(ctx, c.config.Name) + if err == nil && token != "" { if verbose { - log.Logf(" - Using OAuth token for: %s", c.config.Name) + log.Logf(" - Using dynamic OAuth token for: %s", c.config.Name) } headers["Authorization"] = "Bearer " + token } From 025eceef0897d4fabf508c38e8f12a30af4eef28 Mon Sep 17 00:00:00 2001 From: Justin Chang Date: Wed, 25 Feb 2026 15:06:36 -0800 Subject: [PATCH 5/5] Address reviewer feedback and rename IsRemoteOAuthServer - Remove redundant 10s probe timeout (discovery library uses 30s internally) - Log dynamic OAuth discovery failures instead of silently returning nil - Rename IsRemoteOAuthServer to HasExplicitOAuthProviders for clarity - Log TokenExists errors during community server provider startup Co-Authored-By: Claude Opus 4.6 --- cmd/docker-mcp/server/enable.go | 4 ++-- pkg/catalog/types.go | 5 ++++- pkg/gateway/run.go | 6 ++++-- pkg/oauth/dcr_registration.go | 19 +++++++++++-------- pkg/workingset/oauth.go | 2 +- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/cmd/docker-mcp/server/enable.go b/cmd/docker-mcp/server/enable.go index 0c045aaa4..08787c212 100644 --- a/cmd/docker-mcp/server/enable.go +++ b/cmd/docker-mcp/server/enable.go @@ -61,7 +61,7 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad } // DCR flag enabled AND type="remote" AND oauth present - if mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() { + if mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() { // In CE mode, skip lazy setup - DCR happens during oauth authorize if pkgoauth.IsCEMode() { fmt.Printf("OAuth server %s enabled. Run 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName) @@ -74,7 +74,7 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad fmt.Printf("OAuth provider configured for %s - use 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName) } } - } else if !mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() { + } else if !mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() { // Provide guidance when DCR is needed but disabled fmt.Printf("Server %s requires OAuth authentication but DCR is disabled.\n", serverName) fmt.Printf(" To enable automatic OAuth setup, run: docker mcp feature enable mcp-oauth-dcr\n") diff --git a/pkg/catalog/types.go b/pkg/catalog/types.go index e08d7223f..9ee7768c6 100644 --- a/pkg/catalog/types.go +++ b/pkg/catalog/types.go @@ -63,7 +63,10 @@ func (s *Server) IsOAuthServer() bool { return s.OAuth != nil && len(s.OAuth.Providers) > 0 } -func (s *Server) IsRemoteOAuthServer() bool { +// HasExplicitOAuthProviders returns true if this is a remote server with +// explicit OAuth provider metadata in the catalog (e.g. oauth.providers YAML). +// Community servers that discover OAuth dynamically will return false here. +func (s *Server) HasExplicitOAuthProviders() bool { return s.Type == "remote" && s.IsOAuthServer() } diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index 8f51c3a1b..6a4776dea 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -359,12 +359,14 @@ func (g *Gateway) Run(ctx context.Context) error { continue } - if serverConfig.Spec.IsRemoteOAuthServer() { + if serverConfig.Spec.HasExplicitOAuthProviders() { g.startProvider(ctx, serverName) } else if serverConfig.IsRemote() { // Community servers: start provider if they have a stored OAuth token // from dynamic discovery (DCR without explicit OAuth metadata) - if exists, _ := credHelper.TokenExists(ctx, serverName); exists { + if exists, err := credHelper.TokenExists(ctx, serverName); err != nil { + log.Logf("Warning: Failed to check OAuth token for %s: %v", serverName, err) + } else if exists { log.Logf("- Starting OAuth provider for community server: %s", serverName) g.startProvider(ctx, serverName) } diff --git a/pkg/oauth/dcr_registration.go b/pkg/oauth/dcr_registration.go index cf8d5baf6..d7070cb8b 100644 --- a/pkg/oauth/dcr_registration.go +++ b/pkg/oauth/dcr_registration.go @@ -3,12 +3,12 @@ package oauth import ( "context" "fmt" - "time" oauthhelpers "github.com/docker/mcp-gateway-oauth-helpers" "github.com/docker/mcp-gateway/pkg/catalog" "github.com/docker/mcp-gateway/pkg/desktop" + "github.com/docker/mcp-gateway/pkg/log" ) // dcrRegistrationClient is the subset of desktop.Tools used for DCR registration. @@ -54,7 +54,7 @@ func RegisterProviderForLazySetup(ctx context.Context, serverName string) error } // Verify this is a remote OAuth server (Type="remote" && OAuth providers exist) - if !server.IsRemoteOAuthServer() { + if !server.HasExplicitOAuthProviders() { return fmt.Errorf("server %s is not a remote OAuth server", serverName) } @@ -83,12 +83,15 @@ func registerProviderForDynamicDiscovery(ctx context.Context, serverName, server return nil // Already registered } - // Probe the server with a timeout to discover OAuth requirements - probeCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - discovery, err := prober.DiscoverOAuthRequirements(probeCtx, serverURL) - if err != nil || discovery == nil || !discovery.RequiresOAuth { - return nil // Server doesn't need OAuth, not an error + // Probe the server to discover OAuth requirements. + // The discovery library uses its own 30s HTTP timeout internally. + discovery, err := prober.DiscoverOAuthRequirements(ctx, serverURL) + if err != nil { + log.Logf("Dynamic OAuth discovery failed for %s: %v", serverName, err) + return nil // Probe failed, not fatal + } + if discovery == nil || !discovery.RequiresOAuth { + return nil // Server doesn't need OAuth } // Register with DD (pending DCR state) using server name as provider name diff --git a/pkg/workingset/oauth.go b/pkg/workingset/oauth.go index 5ebe51290..d019e0e6b 100644 --- a/pkg/workingset/oauth.go +++ b/pkg/workingset/oauth.go @@ -24,7 +24,7 @@ func RegisterOAuthProvidersForServers(ctx context.Context, servers []Server) { if server.Snapshot == nil { continue } - if server.Snapshot.Server.IsRemoteOAuthServer() { + if server.Snapshot.Server.HasExplicitOAuthProviders() { serverName := server.Snapshot.Server.Name providerName := server.Snapshot.Server.OAuth.Providers[0].Provider