Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,14 @@ func buildRunnerConfig(
transportType = serverMetadata.GetTransport()
}

// Determine server name for telemetry (similar to validateConfig logic)
// This ensures telemetry middleware gets the correct server name
imageMetadata, _ := serverMetadata.(*registry.ImageMetadata)
serverName := runFlags.Name
if serverName == "" && imageMetadata != nil {
serverName = imageMetadata.Name
}

// set default options
opts := []runner.RunConfigBuilderOption{
runner.WithRuntime(rt),
Expand Down Expand Up @@ -443,6 +451,7 @@ func buildRunnerConfig(

opts = append(opts, runner.WithToolsOverride(toolsOverride))
// Configure middleware from flags
// Use computed serverName and transportType for correct telemetry labels
opts = append(
opts,
runner.WithMiddlewareFromFlags(
Expand All @@ -453,8 +462,8 @@ func buildRunnerConfig(
runFlags.AuthzConfig,
runFlags.EnableAudit,
runFlags.AuditConfig,
runFlags.Name,
runFlags.Transport,
serverName,
transportType,
),
)

Expand Down Expand Up @@ -490,14 +499,9 @@ func buildRunnerConfig(
),
runner.WithToolsFilter(runFlags.ToolsFilter))

imageMetadata, _ := serverMetadata.(*registry.ImageMetadata)
// Process environment files
var err error
if runFlags.EnvFile != "" {
opts = append(opts, runner.WithEnvFile(runFlags.EnvFile))
if err != nil {
return nil, fmt.Errorf("failed to process env file %s: %v", runFlags.EnvFile, err)
}
}
if runFlags.EnvFileDir != "" {
opts = append(opts, runner.WithEnvFilesFromDirectory(runFlags.EnvFileDir))
Expand Down
92 changes: 92 additions & 0 deletions cmd/thv/app/run_flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package app
import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -216,6 +217,97 @@ func TestBuildRunnerConfig_TelemetryProcessing(t *testing.T) {
}
}

func TestTelemetryMiddlewareParameterComputation(t *testing.T) {
// This test validates the telemetry middleware parameter computation
// by testing the logic that computes server name and transport type
// before calling WithMiddlewareFromFlags
t.Parallel()

logger.Initialize()

tests := []struct {
name string
runFlags *RunFlags
serverOrImage string
expectedServer string
expectedTransport string
}{
{
name: "explicit name and transport should use provided values",
runFlags: &RunFlags{
Name: "custom-server",
Transport: "http",
},
serverOrImage: "custom-server",
expectedServer: "custom-server",
expectedTransport: "http",
},
{
name: "empty name should be computed from image name",
runFlags: &RunFlags{
Transport: "sse",
},
serverOrImage: "docker://registry.test/my-test-server:latest",
expectedServer: "my-test-server", // Extracted from image name
expectedTransport: "sse",
},
{
name: "empty transport should use default",
runFlags: &RunFlags{
Name: "named-server",
},
serverOrImage: "named-server",
expectedServer: "named-server",
expectedTransport: "streamable-http", // Default from constant
},
{
name: "both empty should compute name and use default transport",
runFlags: &RunFlags{},
serverOrImage: "docker://example.com/path/server-name:v1.0",
expectedServer: "server-name", // Extracted from image
expectedTransport: "streamable-http", // Default
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Test the server name computation logic that was fixed
// This simulates the logic in BuildRunnerConfig before WithMiddlewareFromFlags

// 1. Test transport type computation (this was already working)
transportType := tt.runFlags.Transport
if transportType == "" {
transportType = defaultTransportType // "streamable-http"
}
assert.Equal(t, tt.expectedTransport, transportType, "Transport type should match expected")

// 2. Test server name computation
serverName := tt.runFlags.Name
if serverName == "" {
// This simulates the image metadata extraction logic
if strings.HasPrefix(tt.serverOrImage, "docker://") {
imagePath := strings.TrimPrefix(tt.serverOrImage, "docker://")
parts := strings.Split(imagePath, "/")
imageName := parts[len(parts)-1]
if colonIndex := strings.Index(imageName, ":"); colonIndex != -1 {
imageName = imageName[:colonIndex]
}
serverName = imageName
} else {
serverName = tt.serverOrImage
}
}
assert.Equal(t, tt.expectedServer, serverName, "Server name should match expected")

// 3. Verify both parameters are non-empty for proper middleware function
assert.NotEmpty(t, serverName, "Server name should never be empty for middleware")
assert.NotEmpty(t, transportType, "Transport type should never be empty for middleware")
})
}
}

func TestBuildRunnerConfig_TelemetryProcessing_Integration(t *testing.T) {
t.Parallel()
// This is a more complete integration test that tests telemetry processing
Expand Down
12 changes: 10 additions & 2 deletions pkg/api/v1/workload_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
var remoteAuthConfig *runner.RemoteAuthConfig
var imageURL string
var imageMetadata *registry.ImageMetadata
var serverMetadata registry.ServerMetadata

if req.URL != "" {
// Configure remote authentication if OAuth config is provided
Expand All @@ -131,7 +132,6 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
return nil, err
}
} else {
var serverMetadata registry.ServerMetadata
// Create a dedicated context with longer timeout for image retrieval
imageCtx, cancel := context.WithTimeout(ctx, imageRetrievalTimeout)
defer cancel()
Expand Down Expand Up @@ -216,6 +216,14 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
runner.WithTelemetryConfig("", false, false, false, "", 0.0, nil, false, nil),
}

// Determine transport type
transportType := "streamable-http"
if req.Transport != "" {
transportType = req.Transport
} else if serverMetadata != nil {
transportType = serverMetadata.GetTransport()
}

// Configure middleware from flags
options = append(options,
runner.WithMiddlewareFromFlags(
Expand All @@ -227,7 +235,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
false,
"",
req.Name,
req.Transport,
transportType,
),
)

Expand Down
32 changes: 20 additions & 12 deletions pkg/telemetry/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ func (m *HTTPMiddleware) addMCPAttributes(ctx context.Context, span trace.Span,
span.SetAttributes(attribute.String("mcp.server.name", serverName))

// Determine backend transport type
// Note: ToolHive always serves SSE to clients, but backends can be stdio or sse
// Note: ToolHive supports multiple transport types including stdio, sse, streamable-http
// The transport should never be empty as both CLI and API have fallbacks to "streamable-http"
// If transport is still empty, it indicates a configuration issue in middleware construction
backendTransport := m.extractBackendTransport(r)
span.SetAttributes(attribute.String("mcp.transport", backendTransport))

Expand Down Expand Up @@ -281,25 +283,31 @@ func (m *HTTPMiddleware) addMethodSpecificAttributes(span trace.Span, parsedMCP
}
}

// extractServerName extracts the MCP server name from the HTTP request using multiple fallback strategies.
// It first checks for the X-MCP-Server-Name header, then extracts from URL path segments
// (skipping common prefixes like "sse", "messages", "api", "v1"), and finally falls back
// to the middleware's configured server name.
// extractServerName extracts the MCP server name from the HTTP request.
// It checks for an explicit X-MCP-Server-Name header first, then falls back to the
// configured server name. This approach is more reliable than parsing URL paths since
// the server name is already known during middleware construction.
func (m *HTTPMiddleware) extractServerName(r *http.Request) string {
// Check for explicit server name header (for advanced routing scenarios)
if serverName := r.Header.Get("X-MCP-Server-Name"); serverName != "" {
return serverName
}
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
for _, part := range pathParts {
if part != "" && part != "sse" && part != "messages" && part != "api" && part != "v1" {
return part
}
}

// Always use the configured server name - this is the correct server name
// that was passed during middleware construction and doesn't depend on URL structure
//
// NOTE: Previously this function attempted to parse server names from URL paths by
// splitting r.URL.Path and filtering out known endpoint segments like "sse", "messages",
// "api", "v1", etc. This approach was fundamentally flawed because:
// 1. It incorrectly treated endpoint names like "message" as server names
// 2. It made assumptions about URL structure that don't always hold
// 3. The actual server name is already available via m.serverName
// Adding more exclusions (like "message") would just be treating symptoms, not the root cause.
return m.serverName
}

// extractBackendTransport determines the backend transport type.
// ToolHive always serves SSE to clients, but backends can be stdio or sse.
// ToolHive supports multiple transport types: stdio, sse, streamable-http.
func (m *HTTPMiddleware) extractBackendTransport(r *http.Request) string {
// Try to get transport info from custom headers (if set by proxy)
if transport := r.Header.Get("X-MCP-Transport"); transport != "" {
Expand Down
12 changes: 7 additions & 5 deletions pkg/telemetry/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ func TestHTTPMiddleware_FormatRequestID(t *testing.T) {
func TestHTTPMiddleware_ExtractServerName(t *testing.T) {
t.Parallel()

middleware := &HTTPMiddleware{}
middleware := &HTTPMiddleware{
serverName: "test-server", // Set a configured server name for testing
}

tests := []struct {
name string
Expand All @@ -456,23 +458,23 @@ func TestHTTPMiddleware_ExtractServerName(t *testing.T) {
{
name: "from path",
path: "/api/v1/github/messages",
expected: "github",
expected: "test-server", // Now uses configured server name instead of path parsing
},
{
name: "from path with sse",
path: "/sse/weather/messages",
expected: "weather",
expected: "test-server", // Now uses configured server name instead of path parsing
},
{
name: "fallback to serverName",
path: "/messages",
query: "session_id=abc123",
expected: "", // Falls back to m.serverName which is empty in test
expected: "test-server", // Uses configured server name
},
{
name: "unknown",
path: "/health",
expected: "health", // "health" is not in the skip list, so it's extracted from path
expected: "test-server", // Now uses configured server name instead of path parsing
},
}

Expand Down
Loading
Loading