Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/docker-mcp/oauth/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func Revoke(ctx context.Context, app string) error {
func revokeDesktopMode(ctx context.Context, app string) error {
client := desktop.NewAuthClient()

// Revoke tokens
// Revoke tokens via Docker Desktop
if err := client.DeleteOAuthApp(ctx, app); err != nil {
return fmt.Errorf("failed to revoke OAuth access: %w", err)
}
Expand Down
13 changes: 11 additions & 2 deletions cmd/docker-mcp/server/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad
}

// DCR flag enabled AND type="remote" AND oauth present
if mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() {
if mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() {
// In CE mode, skip lazy setup - DCR happens during oauth authorize
if pkgoauth.IsCEMode() {
fmt.Printf("OAuth server %s enabled. Run 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName)
Expand All @@ -74,11 +74,20 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad
fmt.Printf("OAuth provider configured for %s - use 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName)
}
}
} else if !mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() {
} else if !mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() {
// Provide guidance when DCR is needed but disabled
fmt.Printf("Server %s requires OAuth authentication but DCR is disabled.\n", serverName)
fmt.Printf(" To enable automatic OAuth setup, run: docker mcp feature enable mcp-oauth-dcr\n")
fmt.Printf(" Or set up OAuth manually using: docker mcp oauth authorize %s\n", serverName)
} else if mcpOAuthDcrEnabled && server.Type == "remote" && !server.IsOAuthServer() && server.Remote.URL != "" {
// Community server without oauth.providers — probe for OAuth
if pkgoauth.IsCEMode() {
fmt.Printf("Remote server %s enabled. Run 'docker mcp oauth authorize %s' if authentication is required\n", serverName, serverName)
} else {
if err := pkgoauth.RegisterProviderForDynamicDiscovery(ctx, serverName, server.Remote.URL); err != nil {
fmt.Printf("Warning: Dynamic OAuth discovery failed for %s: %v\n", serverName, err)
}
}
}
} else {
return fmt.Errorf("server %s not found in catalog", serverName)
Expand Down
5 changes: 4 additions & 1 deletion pkg/catalog/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ func (s *Server) IsOAuthServer() bool {
return s.OAuth != nil && len(s.OAuth.Providers) > 0
}

func (s *Server) IsRemoteOAuthServer() bool {
// HasExplicitOAuthProviders returns true if this is a remote server with
// explicit OAuth provider metadata in the catalog (e.g. oauth.providers YAML).
// Community servers that discover OAuth dynamically will return false here.
func (s *Server) HasExplicitOAuthProviders() bool {
return s.Type == "remote" && s.IsOAuthServer()
}

Expand Down
33 changes: 16 additions & 17 deletions pkg/gateway/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,24 +179,23 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) {

var invalidatedKeys []clientKey
for key, keptClient := range cp.keptClients {
// Check if this client uses OAuth for the specified provider
if keptClient.Config.Spec.OAuth != nil {
// Match by server name (for DCR providers, server name matches provider)
if keptClient.Config.Name == provider {
log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name))

// Close the connection
client, err := keptClient.Getter.GetClient(context.TODO())
if err == nil {
client.Session().Close()
log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name))
} else {
log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err))
}

// Mark for removal from kept clients
invalidatedKeys = append(invalidatedKeys, key)
// Check if this remote client matches the OAuth provider
// Matches both catalog servers (explicit OAuth metadata) and community servers
// (dynamic OAuth discovery via DCR without Spec.OAuth)
if keptClient.Config.Name == provider && keptClient.Config.IsRemote() {
log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name))

// Close the connection
client, err := keptClient.Getter.GetClient(context.TODO())
if err == nil {
client.Session().Close()
log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name))
} else {
log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err))
}

// Mark for removal from kept clients
invalidatedKeys = append(invalidatedKeys, key)
}
}

Expand Down
192 changes: 192 additions & 0 deletions pkg/gateway/clientpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"fmt"
"os"
"testing"
"time"
Expand Down Expand Up @@ -277,6 +278,197 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any {
return config
}

func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) {
// Community server: remote URL set, but no Spec.OAuth metadata.
// This verifies Gap 3: InvalidateOAuthClients matches community servers
// that use dynamic OAuth discovery without explicit OAuth config.
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

getter := &clientGetter{}
getter.once.Do(func() {}) // mark as executed
getter.err = fmt.Errorf("mock: no real client")

key := clientKey{serverName: "com-notion-mcp"}
cp.keptClients[key] = keptClient{
Name: "com-notion-mcp",
Getter: getter,
Config: &catalog.ServerConfig{
Name: "com-notion-mcp",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{
URL: "https://mcp.notion.so/mcp",
Transport: "streamable-http",
},
// No OAuth field - community server
},
},
}

cp.InvalidateOAuthClients("com-notion-mcp")

assert.Empty(t, cp.keptClients, "community server should be invalidated by name")
}

func TestInvalidateOAuthClients_MatchesCatalogServer(t *testing.T) {
// Catalog server: remote URL set WITH Spec.OAuth metadata.
// Verifies backward compatibility: catalog servers still get invalidated.
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

getter := &clientGetter{}
getter.once.Do(func() {})
getter.err = fmt.Errorf("mock: no real client")

key := clientKey{serverName: "notion-remote"}
cp.keptClients[key] = keptClient{
Name: "notion-remote",
Getter: getter,
Config: &catalog.ServerConfig{
Name: "notion-remote",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{
URL: "https://mcp.notion.so/mcp",
Transport: "streamable-http",
},
OAuth: &catalog.OAuth{
Providers: []catalog.OAuthProvider{{Provider: "notion"}},
},
},
},
}

cp.InvalidateOAuthClients("notion-remote")

assert.Empty(t, cp.keptClients, "catalog server should be invalidated by name")
}

func TestInvalidateOAuthClients_SkipsNonRemoteServer(t *testing.T) {
// Docker container server: not remote, should NOT be invalidated.
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

getter := &clientGetter{}
getter.once.Do(func() {})
getter.err = fmt.Errorf("mock: no real client")

key := clientKey{serverName: "my-container-server"}
cp.keptClients[key] = keptClient{
Name: "my-container-server",
Getter: getter,
Config: &catalog.ServerConfig{
Name: "my-container-server",
Spec: catalog.Server{
Type: "server",
Image: "mcp/my-server:latest",
// Not remote - no URL
},
},
}

cp.InvalidateOAuthClients("my-container-server")

assert.Len(t, cp.keptClients, 1, "non-remote server should NOT be invalidated")
}

func TestInvalidateOAuthClients_SkipsMismatchedName(t *testing.T) {
// Remote server with different name: should NOT be invalidated.
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

getter := &clientGetter{}
getter.once.Do(func() {})
getter.err = fmt.Errorf("mock: no real client")

key := clientKey{serverName: "other-server"}
cp.keptClients[key] = keptClient{
Name: "other-server",
Getter: getter,
Config: &catalog.ServerConfig{
Name: "other-server",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{
URL: "https://other.example.com/mcp",
},
},
},
}

cp.InvalidateOAuthClients("com-notion-mcp")

assert.Len(t, cp.keptClients, 1, "server with different name should NOT be invalidated")
}

func TestInvalidateOAuthClients_OnlyMatchingRemoved(t *testing.T) {
// Multiple clients: only the matching remote server should be removed.
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

makeGetter := func() *clientGetter {
g := &clientGetter{}
g.once.Do(func() {})
g.err = fmt.Errorf("mock: no real client")
return g
}

// Community OAuth server (should be invalidated)
cp.keptClients[clientKey{serverName: "com-notion-mcp"}] = keptClient{
Name: "com-notion-mcp",
Getter: makeGetter(),
Config: &catalog.ServerConfig{
Name: "com-notion-mcp",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{URL: "https://mcp.notion.so/mcp"},
},
},
}

// Different remote server (should NOT be invalidated)
cp.keptClients[clientKey{serverName: "github-remote"}] = keptClient{
Name: "github-remote",
Getter: makeGetter(),
Config: &catalog.ServerConfig{
Name: "github-remote",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{URL: "https://mcp.github.com/mcp"},
},
},
}

// Docker container server (should NOT be invalidated)
cp.keptClients[clientKey{serverName: "local-server"}] = keptClient{
Name: "local-server",
Getter: makeGetter(),
Config: &catalog.ServerConfig{
Name: "local-server",
Spec: catalog.Server{
Type: "server",
Image: "mcp/local:latest",
},
},
}

cp.InvalidateOAuthClients("com-notion-mcp")

assert.Len(t, cp.keptClients, 2, "only the matching remote server should be removed")
_, hasNotion := cp.keptClients[clientKey{serverName: "com-notion-mcp"}]
assert.False(t, hasNotion, "com-notion-mcp should have been removed")
_, hasGithub := cp.keptClients[clientKey{serverName: "github-remote"}]
assert.True(t, hasGithub, "github-remote should remain")
_, hasLocal := cp.keptClients[clientKey{serverName: "local-server"}]
assert.True(t, hasLocal, "local-server should remain")
}

func TestStdioClientInitialization(t *testing.T) {
// This is an integration test that requires Docker
if testing.Short() {
Expand Down
26 changes: 23 additions & 3 deletions pkg/gateway/mcpadd.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,12 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler {
}
}

// Handle OAuth DCR only when the client supports elicitation (e.g. not stdio-based clients)
// Handle OAuth DCR for any remote server — covers both catalog servers
// (explicit OAuth metadata) and community servers (dynamic discovery).
// getRemoteOAuthServerStatus handles the case where OAuth is not needed.
if g.McpOAuthDcrEnabled &&
serverConfig != nil &&
serverConfig.Spec.IsRemoteOAuthServer() {
serverConfig.IsRemote() {

init := req.Session.InitializeParams()
if init != nil &&
Expand Down Expand Up @@ -444,7 +446,25 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str
if !providerExists {
// Register DCR client with DD so user can authorize
if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil {
log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)
// Fallback: try dynamic discovery for community servers without oauth.providers
if serverConfig, _, found := g.configuration.Find(serverName); found && serverConfig.Spec.Remote.URL != "" {
if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverConfig.Spec.Remote.URL); err != nil {
log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)
}
} else {
log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)
}
}

// Verify DCR entry was created — dynamic discovery may have found no OAuth requirement.
// Distinguish "not found" (server doesn't need OAuth) from transient API errors.
authClient := desktop.NewAuthClient()
if _, err := authClient.GetDCRClient(ctx, serverName); err != nil {
if strings.Contains(err.Error(), "HTTP 404") {
return true, "" // Server doesn't require OAuth
}
log.Logf("Warning: Failed to verify DCR entry for %s (may be transient): %v", serverName, err)
return true, "" // Fail open — avoid blocking the add flow on transient errors
}

// Start provider (CE mode only - Desktop mode doesn't need polling)
Expand Down
20 changes: 17 additions & 3 deletions pkg/gateway/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,25 @@ func (g *Gateway) Run(ctx context.Context) error {
// Start OAuth provider for each OAuth server.
// Each provider runs in its own goroutine with dynamic timing based on token expiry.
log.Log("- Starting OAuth provider loops...")
credHelper := oauth.NewOAuthCredentialHelper()
for _, serverName := range configuration.ServerNames() {
serverConfig, _, found := configuration.Find(serverName)
if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() {
if !found || serverConfig == nil {
continue
}

g.startProvider(ctx, serverName)
if serverConfig.Spec.HasExplicitOAuthProviders() {
g.startProvider(ctx, serverName)
} else if serverConfig.IsRemote() {
// Community servers: start provider if they have a stored OAuth token
// from dynamic discovery (DCR without explicit OAuth metadata)
if exists, err := credHelper.TokenExists(ctx, serverName); err != nil {
log.Logf("Warning: Failed to check OAuth token for %s: %v", serverName, err)
} else if exists {
log.Logf("- Starting OAuth provider for community server: %s", serverName)
g.startProvider(ctx, serverName)
}
}
}
}

Expand Down Expand Up @@ -697,7 +709,9 @@ func (g *Gateway) routeEventToProvider(event oauth.Event) {
g.clientPool.InvalidateOAuthClients(event.Provider)

case oauth.EventLogoutSuccess:
// User logged out - stop provider if exists
// Invalidate cached OAuth client connections (clear stale bearer tokens)
g.clientPool.InvalidateOAuthClients(event.Provider)
// Stop provider if exists
if exists {
log.Logf("- Stopping provider for %s after logout", event.Provider)
g.stopProvider(event.Provider)
Expand Down
11 changes: 11 additions & 0 deletions pkg/mcp/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParam
} else if token != "" {
headers["Authorization"] = "Bearer " + token
}
} else if c.config.Spec.Remote.URL != "" {
// Community servers may have OAuth tokens via dynamic discovery (DCR)
// without explicit OAuth metadata in the catalog. Try to get a stored token.
credHelper := oauth.NewOAuthCredentialHelper()
token, err := credHelper.GetOAuthToken(ctx, c.config.Name)
if err == nil && token != "" {
if verbose {
log.Logf(" - Using dynamic OAuth token for: %s", c.config.Name)
}
headers["Authorization"] = "Bearer " + token
}
}

var mcpTransport mcp.Transport
Expand Down
Loading
Loading