88 "net/http"
99 "net/http/httptest"
1010 "net/url"
11+ "path"
1112 "strings"
1213 "sync"
1314 "sync/atomic"
@@ -34,6 +35,13 @@ type sseSession struct {
3435// content. This can be used to inject context values from headers, for example.
3536type SSEContextFunc func (ctx context.Context , r * http.Request ) context.Context
3637
38+ // DynamicBasePathFunc allows the user to provide a function to generate the
39+ // base path for a given request and sessionID. This is useful for cases where
40+ // the base path is not known at the time of SSE server creation, such as when
41+ // using a reverse proxy or when the base path is dynamically generated. The
42+ // function should return the base path (e.g., "/mcp/tenant123").
43+ type DynamicBasePathFunc func (r * http.Request , sessionID string ) string
44+
3745func (s * sseSession ) SessionID () string {
3846 return s .sessionID
3947}
@@ -58,19 +66,19 @@ type SSEServer struct {
5866 server * MCPServer
5967 baseURL string
6068 basePath string
69+ appendQueryToMessageEndpoint bool
6170 useFullURLForMessageEndpoint bool
6271 messageEndpoint string
6372 sseEndpoint string
6473 sessions sync.Map
6574 srv * http.Server
6675 contextFunc SSEContextFunc
76+ dynamicBasePathFunc DynamicBasePathFunc
6777
6878 keepAlive bool
6979 keepAliveInterval time.Duration
7080
7181 mu sync.RWMutex
72-
73- appendQueryToMessageEndpoint bool
7482}
7583
7684// SSEOption defines a function type for configuring SSEServer
@@ -99,14 +107,25 @@ func WithBaseURL(baseURL string) SSEOption {
99107 }
100108}
101109
102- // WithBasePath adds a new option for setting base path
110+ // WithBasePath adds a new option for setting a static base path
103111func WithBasePath (basePath string ) SSEOption {
104112 return func (s * SSEServer ) {
105- // Ensure the path starts with / and doesn't end with /
106- if ! strings .HasPrefix (basePath , "/" ) {
107- basePath = "/" + basePath
113+ s .basePath = normalizeURLPath (basePath )
114+ }
115+ }
116+
117+ // WithDynamicBasePath accepts a function for generating the base path. This is
118+ // useful for cases where the base path is not known at the time of SSE server
119+ // creation, such as when using a reverse proxy or when the server is mounted
120+ // at a dynamic path.
121+ func WithDynamicBasePath (fn DynamicBasePathFunc ) SSEOption {
122+ return func (s * SSEServer ) {
123+ if fn != nil {
124+ s .dynamicBasePathFunc = func (r * http.Request , sid string ) string {
125+ bp := fn (r , sid )
126+ return normalizeURLPath (bp )
127+ }
108128 }
109- s .basePath = strings .TrimSuffix (basePath , "/" )
110129 }
111130}
112131
@@ -208,8 +227,8 @@ func (s *SSEServer) Start(addr string) error {
208227
209228 if s .srv == nil {
210229 s .srv = & http.Server {
211- Addr : addr ,
212- Handler : s ,
230+ Addr : addr ,
231+ Handler : s ,
213232 }
214233 } else {
215234 if s .srv .Addr == "" {
@@ -218,7 +237,7 @@ func (s *SSEServer) Start(addr string) error {
218237 return fmt .Errorf ("conflicting listen address: WithHTTPServer(%q) vs Start(%q)" , s .srv .Addr , addr )
219238 }
220239 }
221-
240+
222241 return s .srv .ListenAndServe ()
223242}
224243
@@ -331,7 +350,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
331350 }
332351
333352 // Send the initial endpoint event
334- endpoint := s .GetMessageEndpointForClient (sessionID )
353+ endpoint := s .GetMessageEndpointForClient (r , sessionID )
335354 if s .appendQueryToMessageEndpoint && len (r .URL .RawQuery ) > 0 {
336355 endpoint += "&" + r .URL .RawQuery
337356 }
@@ -355,13 +374,20 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
355374}
356375
357376// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
358- // based on the useFullURLForMessageEndpoint configuration.
359- func (s * SSEServer ) GetMessageEndpointForClient (sessionID string ) string {
360- messageEndpoint := s .messageEndpoint
361- if s .useFullURLForMessageEndpoint {
362- messageEndpoint = s .CompleteMessageEndpoint ()
377+ // for the given request. This is the canonical way to compute the message endpoint for a client.
378+ // It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag.
379+ func (s * SSEServer ) GetMessageEndpointForClient (r * http.Request , sessionID string ) string {
380+ basePath := s .basePath
381+ if s .dynamicBasePathFunc != nil {
382+ basePath = s .dynamicBasePathFunc (r , sessionID )
383+ }
384+
385+ endpointPath := normalizeURLPath (basePath , s .messageEndpoint )
386+ if s .useFullURLForMessageEndpoint && s .baseURL != "" {
387+ endpointPath = s .baseURL + endpointPath
363388 }
364- return fmt .Sprintf ("%s?sessionId=%s" , messageEndpoint , sessionID )
389+
390+ return fmt .Sprintf ("%s?sessionId=%s" , endpointPath , sessionID )
365391}
366392
367393// handleMessage processes incoming JSON-RPC messages from clients and sends responses
@@ -479,32 +505,111 @@ func (s *SSEServer) GetUrlPath(input string) (string, error) {
479505 return parse .Path , nil
480506}
481507
482- func (s * SSEServer ) CompleteSseEndpoint () string {
483- return s .baseURL + s .basePath + s .sseEndpoint
508+ func (s * SSEServer ) CompleteSseEndpoint () (string , error ) {
509+ if s .dynamicBasePathFunc != nil {
510+ return "" , & ErrDynamicPathConfig {Method : "CompleteSseEndpoint" }
511+ }
512+
513+ path := normalizeURLPath (s .basePath , s .sseEndpoint )
514+ return s .baseURL + path , nil
484515}
485516
486517func (s * SSEServer ) CompleteSsePath () string {
487- path , err := s .GetUrlPath (s .CompleteSseEndpoint ())
518+ path , err := s .CompleteSseEndpoint ()
519+ if err != nil {
520+ return normalizeURLPath (s .basePath , s .sseEndpoint )
521+ }
522+ urlPath , err := s .GetUrlPath (path )
488523 if err != nil {
489- return s .basePath + s .sseEndpoint
524+ return normalizeURLPath ( s .basePath , s .sseEndpoint )
490525 }
491- return path
526+ return urlPath
492527}
493528
494- func (s * SSEServer ) CompleteMessageEndpoint () string {
495- return s .baseURL + s .basePath + s .messageEndpoint
529+ func (s * SSEServer ) CompleteMessageEndpoint () (string , error ) {
530+ if s .dynamicBasePathFunc != nil {
531+ return "" , & ErrDynamicPathConfig {Method : "CompleteMessageEndpoint" }
532+ }
533+ path := normalizeURLPath (s .basePath , s .messageEndpoint )
534+ return s .baseURL + path , nil
496535}
497536
498537func (s * SSEServer ) CompleteMessagePath () string {
499- path , err := s .GetUrlPath (s .CompleteMessageEndpoint ())
538+ path , err := s .CompleteMessageEndpoint ()
539+ if err != nil {
540+ return normalizeURLPath (s .basePath , s .messageEndpoint )
541+ }
542+ urlPath , err := s .GetUrlPath (path )
500543 if err != nil {
501- return s .basePath + s .messageEndpoint
544+ return normalizeURLPath ( s .basePath , s .messageEndpoint )
502545 }
503- return path
546+ return urlPath
547+ }
548+
549+ // SSEHandler returns an http.Handler for the SSE endpoint.
550+ //
551+ // This method allows you to mount the SSE handler at any arbitrary path
552+ // using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
553+ // intended for advanced scenarios where you want to control the routing or
554+ // support dynamic segments.
555+ //
556+ // IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
557+ // you must use the WithDynamicBasePath option to ensure the correct base path
558+ // is communicated to clients.
559+ //
560+ // Example usage:
561+ //
562+ // // Advanced/dynamic:
563+ // sseServer := NewSSEServer(mcpServer,
564+ // WithDynamicBasePath(func(r *http.Request, sessionID string) string {
565+ // tenant := r.PathValue("tenant")
566+ // return "/mcp/" + tenant
567+ // }),
568+ // WithBaseURL("http://localhost:8080")
569+ // )
570+ // mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
571+ // mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
572+ //
573+ // For non-dynamic cases, use ServeHTTP method instead.
574+ func (s * SSEServer ) SSEHandler () http.Handler {
575+ return http .HandlerFunc (s .handleSSE )
576+ }
577+
578+ // MessageHandler returns an http.Handler for the message endpoint.
579+ //
580+ // This method allows you to mount the message handler at any arbitrary path
581+ // using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
582+ // intended for advanced scenarios where you want to control the routing or
583+ // support dynamic segments.
584+ //
585+ // IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
586+ // you must use the WithDynamicBasePath option to ensure the correct base path
587+ // is communicated to clients.
588+ //
589+ // Example usage:
590+ //
591+ // // Advanced/dynamic:
592+ // sseServer := NewSSEServer(mcpServer,
593+ // WithDynamicBasePath(func(r *http.Request, sessionID string) string {
594+ // tenant := r.PathValue("tenant")
595+ // return "/mcp/" + tenant
596+ // }),
597+ // WithBaseURL("http://localhost:8080")
598+ // )
599+ // mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
600+ // mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
601+ //
602+ // For non-dynamic cases, use ServeHTTP method instead.
603+ func (s * SSEServer ) MessageHandler () http.Handler {
604+ return http .HandlerFunc (s .handleMessage )
504605}
505606
506607// ServeHTTP implements the http.Handler interface.
507608func (s * SSEServer ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
609+ if s .dynamicBasePathFunc != nil {
610+ http .Error (w , (& ErrDynamicPathConfig {Method : "ServeHTTP" }).Error (), http .StatusInternalServerError )
611+ return
612+ }
508613 path := r .URL .Path
509614 // Use exact path matching rather than Contains
510615 ssePath := s .CompleteSsePath ()
@@ -520,3 +625,21 @@ func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
520625
521626 http .NotFound (w , r )
522627}
628+
629+ // normalizeURLPath joins path elements like path.Join but ensures the
630+ // result always starts with a leading slash and never ends with a slash
631+ func normalizeURLPath (elem ... string ) string {
632+ joined := path .Join (elem ... )
633+
634+ // Ensure leading slash
635+ if ! strings .HasPrefix (joined , "/" ) {
636+ joined = "/" + joined
637+ }
638+
639+ // Remove trailing slash if not just "/"
640+ if len (joined ) > 1 && strings .HasSuffix (joined , "/" ) {
641+ joined = joined [:len (joined )- 1 ]
642+ }
643+
644+ return joined
645+ }
0 commit comments