@@ -9,9 +9,10 @@ import (
99 "testing"
1010 "time"
1111
12- "github.com/mark3labs/mcp-go/mcp"
1312 "github.com/stretchr/testify/assert"
1413 "github.com/stretchr/testify/require"
14+
15+ "github.com/mark3labs/mcp-go/mcp"
1516)
1617
1718// sessionTestClient implements the basic ClientSession interface for testing
@@ -99,12 +100,49 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
99100 f .sessionTools = toolsCopy
100101}
101102
103+ // sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
104+ type sessionTestClientWithClientInfo struct {
105+ sessionID string
106+ notificationChannel chan mcp.JSONRPCNotification
107+ initialized bool
108+ clientInfo atomic.Value
109+ }
110+
111+ func (f * sessionTestClientWithClientInfo ) SessionID () string {
112+ return f .sessionID
113+ }
114+
115+ func (f * sessionTestClientWithClientInfo ) NotificationChannel () chan <- mcp.JSONRPCNotification {
116+ return f .notificationChannel
117+ }
118+
119+ func (f * sessionTestClientWithClientInfo ) Initialize () {
120+ f .initialized = true
121+ }
122+
123+ func (f * sessionTestClientWithClientInfo ) Initialized () bool {
124+ return f .initialized
125+ }
126+
127+ func (f * sessionTestClientWithClientInfo ) GetClientInfo () mcp.Implementation {
128+ if value := f .clientInfo .Load (); value != nil {
129+ if clientInfo , ok := value .(mcp.Implementation ); ok {
130+ return clientInfo
131+ }
132+ }
133+ return mcp.Implementation {}
134+ }
135+
136+ func (f * sessionTestClientWithClientInfo ) SetClientInfo (clientInfo mcp.Implementation ) {
137+ f .clientInfo .Store (clientInfo )
138+ }
139+
102140// sessionTestClientWithTools implements the SessionWithLogging interface for testing
103141type sessionTestClientWithLogging struct {
104142 sessionID string
105143 notificationChannel chan mcp.JSONRPCNotification
106144 initialized bool
107- loggingLevel atomic.Value
145+ loggingLevel atomic.Value
108146}
109147
110148func (f * sessionTestClientWithLogging ) SessionID () string {
@@ -136,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
136174
137175// Verify that all implementations satisfy their respective interfaces
138176var (
139- _ ClientSession = (* sessionTestClient )(nil )
140- _ SessionWithTools = (* sessionTestClientWithTools )(nil )
141- _ SessionWithLogging = (* sessionTestClientWithLogging )(nil )
177+ _ ClientSession = (* sessionTestClient )(nil )
178+ _ SessionWithTools = (* sessionTestClientWithTools )(nil )
179+ _ SessionWithLogging = (* sessionTestClientWithLogging )(nil )
180+ _ SessionWithClientInfo = (* sessionTestClientWithClientInfo )(nil )
142181)
143182
144183func TestSessionWithTools_Integration (t * testing.T ) {
@@ -1041,4 +1080,49 @@ func TestMCPServer_SetLevel(t *testing.T) {
10411080 if session .GetLogLevel () != mcp .LoggingLevelCritical {
10421081 t .Errorf ("Expected critical level, got %v" , session .GetLogLevel ())
10431082 }
1044- }
1083+ }
1084+
1085+ func TestSessionWithClientInfo_Integration (t * testing.T ) {
1086+ server := NewMCPServer ("test-server" , "1.0.0" )
1087+
1088+ session := & sessionTestClientWithClientInfo {
1089+ sessionID : "session-1" ,
1090+ notificationChannel : make (chan mcp.JSONRPCNotification , 10 ),
1091+ initialized : false ,
1092+ }
1093+
1094+ err := server .RegisterSession (context .Background (), session )
1095+ require .NoError (t , err )
1096+
1097+ clientInfo := mcp.Implementation {
1098+ Name : "test-client" ,
1099+ Version : "1.0.0" ,
1100+ }
1101+
1102+ initRequest := mcp.InitializeRequest {}
1103+ initRequest .Params .ClientInfo = clientInfo
1104+ initRequest .Params .ProtocolVersion = mcp .LATEST_PROTOCOL_VERSION
1105+ initRequest .Params .Capabilities = mcp.ClientCapabilities {}
1106+
1107+ sessionCtx := server .WithContext (context .Background (), session )
1108+
1109+ // Retrieve the session from context
1110+ retrievedSession := ClientSessionFromContext (sessionCtx )
1111+ require .NotNil (t , retrievedSession , "Session should be available from context" )
1112+ assert .Equal (t , session .SessionID (), retrievedSession .SessionID (), "Session ID should match" )
1113+
1114+ result , reqErr := server .handleInitialize (sessionCtx , 1 , initRequest )
1115+ require .Nil (t , reqErr )
1116+ require .NotNil (t , result )
1117+
1118+ // Check if the session can be cast to SessionWithClientInfo
1119+ sessionWithClientInfo , ok := retrievedSession .(SessionWithClientInfo )
1120+ require .True (t , ok , "Session should implement SessionWithClientInfo" )
1121+
1122+ assert .True (t , sessionWithClientInfo .Initialized (), "Session should be initialized" )
1123+
1124+ storedClientInfo := sessionWithClientInfo .GetClientInfo ()
1125+
1126+ assert .Equal (t , clientInfo .Name , storedClientInfo .Name , "Client name should match" )
1127+ assert .Equal (t , clientInfo .Version , storedClientInfo .Version , "Client version should match" )
1128+ }
0 commit comments