diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 9e93f7510..c3c72510f 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -15,6 +15,7 @@ import ( "io" "net/http" "net/url" + "path" "strings" "time" @@ -252,9 +253,16 @@ func DeriveIssuerFromURL(remoteURL string) string { host = fmt.Sprintf("%s:%s", host, port) } + // For localhost, preserve the original scheme (HTTP or HTTPS) + // This supports local development and testing scenarios + scheme := networking.HttpsScheme + if networking.IsLocalhost(host) && parsedURL.Scheme != "" { + scheme = parsedURL.Scheme + } + // General pattern: use the domain as the issuer // This works for most OAuth providers that use their domain as the issuer - issuer := fmt.Sprintf("%s://%s", networking.HttpsScheme, host) + issuer := fmt.Sprintf("%s://%s", scheme, host) logger.Debugf("Derived issuer from URL - remoteURL: %s, issuer: %s", remoteURL, issuer) return issuer @@ -327,11 +335,22 @@ func DeriveIssuerFromRealm(realm string) string { // RFC 8414: The issuer identifier MUST be a URL using the "https" scheme // with no query or fragment components - if parsedURL.Scheme != "https" { + if parsedURL.Scheme != "https" && !networking.IsLocalhost(parsedURL.Host) { logger.Debugf("Realm is not using HTTPS scheme: %s", realm) return "" } + // Normalize the path to prevent path traversal attacks + if parsedURL.Path != "" { + // Clean the path to resolve . and .. elements + cleanPath := path.Clean(parsedURL.Path) + // Ensure the path doesn't escape the root + if !strings.HasPrefix(cleanPath, "/") { + cleanPath = "/" + cleanPath + } + parsedURL.Path = cleanPath + } + if parsedURL.RawQuery != "" || parsedURL.Fragment != "" { logger.Debugf("Realm contains query or fragment components: %s", realm) // Remove query and fragment to make it a valid issuer diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index 0b91a7743..4da81a1f7 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -21,6 +21,10 @@ import ( "github.com/stacklok/toolhive/pkg/transport/types" ) +const ( + localhostStr = "localhost" +) + func TestNewRunConfig(t *testing.T) { t.Parallel() config := NewRunConfig() @@ -531,13 +535,13 @@ func TestRunConfigBuilder(t *testing.T) { TargetPort: 9090, Args: []string{"--metadata-arg"}, } - host := "localhost" + host := localhostStr debug := true volumes := []string{"/host:/container"} secretsList := []string{"secret1,target=ENV_VAR1"} authzConfigPath := "" // Empty to skip loading the authorization configuration permissionProfile := permissions.ProfileNone - targetHost := "localhost" + targetHost := localhostStr mcpTransport := "sse" proxyPort := 60000 targetPort := 9000 @@ -803,8 +807,8 @@ func TestRunConfigBuilder_MetadataOverrides(t *testing.T) { WithCmdArgs(nil), WithName("test-server"), WithImage("test-image"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(nil), WithSecrets(nil), @@ -850,8 +854,8 @@ func TestRunConfigBuilder_EnvironmentVariableTransportDependency(t *testing.T) { WithCmdArgs(nil), WithName("test-server"), WithImage("test-image"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(nil), WithSecrets(nil), @@ -902,8 +906,8 @@ func TestRunConfigBuilder_CmdArgsMetadataOverride(t *testing.T) { WithCmdArgs(userArgs), WithName("test-server"), WithImage("test-image"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(nil), WithSecrets(nil), @@ -957,8 +961,8 @@ func TestRunConfigBuilder_CmdArgsMetadataDefaults(t *testing.T) { WithCmdArgs(userArgs), WithName("test-server"), WithImage("test-image"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(nil), WithSecrets(nil), @@ -1008,8 +1012,8 @@ func TestRunConfigBuilder_VolumeProcessing(t *testing.T) { WithCmdArgs(nil), WithName("test-server"), WithImage("test-image"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(volumes), WithSecrets(nil), @@ -1081,8 +1085,8 @@ func TestRunConfigBuilder_FilesystemMCPScenario(t *testing.T) { WithCmdArgs(userArgs), WithName("filesystem"), WithImage("mcp/filesystem:latest"), - WithHost("localhost"), - WithTargetHost("localhost"), + WithHost(localhostStr), + WithTargetHost(localhostStr), WithDebug(false), WithVolumes(nil), WithSecrets(nil), diff --git a/pkg/runner/remote_auth.go b/pkg/runner/remote_auth.go index 601d2182f..9c673b64b 100644 --- a/pkg/runner/remote_auth.go +++ b/pkg/runner/remote_auth.go @@ -114,9 +114,19 @@ func (h *RemoteAuthHandler) discoverIssuerAndScopes( return h.tryDiscoverFromResourceMetadata(ctx, authInfo.ResourceMetadata) } - issuer := discovery.DeriveIssuerFromURL(remoteURL) - if issuer != "" { - return issuer, h.config.Scopes, nil, nil + // Priority 4: Try to discover actual issuer from the server's well-known endpoint + // This handles cases where the issuer differs from the server URL (e.g., Atlassian) + issuer, scopes, authServerInfo, err := h.tryDiscoverFromWellKnown(ctx, remoteURL) + if err == nil { + return issuer, scopes, authServerInfo, nil + } + logger.Debugf("Could not discover from well-known endpoint: %v", err) + + // Priority 5: Last resort - derive issuer from URL without discovery + derivedIssuer := discovery.DeriveIssuerFromURL(remoteURL) + if derivedIssuer != "" { + logger.Infof("Using derived issuer from URL: %s", derivedIssuer) + return derivedIssuer, h.config.Scopes, nil, nil } // No issuer could be determined @@ -183,3 +193,40 @@ func (*RemoteAuthHandler) findValidAuthServer( return nil, "" } + +// tryDiscoverFromWellKnown attempts to discover the actual OAuth issuer +// by probing the server's well-known endpoints without validating issuer match +// This is useful when the issuer differs from the server URL (e.g., Atlassian case) +func (h *RemoteAuthHandler) tryDiscoverFromWellKnown( + ctx context.Context, + remoteURL string, +) (string, []string, *discovery.AuthServerInfo, error) { + // First try to derive a base URL from the remote URL + derivedURL := discovery.DeriveIssuerFromURL(remoteURL) + if derivedURL == "" { + return "", nil, nil, fmt.Errorf("could not derive base URL from %s", remoteURL) + } + + // Try to discover the actual issuer without validation + // This uses DiscoverActualIssuer which doesn't validate issuer match + authServerInfo, err := discovery.ValidateAndDiscoverAuthServer(ctx, derivedURL) + if err != nil { + return "", nil, nil, fmt.Errorf("well-known discovery failed: %w", err) + } + + // Successfully discovered the actual issuer + if authServerInfo.Issuer != derivedURL { + logger.Infof("Discovered actual issuer: %s (differs from server URL: %s)", + authServerInfo.Issuer, derivedURL) + } + + // Determine scopes - use configured or fall back to defaults + scopes := h.config.Scopes + if len(scopes) == 0 { + // Use some reasonable defaults if no scopes configured + scopes = []string{"openid", "profile"} + logger.Debugf("No scopes configured, using defaults: %v", scopes) + } + + return authServerInfo.Issuer, scopes, authServerInfo, nil +} diff --git a/pkg/runner/remote_auth_test.go b/pkg/runner/remote_auth_test.go new file mode 100644 index 000000000..a378fc13f --- /dev/null +++ b/pkg/runner/remote_auth_test.go @@ -0,0 +1,619 @@ +package runner + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/logger" +) + +const ( + resourceMetadataPath = "/.well-known/resource-metadata" +) + +func init() { + // Initialize logger for tests + logger.Initialize() +} + +func TestDiscoverIssuerAndScopes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *RemoteAuthConfig + authInfo *discovery.AuthInfo + remoteURL string + mockServers map[string]*httptest.Server + expectedIssuer string + expectedScopes []string + expectedAuthServer bool + expectError bool + errorContains string + }{ + // Priority 1: Configured issuer takes precedence + { + name: "configured issuer takes precedence", + config: &RemoteAuthConfig{ + Issuer: "https://configured.example.com", + Scopes: []string{"openid", "profile"}, + }, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://realm.example.com", + ResourceMetadata: "https://metadata.example.com", + }, + remoteURL: "https://server.example.com", + expectedIssuer: "https://configured.example.com", + expectedScopes: []string{"openid", "profile"}, + expectError: false, + }, + + // Priority 2: Realm-derived issuer + { + name: "valid realm URL derives issuer", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://auth.example.com/realm/mcp", + }, + remoteURL: "https://server.example.com", + expectedIssuer: "https://auth.example.com/realm/mcp", + expectedScopes: nil, + expectError: false, + }, + { + name: "realm with query and fragment stripped", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://auth.example.com/realm?param=value#fragment", + }, + remoteURL: "https://server.example.com", + expectedIssuer: "https://auth.example.com/realm", + expectedScopes: nil, + expectError: false, + }, + + // Priority 3: Resource metadata + // These tests use dynamic setup to create properly linked servers + { + name: "valid resource metadata", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: "dynamic", // Special marker for dynamic setup + }, + remoteURL: "https://server.example.com", + mockServers: map[string]*httptest.Server{ + "dynamic": nil, // Will be created with linked servers + }, + expectedIssuer: "dynamic", // Will be set to auth server URL + expectedScopes: nil, + expectedAuthServer: true, + expectError: false, + }, + { + name: "resource metadata with multiple auth servers", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: "dynamic-multi", // Special marker for dynamic setup + }, + remoteURL: "https://server.example.com", + mockServers: map[string]*httptest.Server{ + "dynamic": nil, // Will be created with linked servers + }, + expectedIssuer: "dynamic", // Will be set to second auth server URL + expectedScopes: nil, + expectedAuthServer: true, + expectError: false, + }, + + // Priority 4: Well-known discovery (Atlassian scenario) + { + name: "well-known discovery with issuer mismatch", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + }, + remoteURL: "https://mcp.atlassian.com/v1/sse", + mockServers: map[string]*httptest.Server{ + "mcp.atlassian.com": createMockAuthServer(t, "https://atlassian-workers.example.com"), + }, + expectedIssuer: "https://atlassian-workers.example.com", + expectedScopes: []string{"openid", "profile"}, + expectedAuthServer: true, + expectError: false, + }, + + // Priority 5: URL-derived fallback + { + name: "url derived fallback when well-known fails", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + }, + remoteURL: "", // Will be set from mock server + mockServers: map[string]*httptest.Server{ + "localhost": createMock404Server(t), + }, + expectedIssuer: "", // Will be set dynamically to match server URL + expectedScopes: nil, + expectError: false, + }, + + // Security test cases + { + name: "http realm rejected for security", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "http://insecure.example.com", // HTTP not HTTPS + }, + remoteURL: "https://server.example.com", + // Should fall through to well-known + mockServers: map[string]*httptest.Server{ + "server.example.com": createMockAuthServer(t, "https://server.example.com"), + }, + expectedIssuer: "https://server.example.com", + expectedScopes: []string{"openid", "profile"}, + expectedAuthServer: true, + expectError: false, + }, + { + name: "localhost http realm allowed", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "http://localhost:8080", + }, + remoteURL: "https://server.example.com", + expectedIssuer: "http://localhost:8080", + expectedScopes: nil, + expectError: false, + }, + { + name: "malformed resource metadata URL", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: "not-a-url", + }, + remoteURL: "https://server.example.com", + expectError: true, + errorContains: "could not determine OAuth issuer", + }, + + // Edge cases + { + name: "empty auth info", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + }, + remoteURL: "https://server.example.com", + mockServers: map[string]*httptest.Server{ + "server.example.com": createMockAuthServer(t, "https://server.example.com"), + }, + expectedIssuer: "https://server.example.com", + expectedScopes: []string{"openid", "profile"}, + expectedAuthServer: true, + expectError: false, + }, + { + name: "all discovery methods fail", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + }, + remoteURL: "", // Will be set from mock server + mockServers: map[string]*httptest.Server{ + "localhost": createMock404Server(t), + }, + expectedIssuer: "", // Will be set dynamically to match server URL + expectedScopes: nil, + expectError: false, + }, + { + name: "malformed remote URL", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + }, + remoteURL: "not-a-url", + expectError: true, + errorContains: "could not determine OAuth issuer", + }, + { + name: "configured scopes used with discovered issuer", + config: &RemoteAuthConfig{ + Scopes: []string{"custom", "scopes"}, + }, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://auth.example.com", + }, + remoteURL: "https://server.example.com", + expectedIssuer: "https://auth.example.com", + expectedScopes: []string{"custom", "scopes"}, + expectError: false, + }, + { + name: "resource metadata with scopes", + config: &RemoteAuthConfig{}, + authInfo: &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: "dynamic-scopes", // Special marker for dynamic setup + }, + remoteURL: "https://server.example.com", + mockServers: map[string]*httptest.Server{ + "dynamic": nil, // Will be created with linked servers + }, + expectedIssuer: "dynamic", // Will be set to auth server URL + expectedScopes: []string{"resource", "scopes"}, // Scopes from metadata are used + expectedAuthServer: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // Convert to testCase for helper functions + tc := &testCase{ + name: tt.name, + config: tt.config, + authInfo: tt.authInfo, + remoteURL: tt.remoteURL, + mockServers: tt.mockServers, + expectedIssuer: tt.expectedIssuer, + expectedScopes: tt.expectedScopes, + expectedAuthServer: tt.expectedAuthServer, + expectError: tt.expectError, + errorContains: tt.errorContains, + } + + // Process test servers using helper function + setup, authInfo, remoteURL, expectedIssuer := processTestServers(t, tc) + defer setup.cleanup() + + // Update expected issuer from processing + if expectedIssuer != "" && expectedIssuer != tt.expectedIssuer { + tt.expectedIssuer = expectedIssuer + } + + handler := &RemoteAuthHandler{ + config: tt.config, + } + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + issuer, scopes, authServerInfo, err := handler.discoverIssuerAndScopes( + ctx, + authInfo, + remoteURL, + ) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedIssuer, issuer, "issuer mismatch") + assert.Equal(t, tt.expectedScopes, scopes, "scopes mismatch") + + if tt.expectedAuthServer { + assert.NotNil(t, authServerInfo, "expected auth server info") + if authServerInfo != nil { + assert.Equal(t, tt.expectedIssuer, authServerInfo.Issuer, "auth server issuer mismatch") + assert.NotEmpty(t, authServerInfo.AuthorizationURL, "authorization URL should not be empty") + assert.NotEmpty(t, authServerInfo.TokenURL, "token URL should not be empty") + } + } else { + assert.Nil(t, authServerInfo, "expected no auth server info") + } + }) + } +} + +// Helper functions to create mock servers + +func createMockAuthServer(t *testing.T, issuer string) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle all possible well-known paths + if strings.Contains(r.URL.Path, "/.well-known/oauth-authorization-server") || + strings.Contains(r.URL.Path, "/.well-known/openid-configuration") { + w.Header().Set("Content-Type", "application/json") + // Use the provided issuer, or if empty, use the actual server URL + actualIssuer := issuer + if actualIssuer == "" { + actualIssuer = "http://" + r.Host + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": actualIssuer, + "authorization_endpoint": actualIssuer + "/authorize", + "token_endpoint": actualIssuer + "/token", + "registration_endpoint": actualIssuer + "/register", + }) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) +} + +func createMock404Server(t *testing.T) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) +} + +func createMockResourceMetadataServer(t *testing.T, authServers []string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == resourceMetadataPath { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "resource": "https://resource.example.com", + "authorization_servers": authServers, + }) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) +} + +func createMockResourceMetadataServerWithScopes(t *testing.T, authServers []string, scopes []string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == resourceMetadataPath { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "resource": "https://resource.example.com", + "authorization_servers": authServers, + "scopes_supported": scopes, + }) + } else { + w.WriteHeader(http.StatusNotFound) + } + })) +} + +// Security-focused tests +func TestDiscoverIssuerAndScopes_Security(t *testing.T) { + t.Parallel() + + t.Run("prevents issuer injection via realm", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + // Try to inject a malicious issuer via realm + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://evil.com/../../legitimate.com", + } + + ctx := t.Context() + issuer, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com") + + require.NoError(t, err) + // The path traversal should be normalized + assert.NotContains(t, issuer, "..") + }) + + t.Run("validates HTTPS for non-localhost", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + Realm: "http://external.example.com", // HTTP not HTTPS + } + + mockServer := createMockAuthServer(t, "https://fallback.example.com") + defer mockServer.Close() + + ctx := t.Context() + issuer, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, mockServer.URL) + + require.NoError(t, err) + // Should not use the insecure realm, should fall through + assert.NotEqual(t, "http://external.example.com", issuer) + }) + + t.Run("handles malicious resource metadata response", func(t *testing.T) { + t.Parallel() + maliciousServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == resourceMetadataPath { + // Send a huge response to try DoS + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"resource": "`)) + for i := 0; i < 10000000; i++ { + w.Write([]byte("A")) + } + w.Write([]byte(`"}`)) + } + })) + defer maliciousServer.Close() + + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: maliciousServer.URL + resourceMetadataPath, + } + + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) + defer cancel() + + _, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com") + + // Should timeout or fail gracefully, not hang or crash + assert.Error(t, err) + }) +} + +// Test the helper functions +func TestTryDiscoverFromWellKnown(t *testing.T) { + t.Parallel() + + t.Run("discovers actual issuer from localhost server", func(t *testing.T) { + t.Parallel() + // For localhost test servers, the issuer will be the server's HTTP URL + mockServer := createMockAuthServer(t, "") // Will use actual server URL + defer mockServer.Close() + + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + ctx := t.Context() + issuer, scopes, authInfo, err := handler.tryDiscoverFromWellKnown(ctx, mockServer.URL) + + require.NoError(t, err) + assert.Equal(t, mockServer.URL, issuer) // For localhost, issuer matches server URL + assert.Equal(t, []string{"openid", "profile"}, scopes) // Default scopes + assert.NotNil(t, authInfo) + assert.Equal(t, mockServer.URL, authInfo.Issuer) + }) + + t.Run("uses configured scopes", func(t *testing.T) { + t.Parallel() + mockServer := createMockAuthServer(t, "") // Will use actual server URL + defer mockServer.Close() + + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{ + Scopes: []string{"custom", "scopes"}, + }, + } + + ctx := t.Context() + issuer, scopes, _, err := handler.tryDiscoverFromWellKnown(ctx, mockServer.URL) + + require.NoError(t, err) + assert.Equal(t, mockServer.URL, issuer) // For localhost, issuer matches server URL + assert.Equal(t, []string{"custom", "scopes"}, scopes) + }) + + t.Run("handles discovery failure", func(t *testing.T) { + t.Parallel() + mockServer := createMock404Server(t) + defer mockServer.Close() + + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + ctx := t.Context() + _, _, _, err := handler.tryDiscoverFromWellKnown(ctx, mockServer.URL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "well-known discovery failed") + }) +} + +// TestDiscoveryPriorityChain tests that the discovery follows the correct priority order +func TestDiscoveryPriorityChain(t *testing.T) { + t.Parallel() + + t.Run("configured issuer takes highest priority", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{ + Issuer: "https://configured.example.com", + Scopes: []string{"custom"}, + }, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://realm.example.com", + ResourceMetadata: "https://metadata.example.com", + } + + ctx := context.Background() + issuer, scopes, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com") + + require.NoError(t, err) + assert.Equal(t, "https://configured.example.com", issuer) + assert.Equal(t, []string{"custom"}, scopes) + }) + + t.Run("realm URL used when no configured issuer", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + Realm: "https://realm.example.com/oauth", + } + + ctx := context.Background() + issuer, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com") + + require.NoError(t, err) + assert.Equal(t, "https://realm.example.com/oauth", issuer) + }) + + t.Run("non-URL realm falls through to URL derivation", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + Realm: "OAuth", // Not a URL, like Atlassian + } + + ctx := context.Background() + issuer, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com") + + require.NoError(t, err) + // Should fall through to URL-derived issuer + assert.Equal(t, "https://server.example.com", issuer) + }) + + t.Run("empty auth info falls through to URL derivation", func(t *testing.T) { + t.Parallel() + handler := &RemoteAuthHandler{ + config: &RemoteAuthConfig{}, + } + + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + } + + ctx := context.Background() + issuer, _, _, err := handler.discoverIssuerAndScopes(ctx, authInfo, "https://server.example.com/path") + + require.NoError(t, err) + assert.Equal(t, "https://server.example.com", issuer) + }) +} diff --git a/pkg/runner/remote_auth_test_helpers_test.go b/pkg/runner/remote_auth_test_helpers_test.go new file mode 100644 index 000000000..21ea825e8 --- /dev/null +++ b/pkg/runner/remote_auth_test_helpers_test.go @@ -0,0 +1,202 @@ +package runner + +import ( + "net/http/httptest" + "strings" + "testing" + + "github.com/stacklok/toolhive/pkg/auth/discovery" +) + +const ( + dynamicTestType = "dynamic" +) + +// testServerSetup holds the mock servers for a test +type testServerSetup struct { + MetadataServer *httptest.Server + AuthServer *httptest.Server + InvalidServer *httptest.Server + Servers map[string]*httptest.Server +} + +// cleanup closes all servers +func (s *testServerSetup) cleanup() { + if s.MetadataServer != nil { + s.MetadataServer.Close() + } + if s.AuthServer != nil { + s.AuthServer.Close() + } + if s.InvalidServer != nil { + s.InvalidServer.Close() + } + for _, server := range s.Servers { + if server != nil { + server.Close() + } + } +} + +// setupResourceMetadataTest creates linked mock servers for resource metadata testing +func setupResourceMetadataTest(t *testing.T, testType string) (*testServerSetup, *discovery.AuthInfo, string) { + t.Helper() + setup := &testServerSetup{ + Servers: make(map[string]*httptest.Server), + } + + // Create auth server + setup.AuthServer = createMockAuthServer(t, "") + + var authServers []string + var scopes []string + + switch testType { + case "multi-server": + // Create invalid server for multi-server test + setup.InvalidServer = createMock404Server(t) + authServers = []string{setup.InvalidServer.URL, setup.AuthServer.URL} + case "with-scopes": + authServers = []string{setup.AuthServer.URL} + scopes = []string{"resource", "scopes"} + default: + authServers = []string{setup.AuthServer.URL} + } + + // Create metadata server with proper auth server URLs + if len(scopes) > 0 { + setup.MetadataServer = createMockResourceMetadataServerWithScopes(t, authServers, scopes) + } else { + setup.MetadataServer = createMockResourceMetadataServer(t, authServers) + } + + // Create auth info with actual metadata URL + authInfo := &discovery.AuthInfo{ + Type: "OAuth", + ResourceMetadata: setup.MetadataServer.URL + resourceMetadataPath, + } + + // Return the expected issuer (auth server URL) + return setup, authInfo, setup.AuthServer.URL +} + +// processTestServers handles the server setup for a test case +func processTestServers(t *testing.T, tt *testCase) (*testServerSetup, *discovery.AuthInfo, string, string) { + t.Helper() + // Handle special dynamic test cases + if tt.authInfo != nil && tt.authInfo.ResourceMetadata != "" { + switch tt.authInfo.ResourceMetadata { + case dynamicTestType: + setup, authInfo, expectedIssuer := setupResourceMetadataTest(t, "single-server") + if tt.expectedIssuer == dynamicTestType { + tt.expectedIssuer = expectedIssuer + } + return setup, authInfo, tt.remoteURL, tt.expectedIssuer + + case "dynamic-multi": + setup, authInfo, expectedIssuer := setupResourceMetadataTest(t, "multi-server") + if tt.expectedIssuer == dynamicTestType { + tt.expectedIssuer = expectedIssuer + } + return setup, authInfo, tt.remoteURL, tt.expectedIssuer + + case "dynamic-scopes": + setup, authInfo, expectedIssuer := setupResourceMetadataTest(t, "with-scopes") + if tt.expectedIssuer == dynamicTestType { + tt.expectedIssuer = expectedIssuer + } + return setup, authInfo, tt.remoteURL, tt.expectedIssuer + } + } + + // Handle regular mock servers + setup := &testServerSetup{ + Servers: make(map[string]*httptest.Server), + } + + authInfo := tt.authInfo + remoteURL := tt.remoteURL + + // Set up mock servers from test definition + for host, server := range tt.mockServers { + if host == "localhost" && server == nil { + if containsAny(tt.name, "404", "all discovery methods fail") { + server = createMock404Server(t) + } else { + server = createMockAuthServer(t, "") + } + } + setup.Servers[host] = server + } + + // Process URLs + if len(setup.Servers) > 0 { + remoteURL, tt.expectedIssuer = processURLsForServers(tt, authInfo, remoteURL, setup.Servers) + } + + return setup, authInfo, remoteURL, tt.expectedIssuer +} + +// processURLsForServers updates URLs to use mock server addresses +func processURLsForServers(tt *testCase, authInfo *discovery.AuthInfo, remoteURL string, servers map[string]*httptest.Server) (string, string) { + expectedIssuer := tt.expectedIssuer + + // For resource metadata tests + if authInfo != nil && authInfo.ResourceMetadata != "" && !containsAny(authInfo.ResourceMetadata, "dynamic") { + for host, server := range servers { + if containsAny(authInfo.ResourceMetadata, host) { + authInfo.ResourceMetadata = replaceFirst(authInfo.ResourceMetadata, "https://"+host, server.URL) + break + } + } + } + + // For well-known discovery tests + if remoteURL == "" && servers["localhost"] != nil { + remoteURL = servers["localhost"].URL + if expectedIssuer == "" { + if containsAny(tt.name, "malformed resource metadata") { + expectedIssuer = servers["localhost"].URL + } else if containsAny(tt.name, "fallback", "all discovery") { + expectedIssuer = servers["localhost"].URL + } + } + } else { + for host, server := range servers { + if containsAny(remoteURL, host) { + remoteURL = replaceFirst(remoteURL, "https://"+host, server.URL) + break + } + } + } + + return remoteURL, expectedIssuer +} + +// Helper functions +func containsAny(s string, substrs ...string) bool { + for _, substr := range substrs { + if strings.Contains(s, substr) { + return true + } + } + return false +} + +func replaceFirst(s, old, replacement string) string { + return strings.Replace(s, old, replacement, 1) +} + +// testCase represents a single test case +type testCase struct { + name string + config *RemoteAuthConfig + authInfo *discovery.AuthInfo + remoteURL string + mockServers map[string]*httptest.Server + expectedIssuer string + expectedScopes []string + expectedAuthServer bool + expectError bool + errorContains string +}