diff --git a/src/app/components/chat-panel/chat-panel.component.scss b/src/app/components/chat-panel/chat-panel.component.scss index b8640c5f..72112e88 100644 --- a/src/app/components/chat-panel/chat-panel.component.scss +++ b/src/app/components/chat-panel/chat-panel.component.scss @@ -382,7 +382,8 @@ button.audio-rec-btn, button.video-rec-btn { background-color: var(--chat-card-background-color); &.recording { - background-color: var(--chat-panel-eval-fail-color); + background-color: var(--chat-panel-eval-fail-color) !important; + color: white !important; } } diff --git a/src/app/components/chat/chat.component.spec.ts b/src/app/components/chat/chat.component.spec.ts index 745e8e9d..ab86fb06 100644 --- a/src/app/components/chat/chat.component.spec.ts +++ b/src/app/components/chat/chat.component.spec.ts @@ -806,15 +806,92 @@ describe('ChatComponent', () => { describe('when bidi streaming is restarted', () => { beforeEach(() => { - component.sessionHasUsedBidi.add(component.sessionId); + component.startAudioRecording(); + component.stopAudioRecording(); component.startAudioRecording(); }); - it('should show snackbar', () => { - expect(mockSnackBar.open) - .toHaveBeenCalledWith( - 'Restarting bidirectional streaming is not currently supported. Please refresh the page or start a new session.', - OK_BUTTON_TEXT, - ); + it('should allow restart without error', () => { + expect(component.isAudioRecording).toBe(true); + expect(mockStreamChatService.startAudioChat).toHaveBeenCalledTimes(2); + }); + }); + + describe('when audio recording is stopped and restarted', () => { + beforeEach(() => { + component.startAudioRecording(); + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true); + component.stopAudioRecording(); + }); + it('should remove session from sessionHasUsedBidi set', () => { + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false); + }); + + it('should allow restarting audio recording', () => { + component.startAudioRecording(); + expect(mockSnackBar.open).not.toHaveBeenCalled(); + expect(component.isAudioRecording).toBe(true); + }); + }); + + describe('when video recording is stopped and restarted', () => { + beforeEach(() => { + component.startVideoRecording(); + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true); + component.stopVideoRecording(); + }); + + it('should remove session from sessionHasUsedBidi set', () => { + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false); + }); + + it('should allow restarting video recording', () => { + component.startVideoRecording(); + expect(mockSnackBar.open).not.toHaveBeenCalled(); + expect(component.isVideoRecording).toBe(true); + }); + }); + + describe('when trying to start concurrent bidi streams', () => { + it('should prevent starting audio while already recording', () => { + component.startAudioRecording(); + expect(component.isAudioRecording).toBe(true); + + component.startAudioRecording(); + + expect(mockSnackBar.open).toHaveBeenCalledWith( + 'Another streaming request is already in progress. Please stop it before starting a new one.', + 'OK' + ); + expect(mockStreamChatService.startAudioChat).toHaveBeenCalledTimes(1); + }); + + it('should prevent starting video while already recording', () => { + component.startVideoRecording(); + expect(component.isVideoRecording).toBe(true); + + component.startVideoRecording(); + + expect(mockSnackBar.open).toHaveBeenCalledWith( + 'Another streaming request is already in progress. Please stop it before starting a new one.', + 'OK' + ); + expect(mockStreamChatService.startVideoChat).toHaveBeenCalledTimes(1); + }); + }); + + describe('when stopping video recording without videoContainer', () => { + it('should still cleanup sessionHasUsedBidi', () => { + component.startVideoRecording(); + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(true); + + spyOn(component, 'chatPanel').and.returnValue({ + videoContainer: undefined + } as any); + + component.stopVideoRecording(); + + expect(component.sessionHasUsedBidi.has(component.sessionId)).toBe(false); + expect(component.isVideoRecording).toBe(false); }); }); }); diff --git a/src/app/components/chat/chat.component.ts b/src/app/components/chat/chat.component.ts index cc810a94..3b78b9f7 100644 --- a/src/app/components/chat/chat.component.ts +++ b/src/app/components/chat/chat.component.ts @@ -113,7 +113,7 @@ class CustomPaginatorIntl extends MatPaginatorIntl { } const BIDI_STREAMING_RESTART_WARNING = - 'Restarting bidirectional streaming is not currently supported. Please refresh the page or start a new session.'; + 'Another streaming request is already in progress. Please stop it before starting a new one.'; @Component({ selector: 'app-chat', @@ -213,7 +213,6 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy { private readonly isModelThinkingSubject = new BehaviorSubject(false); protected readonly canEditSession = signal(true); - // TODO: Remove this once backend supports restarting bidi streaming. sessionHasUsedBidi = new Set(); eventData = new Map(); @@ -1018,11 +1017,14 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy { {role: 'bot', text: 'Speaking...'}, ]); this.sessionHasUsedBidi.add(this.sessionId); + this.changeDetectorRef.detectChanges(); } stopAudioRecording() { this.streamChatService.stopAudioChat(); this.isAudioRecording = false; + this.sessionHasUsedBidi.delete(this.sessionId); + this.changeDetectorRef.detectChanges(); } toggleVideoRecording() { @@ -1049,15 +1051,17 @@ export class ChatComponent implements OnInit, AfterViewInit, OnDestroy { this.messages.update( messages => [...messages, {role: 'user', text: 'Speaking...'}]); this.sessionHasUsedBidi.add(this.sessionId); + this.changeDetectorRef.detectChanges(); } stopVideoRecording() { const videoContainer = this.chatPanel()?.videoContainer; - if (!videoContainer) { - return; + if (videoContainer) { + this.streamChatService.stopVideoChat(videoContainer); } - this.streamChatService.stopVideoChat(videoContainer); this.isVideoRecording = false; + this.sessionHasUsedBidi.delete(this.sessionId); + this.changeDetectorRef.detectChanges(); } private getAsyncFunctionsFromParts( diff --git a/src/app/core/services/stream-chat.service.spec.ts b/src/app/core/services/stream-chat.service.spec.ts index d943876e..68c207be 100644 --- a/src/app/core/services/stream-chat.service.spec.ts +++ b/src/app/core/services/stream-chat.service.spec.ts @@ -263,4 +263,64 @@ describe('StreamChatService', () => { expect(mockWebSocketService.sendMessage).toHaveBeenCalledTimes(2); })); }); + + describe('restart audio chat', () => { + it('should allow restarting audio chat after stopping', async () => { + mockAudioRecordingService.getCombinedAudioBuffer.and.returnValue( + Uint8Array.of()); + + await service.startAudioChat({ + appName: 'fake-app-name', + userId: 'fake-user-id', + sessionId: 'fake-session-id' + }); + expect(mockWebSocketService.connect).toHaveBeenCalledTimes(1); + expect(mockAudioRecordingService.startRecording).toHaveBeenCalledTimes(1); + + service.stopAudioChat(); + expect(mockAudioRecordingService.stopRecording).toHaveBeenCalledTimes(1); + expect(mockWebSocketService.closeConnection).toHaveBeenCalledTimes(1); + + await service.startAudioChat({ + appName: 'fake-app-name', + userId: 'fake-user-id', + sessionId: 'fake-session-id' + }); + expect(mockWebSocketService.connect).toHaveBeenCalledTimes(2); + expect(mockAudioRecordingService.startRecording).toHaveBeenCalledTimes(2); + }); + }); + + describe('restart video chat', () => { + it('should allow restarting video chat after stopping', async () => { + mockAudioRecordingService.getCombinedAudioBuffer.and.returnValue( + Uint8Array.of()); + mockVideoService.getCapturedFrame.and.resolveTo(Uint8Array.of()); + + await service.startVideoChat({ + appName: 'fake-app-name', + userId: 'fake-user-id', + sessionId: 'fake-session-id', + videoContainer + }); + expect(mockWebSocketService.connect).toHaveBeenCalledTimes(1); + expect(mockAudioRecordingService.startRecording).toHaveBeenCalledTimes(1); + expect(mockVideoService.startRecording).toHaveBeenCalledTimes(1); + + service.stopVideoChat(videoContainer); + expect(mockAudioRecordingService.stopRecording).toHaveBeenCalledTimes(1); + expect(mockVideoService.stopRecording).toHaveBeenCalledTimes(1); + expect(mockWebSocketService.closeConnection).toHaveBeenCalledTimes(1); + + await service.startVideoChat({ + appName: 'fake-app-name', + userId: 'fake-user-id', + sessionId: 'fake-session-id', + videoContainer + }); + expect(mockWebSocketService.connect).toHaveBeenCalledTimes(2); + expect(mockAudioRecordingService.startRecording).toHaveBeenCalledTimes(2); + expect(mockVideoService.startRecording).toHaveBeenCalledTimes(2); + }); + }); }); diff --git a/src/app/core/services/websocket.service.spec.ts b/src/app/core/services/websocket.service.spec.ts index c3f3125d..311cd0c4 100644 --- a/src/app/core/services/websocket.service.spec.ts +++ b/src/app/core/services/websocket.service.spec.ts @@ -55,4 +55,33 @@ describe('WebSocketService', () => { expect(service.urlSafeBase64ToBase64('abcd')).toEqual('abcd'); }); }); + + describe('connection restart', () => { + it('should reset audio buffer when reconnecting', () => { + service.connect('ws://test1'); + + (service as any).audioBuffer = [new Uint8Array([1, 2, 3])]; + + service.connect('ws://test2'); + expect((service as any).audioBuffer).toEqual([]); + }); + + it('should close previous connection when reconnecting', () => { + service.connect('ws://test1'); + const firstSocket = (service as any).socket$; + spyOn(firstSocket, 'complete'); + + service.connect('ws://test2'); + expect(firstSocket.complete).toHaveBeenCalled(); + }); + + it('should clear audio interval when closing connection', () => { + service.connect('ws://test'); + const intervalId = (service as any).audioIntervalId; + expect(intervalId).not.toBeNull(); + + service.closeConnection(); + expect((service as any).audioIntervalId).toBeNull(); + }); + }); }); diff --git a/src/app/core/services/websocket.service.ts b/src/app/core/services/websocket.service.ts index 13a8a879..46d4457c 100644 --- a/src/app/core/services/websocket.service.ts +++ b/src/app/core/services/websocket.service.ts @@ -38,6 +38,12 @@ export class WebSocketService implements WebSocketServiceInterface { private closeReasonSubject = new Subject(); connect(serverUrl: string) { + // Clean up previous connection if exists + this.closeConnection(); + + // Reset audio buffer for new connection + this.audioBuffer = []; + this.socket$ = new WebSocketSubject({ url: serverUrl, serializer: (msg) => JSON.stringify(msg),