diff --git a/build.gradle b/build.gradle index 3bab1a27..5fb302fc 100644 --- a/build.gradle +++ b/build.gradle @@ -67,6 +67,11 @@ dependencies { testImplementation "org.springframework.cloud:spring-cloud-starter-openfeign" testImplementation "com.squareup.okhttp3:mockwebserver:4.12.0" + // WebSocket + implementation 'org.springframework.boot:spring-boot-starter-websocket' + + // WebFlux + implementation 'org.springframework.boot:spring-boot-starter-webflux' } dependencyManagement { diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java index e6db7769..dbcc2217 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatApi.java @@ -7,9 +7,25 @@ import com.sofa.linkiving.global.common.BaseResponse; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; -@Tag(name = "Chat", description = "채팅 관리 API") +@Tag(name = "Chat", description = """ + AI 채팅 통합 명세 (HTTP + WebSocket) + + ### 📡 1. WebSocket 연결 정보 (필수) + 답변을 실시간으로 수신하기 위해 **반드시 소켓 연결 및 구독**이 선행되어야 합니다. + + * **Socket Endpoint:** `ws://{domain}/ws/chat` + * **Subscribe Path:** `/topic/chat/{chatId}` + * **Auth Header:** `Authorization: Bearer {accessToken}` (CONNECT 프레임 헤더) + ### 🚀 2. 동작 흐름 + 1. **소켓 연결:** 프론트엔드에서 WebSocket 연결 및 `/topic/chat/{chatId}` 구독 + 2. **질문 전송:** `/app/send/{chatId}` (STOMP)로 질문 전송 + 3. **답변 수신:** 소켓 구독 채널로 토큰 단위 답변 스트리밍 (`String` 데이터) + 4. **완료:** `END_OF_STREAM` 메시지 수신 시 스트리밍 종료 + + """) public interface ChatApi { @Operation(summary = "채팅방 목록 조회", description = "사용자의 채팅방 목록 정보(채팅방 Id, 제목)을 조회합니다.") BaseResponse getChats(Member member); @@ -19,4 +35,9 @@ BaseResponse createChat( CreateChatReq req, Member member ); + + void sendMessage(@Parameter(description = "채팅방 Id", required = true) Long chatId, + @Parameter(description = "사용자 질문 내용", required = true) String message, Member member); + + void cancelMessage(@Parameter(description = "채팅방 Id", required = true) Long chatId, Member member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java index 1d54c9b3..7ac5040e 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/controller/ChatController.java @@ -1,5 +1,8 @@ package com.sofa.linkiving.domain.chat.controller; +import org.springframework.messaging.handler.annotation.DestinationVariable; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.handler.annotation.Payload; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -36,4 +39,16 @@ public BaseResponse createChat(@RequestBody @Valid CreateChatReq CreateChatRes res = chatFacade.createChat(req.firstChat(), member); return BaseResponse.success(res, "채팅방 생성 완료"); } + + @MessageMapping("/send/{chatId}") + @Override + public void sendMessage(@DestinationVariable Long chatId, @Payload String message, @AuthMember Member member) { + chatFacade.generateAnswer(chatId, member, message); + } + + @MessageMapping("/cancel/{chatId}") + @Override + public void cancelMessage(@DestinationVariable Long chatId, @AuthMember Member member) { + chatFacade.cancelAnswer(chatId, member); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java b/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java new file mode 100644 index 00000000..c4771a60 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/controller/MockAiController.java @@ -0,0 +1,33 @@ +package com.sofa.linkiving.domain.chat.controller; + +import java.time.Duration; +import java.util.Map; + +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import reactor.core.publisher.Flux; + +@RestController +@RequestMapping("/mock/ai") +public class MockAiController { + + @PostMapping(value = "/generate", produces = MediaType.APPLICATION_NDJSON_VALUE) // 또는 TEXT_EVENT_STREAM_VALUE + public Flux generateAnswer(@RequestBody Map request) { + String userPrompt = request.get("prompt"); + + String fakeResponse = """ + 안녕하세요! 저는 임시 AI 봇입니다. 🤖 + 현재 AI 서버가 구축되지 않아서 테스트용 답변을 드리고 있어요. + 질문하신 내용인 "%s"에 대해 답변을 생성하는 척 하고 있습니다. + 취소 기능을 테스트하시려면 지금 바로 취소 버튼을 눌러보세요! + 타이핑 효과를 위해 천천히 답변을 보내고 있습니다... + """.formatted(userPrompt); + + return Flux.fromArray(fakeResponse.split("")) + .delayElements(Duration.ofMillis(100)); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/error/ChatErrorCode.java b/src/main/java/com/sofa/linkiving/domain/chat/error/ChatErrorCode.java new file mode 100644 index 00000000..c86f8019 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/error/ChatErrorCode.java @@ -0,0 +1,20 @@ +package com.sofa.linkiving.domain.chat.error; + +import org.springframework.http.HttpStatus; + +import com.sofa.linkiving.global.error.code.ErrorCode; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@Getter +@RequiredArgsConstructor +public enum ChatErrorCode implements ErrorCode { + + CHAT_NOT_FOUND(HttpStatus.NOT_FOUND, "C-001", "채팅을 찾을 수 없습니다."), + ALREADY_GENERATING(HttpStatus.BAD_REQUEST, "C-002", "현재 답변이 생성 중입니다. 잠시만 기다려주세요."); + + private final HttpStatus status; + private final String code; + private final String message; +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java b/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java index ecffdd12..9973ebec 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/facade/ChatFacade.java @@ -37,4 +37,15 @@ public ChatsRes getChats(Member member) { List chats = chatService.getChats(member); return ChatsRes.from(chats); } + + @Transactional + public void generateAnswer(Long chatId, Member member, String message) { + Chat chat = chatService.getChat(chatId, member); + messageService.generateAnswer(chat, message); + } + + public void cancelAnswer(Long chatId, Member member) { + Chat chat = chatService.getChat(chatId, member); + messageService.cancelAnswer(chat); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java b/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java new file mode 100644 index 00000000..27d83a85 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManager.java @@ -0,0 +1,39 @@ +package com.sofa.linkiving.domain.chat.manager; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.stereotype.Component; + +import reactor.core.Disposable; + +@Component +public class SubscriptionManager { + + private final Map activeSubscriptions = new ConcurrentHashMap<>(); + + /** + * 구독 추가 (기존 작업이 있다면 취소 후 등록) + */ + public void add(String key, Disposable subscription) { + cancel(key); // 안전하게 기존 작업 정리 + activeSubscriptions.put(key, subscription); + } + + /** + * 구독 취소 및 자원 해제 + */ + public void cancel(String key) { + Disposable subscription = activeSubscriptions.remove(key); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + } + + /** + * 완료된 구독 제거 (자원 해제 없이 Map에서만 삭제) + */ + public void remove(String key) { + activeSubscriptions.remove(key); + } +} diff --git a/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java b/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java index 506b675c..40bc94e7 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/repository/ChatRepository.java @@ -1,6 +1,7 @@ package com.sofa.linkiving.domain.chat.repository; import java.util.List; +import java.util.Optional; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; @@ -21,4 +22,6 @@ public interface ChatRepository extends JpaRepository { ORDER BY MAX(m.createdAt) DESC """) List findAllByMemberOrderByLastMessageDesc(@Param("member") Member member); + + Optional findByIdAndMember(Long id, Member member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java index f2b43a06..3de037e0 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatQueryService.java @@ -5,8 +5,10 @@ import org.springframework.stereotype.Service; import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.error.ChatErrorCode; import com.sofa.linkiving.domain.chat.repository.ChatRepository; import com.sofa.linkiving.domain.member.entity.Member; +import com.sofa.linkiving.global.error.exception.BusinessException; import lombok.RequiredArgsConstructor; @@ -15,6 +17,12 @@ public class ChatQueryService { private final ChatRepository chatRepository; + public Chat findChat(Long chatId, Member member) { + return chatRepository.findByIdAndMember(chatId, member).orElseThrow( + () -> new BusinessException(ChatErrorCode.CHAT_NOT_FOUND) + ); + } + public List findAllOrderByLastMessageDesc(Member member) { return chatRepository.findAllByMemberOrderByLastMessageDesc(member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java index f13794f5..6b0f4fc3 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/ChatService.java @@ -8,13 +8,19 @@ import com.sofa.linkiving.domain.member.entity.Member; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +@Slf4j @Service @RequiredArgsConstructor public class ChatService { private final ChatCommandService chatCommandService; private final ChatQueryService chatQueryService; + public Chat getChat(Long chatId, Member member) { + return chatQueryService.findChat(chatId, member); + } + public List getChats(Member member) { return chatQueryService.findAllOrderByLastMessageDesc(member); } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java index a68f0dc5..adf0b760 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageCommandService.java @@ -2,6 +2,7 @@ import org.springframework.stereotype.Service; +import com.sofa.linkiving.domain.chat.entity.Message; import com.sofa.linkiving.domain.chat.repository.MessageRepository; import lombok.RequiredArgsConstructor; @@ -10,4 +11,8 @@ @RequiredArgsConstructor public class MessageCommandService { private final MessageRepository messageRepository; + + public Message saveMessage(Message message) { + return messageRepository.save(message); + } } diff --git a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java index 07a0f5ac..b446f015 100644 --- a/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java +++ b/src/main/java/com/sofa/linkiving/domain/chat/service/MessageService.java @@ -1,12 +1,88 @@ package com.sofa.linkiving.domain.chat.service; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.stereotype.Service; +import org.springframework.web.reactive.function.client.WebClient; + +import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.chat.enums.Type; +import com.sofa.linkiving.domain.chat.manager.SubscriptionManager; import lombok.RequiredArgsConstructor; +import reactor.core.Disposable; @Service @RequiredArgsConstructor public class MessageService { private final MessageCommandService messageCommandService; private final MessageQueryService messageQueryService; + + private final SimpMessagingTemplate messagingTemplate; + private final SubscriptionManager subscriptionManager; + + private final WebClient webClient = WebClient.create("http://localhost:8080/mock/ai"); + private final Map messageBuffers = new ConcurrentHashMap<>(); + + public void generateAnswer(Chat chat, String userMessage) { + + String roomId = chat.getId().toString(); + + if (messageBuffers.containsKey(roomId)) { + return; + } + + messageBuffers.put(roomId, new StringBuilder()); + + Disposable subscription = webClient.post() + .uri("/generate") + .bodyValue(Map.of("prompt", userMessage)) + .retrieve() + .bodyToFlux(String.class) + .doOnComplete(() -> { + String fullAnswer = messageBuffers.remove(roomId).toString(); + + saveMessage(chat, Type.USER, userMessage); + saveMessage(chat, Type.AI, fullAnswer); + + subscriptionManager.remove(roomId); + messagingTemplate.convertAndSend("/topic/chat/" + roomId, "END_OF_STREAM"); + }) + .doOnError(e -> { + subscriptionManager.remove(roomId); + messagingTemplate.convertAndSend("/topic/chat/" + roomId, "ERROR: " + e.getMessage()); + }) + .subscribe(token -> { + StringBuilder buffer = messageBuffers.get(roomId); + if (buffer != null) { + buffer.append(token); + } + + messagingTemplate.convertAndSend("/topic/chat/" + roomId, token); + }); + + subscriptionManager.add(roomId, subscription); + } + + public void cancelAnswer(Chat chat) { + String roomId = chat.getId().toString(); + + subscriptionManager.cancel(roomId); + messageBuffers.remove(roomId); + + messagingTemplate.convertAndSend("/topic/chat/" + roomId, "GENERATION_CANCELLED"); + } + + private void saveMessage(Chat chat, Type type, String content) { + Message message = Message.builder() + .chat(chat) + .type(type) + .content(content) + .build(); + + messageCommandService.saveMessage(message); + } } diff --git a/src/main/java/com/sofa/linkiving/global/config/WebMvcConfig.java b/src/main/java/com/sofa/linkiving/global/config/WebMvcConfig.java index 51b12a97..9b18bb4d 100644 --- a/src/main/java/com/sofa/linkiving/global/config/WebMvcConfig.java +++ b/src/main/java/com/sofa/linkiving/global/config/WebMvcConfig.java @@ -3,7 +3,10 @@ import java.util.List; import org.springframework.context.annotation.Configuration; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.web.method.support.HandlerMethodArgumentResolver; +import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import com.sofa.linkiving.security.resolver.AuthMemberArgumentResolver; @@ -20,4 +23,19 @@ public class WebMvcConfig implements WebMvcConfigurer { public void addArgumentResolvers(List resolvers) { resolvers.add(authMemberArgumentResolver); } + + @Override + public void configureAsyncSupport(AsyncSupportConfigurer configurer) { + configurer.setTaskExecutor(mvcTaskExecutor()); + } + + public AsyncTaskExecutor mvcTaskExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setCorePoolSize(10); + executor.setMaxPoolSize(100); + executor.setQueueCapacity(50); + executor.setThreadNamePrefix("mvc-async-"); + executor.initialize(); + return executor; + } } diff --git a/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java b/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java new file mode 100644 index 00000000..d107c148 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/global/config/WebSocketConfig.java @@ -0,0 +1,48 @@ +package com.sofa.linkiving.global.config; + +import java.util.List; + +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.simp.config.ChannelRegistration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; + +import com.sofa.linkiving.security.config.StompHandler; +import com.sofa.linkiving.security.resolver.AuthMemberWebsocketArgumentResolver; + +import lombok.RequiredArgsConstructor; + +@Configuration +@EnableWebSocketMessageBroker +@RequiredArgsConstructor +public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { + + private final StompHandler stompHandler; + private final AuthMemberWebsocketArgumentResolver authMemberWebsocketArgumentResolver; + + @Override + public void configureMessageBroker(MessageBrokerRegistry config) { + config.enableSimpleBroker("/topic/chat"); + config.setApplicationDestinationPrefixes("/ws/chat"); + } + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.addEndpoint("/ws/chat") + .setAllowedOriginPatterns("*") + .withSockJS(); + } + + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.interceptors(stompHandler); + } + + @Override + public void addArgumentResolvers(List argumentResolvers) { + argumentResolvers.add(authMemberWebsocketArgumentResolver); + } +} diff --git a/src/main/java/com/sofa/linkiving/infra/redis/RedisService.java b/src/main/java/com/sofa/linkiving/infra/redis/RedisService.java index 14ceefba..4c62c4af 100644 --- a/src/main/java/com/sofa/linkiving/infra/redis/RedisService.java +++ b/src/main/java/com/sofa/linkiving/infra/redis/RedisService.java @@ -33,7 +33,7 @@ public String get(String key) { public Boolean hasNoKey(String key) { Boolean exists = redisTemplate.hasKey(key); - return !Boolean.TRUE.equals(exists); + return !exists; } public void delete(String key) { @@ -56,7 +56,7 @@ public String get(RedisKeySpec type, String... keys) { public Boolean hasNoKey(RedisKeySpec type, String... keys) { Boolean exists = redisTemplate.hasKey(type.key(keys)); - return !Boolean.TRUE.equals(exists); + return !exists; } public void delete(RedisKeySpec type, String... keys) { diff --git a/src/main/java/com/sofa/linkiving/security/config/SecurityConfig.java b/src/main/java/com/sofa/linkiving/security/config/SecurityConfig.java index 6635e826..c3bac430 100644 --- a/src/main/java/com/sofa/linkiving/security/config/SecurityConfig.java +++ b/src/main/java/com/sofa/linkiving/security/config/SecurityConfig.java @@ -39,10 +39,10 @@ public class SecurityConfig { "/h2-console/**", /* web socket */ - "/v1/chat/**", + "/ws/chat/**", /* temp */ - "/v1/member/**" + "/v1/member/**", "/mock/**" }; private static final String[] SEMI_PERMIT_URLS = { diff --git a/src/main/java/com/sofa/linkiving/security/config/StompHandler.java b/src/main/java/com/sofa/linkiving/security/config/StompHandler.java new file mode 100644 index 00000000..6d1d895d --- /dev/null +++ b/src/main/java/com/sofa/linkiving/security/config/StompHandler.java @@ -0,0 +1,62 @@ +package com.sofa.linkiving.security.config; + +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.MessagingException; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.messaging.support.MessageHeaderAccessor; +import org.springframework.security.core.Authentication; +import org.springframework.stereotype.Component; + +import com.sofa.linkiving.global.error.exception.BusinessException; +import com.sofa.linkiving.security.jwt.JwtKeys; +import com.sofa.linkiving.security.jwt.JwtTokenProvider; +import com.sofa.linkiving.security.jwt.error.JwtErrorCode; + +import lombok.RequiredArgsConstructor; + +@Component +@RequiredArgsConstructor +@Order(Ordered.HIGHEST_PRECEDENCE + 99) +public class StompHandler implements ChannelInterceptor { + + private final JwtTokenProvider jwtTokenProvider; + + @Override + public Message preSend(Message message, MessageChannel channel) { + StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); + + if (accessor != null && StompCommand.CONNECT.equals(accessor.getCommand())) { + + String authorizationHeader = accessor.getFirstNativeHeader(JwtKeys.Headers.AUTHORIZATION); + String token = null; + + if (authorizationHeader != null && authorizationHeader.startsWith(JwtKeys.Headers.BEARER_PREFIX)) { + token = authorizationHeader.substring(JwtKeys.Headers.BEARER_PREFIX.length()); + } + + try { + if (token == null) { + throw new BusinessException(JwtErrorCode.EMPTY_TOKEN); + } + + if (jwtTokenProvider.validateAccessToken(token)) { + Authentication authentication = jwtTokenProvider.getAuthentication(token); + accessor.setUser(authentication); + } + + } catch (BusinessException e) { + throw new MessagingException(e.getMessage()); + + } catch (Exception e) { + throw new MessagingException("서버 내부 오류로 연결에 실패했습니다."); + } + } + + return message; + } +} diff --git a/src/main/java/com/sofa/linkiving/security/resolver/AuthMemberWebsocketArgumentResolver.java b/src/main/java/com/sofa/linkiving/security/resolver/AuthMemberWebsocketArgumentResolver.java new file mode 100644 index 00000000..0721bc94 --- /dev/null +++ b/src/main/java/com/sofa/linkiving/security/resolver/AuthMemberWebsocketArgumentResolver.java @@ -0,0 +1,41 @@ +package com.sofa.linkiving.security.resolver; + +import java.security.Principal; + +import org.springframework.core.MethodParameter; +import org.springframework.messaging.Message; +import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.stereotype.Component; + +import com.sofa.linkiving.domain.member.entity.Member; +import com.sofa.linkiving.global.error.code.CommonErrorCode; +import com.sofa.linkiving.global.error.exception.BusinessException; +import com.sofa.linkiving.security.annotation.AuthMember; +import com.sofa.linkiving.security.userdetails.CustomMemberDetail; + +@Component +public class AuthMemberWebsocketArgumentResolver implements HandlerMethodArgumentResolver { + + @Override + public boolean supportsParameter(MethodParameter parameter) { + return parameter.hasParameterAnnotation(AuthMember.class) + && Member.class.isAssignableFrom(parameter.getParameterType()); + } + + @Override + public Object resolveArgument(MethodParameter parameter, Message message) { + Principal principal = SimpMessageHeaderAccessor.getUser(message.getHeaders()); + + if (principal instanceof AbstractAuthenticationToken authentication) { + Object userDetails = authentication.getPrincipal(); + + if (userDetails instanceof CustomMemberDetail customMemberDetail) { + return customMemberDetail.member(); + } + } + + throw new BusinessException(CommonErrorCode.UNAUTHORIZED); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java b/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java new file mode 100644 index 00000000..4a75433b --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/integration/WebSocketChatIntegrationTest.java @@ -0,0 +1,184 @@ +package com.sofa.linkiving.domain.chat.integration; + +import static org.assertj.core.api.Assertions.*; + +import java.lang.reflect.Type; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.messaging.converter.StringMessageConverter; +import org.springframework.messaging.simp.stomp.StompFrameHandler; +import org.springframework.messaging.simp.stomp.StompHeaders; +import org.springframework.messaging.simp.stomp.StompSession; +import org.springframework.messaging.simp.stomp.StompSessionHandlerAdapter; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.bean.override.mockito.MockitoBean; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.socket.WebSocketHttpHeaders; +import org.springframework.web.socket.client.standard.StandardWebSocketClient; +import org.springframework.web.socket.messaging.WebSocketStompClient; +import org.springframework.web.socket.sockjs.client.SockJsClient; +import org.springframework.web.socket.sockjs.client.Transport; +import org.springframework.web.socket.sockjs.client.WebSocketTransport; + +import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.repository.ChatRepository; +import com.sofa.linkiving.domain.chat.service.MessageService; +import com.sofa.linkiving.domain.member.entity.Member; +import com.sofa.linkiving.domain.member.repository.MemberRepository; +import com.sofa.linkiving.infra.redis.RedisService; +import com.sofa.linkiving.security.jwt.JwtTokenProvider; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@ActiveProfiles("test") +public class WebSocketChatIntegrationTest { + + @LocalServerPort + private int port; + + private WebSocketStompClient stompClient; + + @Autowired + private MessageService messageService; + + @Autowired + private MemberRepository memberRepository; + + @Autowired + private ChatRepository chatRepository; + + @MockitoBean + private RedisService redisService; + + @Autowired + private JwtTokenProvider jwtTokenProvider; // 실제 토큰 생성 로직 사용 (또는 MockBean) + + private Chat savedChat; + private String validToken; + + @BeforeEach + void setUp() { + // 1. WebSocket Client 설정 + StandardWebSocketClient standardWebSocketClient = new StandardWebSocketClient(); + WebSocketTransport webSocketTransport = new WebSocketTransport(standardWebSocketClient); + List transports = List.of(webSocketTransport); + SockJsClient sockJsClient = new SockJsClient(transports); + + stompClient = new WebSocketStompClient(sockJsClient); + stompClient.setMessageConverter(new StringMessageConverter()); + + // 2. MockAiController 연결을 위한 WebClient 주소 조작 (핵심) + String testUrl = "http://localhost:" + port + "/mock/ai"; + WebClient testWebClient = WebClient.create(testUrl); + ReflectionTestUtils.setField(messageService, "webClient", testWebClient); + + // 3. 테스트 데이터 생성 + String uniqueEmail = "socket_" + UUID.randomUUID().toString().substring(0, 8) + "@test.com"; + + Member savedMember = memberRepository.save(Member.builder() + .email(uniqueEmail) + .password("password") + .build()); + + savedChat = chatRepository.save(Chat.builder() + .member(savedMember) + .title("test") + .build()); + + // 4. 유효한 토큰 생성 (StompHandler 통과용) + validToken = jwtTokenProvider.createAccessToken(savedMember.getEmail()); + } + + @Test + @DisplayName("메시지 전송 시 MockAiController를 통해 스트리밍 답변 수신") + void shouldReceiveStreamingResponseWhenSendMessage() throws Exception { + // given + String wsUrl = String.format("ws://localhost:%d/ws/chat", port); + StompHeaders headers = new StompHeaders(); + headers.add("Authorization", "Bearer " + validToken); + + WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); + + StompSession session = stompClient.connectAsync(wsUrl, handshakeHeaders, headers, + new StompSessionHandlerAdapter() { + }) + .get(5, TimeUnit.SECONDS); + + Long chatId = savedChat.getId(); + String userMessage = "테스트 질문"; + BlockingQueue queue = new LinkedBlockingQueue<>(); + + // when: 구독 (/topic/chat/{chatId}) + session.subscribe("/topic/chat/" + chatId, new StompFrameHandler() { + @NotNull + @Override + public Type getPayloadType(@NotNull StompHeaders headers) { + return String.class; + } + + @Override + public void handleFrame(@NotNull StompHeaders headers, Object payload) { + queue.add((String)payload); + } + }); + + // when: 메시지 전송 + session.send("/ws/chat/send/" + chatId, userMessage); + + // then: MockAiController가 보내는 응답 검증 + String response = queue.poll(5, TimeUnit.SECONDS); + + assertThat(response).isNotNull(); + assertThat(response).startsWith("안"); + } + + @Test + @DisplayName("취소 요청 시 GENERATION_CANCELLED 메시지 수신") + void shouldReceiveCancelledMessageWhenCancelRequest() throws Exception { + // given + String wsUrl = String.format("ws://localhost:%d/ws/chat", port); + StompHeaders headers = new StompHeaders(); + headers.add("Authorization", "Bearer " + validToken); + + WebSocketHttpHeaders handshakeHeaders = new WebSocketHttpHeaders(); + + StompSession session = stompClient.connectAsync(wsUrl, handshakeHeaders, headers, + new StompSessionHandlerAdapter() { + }) + .get(5, TimeUnit.SECONDS); + + Long chatId = savedChat.getId(); + BlockingQueue queue = new LinkedBlockingQueue<>(); + + session.subscribe("/topic/chat/" + chatId, new StompFrameHandler() { + @NotNull + @Override + public Type getPayloadType(@NotNull StompHeaders headers) { + return String.class; + } + + @Override + public void handleFrame(@NotNull StompHeaders headers, Object payload) { + queue.add((String)payload); + } + }); + + // when: 취소 요청 전송 + session.send("/ws/chat/cancel/" + chatId, ""); + + // then + String response = queue.poll(5, TimeUnit.SECONDS); + assertThat(response).isEqualTo("GENERATION_CANCELLED"); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java b/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java new file mode 100644 index 00000000..e77446c3 --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/manager/SubscriptionManagerTest.java @@ -0,0 +1,67 @@ +package com.sofa.linkiving.domain.chat.manager; + +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import reactor.core.Disposable; + +@ExtendWith(MockitoExtension.class) +public class SubscriptionManagerTest { + + @InjectMocks + private SubscriptionManager subscriptionManager; + + @Mock + private Disposable disposable; + + @Test + @DisplayName("구독 추가 요청 시 기존 구독이 있다면 취소 후 등록") + void shouldDisposeOldSubscriptionWhenAdd() { + // given + String key = "chat-1"; + Disposable oldDisposable = mock(Disposable.class); + + // 먼저 하나 등록 + subscriptionManager.add(key, oldDisposable); + + // when: 같은 키로 새로운 구독 등록 + subscriptionManager.add(key, disposable); + + // then: 이전 구독은 dispose 되어야 함 + verify(oldDisposable).dispose(); + } + + @Test + @DisplayName("구독 취소 요청 시 dispose 호출 및 제거") + void shouldDisposeWhenCancel() { + // given + String key = "chat-1"; + subscriptionManager.add(key, disposable); + + // when + subscriptionManager.cancel(key); + + // then + verify(disposable).dispose(); + } + + @Test + @DisplayName("완료된 구독 제거 요청 시 dispose 없이 맵에서만 제거") + void shouldNotDisposeWhenRemove() { + // given + String key = "chat-1"; + subscriptionManager.add(key, disposable); + + // when + subscriptionManager.remove(key); + + // then: remove는 dispose를 호출하지 않음 (이미 완료된 상태 가정) + verify(disposable, never()).dispose(); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java b/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java index 4f41718c..24f20a19 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/repository/ChatRepositoryTest.java @@ -3,6 +3,7 @@ import static org.assertj.core.api.Assertions.*; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; @@ -22,8 +23,10 @@ @DataJpaTest @ActiveProfiles("test") public class ChatRepositoryTest { + @Autowired private ChatRepository chatRepository; + @Autowired private MemberRepository memberRepository; @Autowired @@ -91,4 +94,55 @@ void shouldReturnOnlyChatsWithMessagesOrderByLastMessageTime() throws Interrupte assertThat(result.get(0).getTitle()).isEqualTo("New Msg Chat"); assertThat(result.get(1).getTitle()).isEqualTo("Old Msg Chat"); } + + @Test + @DisplayName("내 채팅방 조회 시 정상적으로 반환되어야 한다") + void shouldReturnChatWhenMyChatExists() { + // given + Member me = memberRepository.save( + Member.builder() + .email("me@test.com") + .password("password") + .build()); + + Chat myChat = chatRepository.save( + Chat.builder() + .member(me) + .title("test") + .build()); + + // when + Optional result = chatRepository.findByIdAndMember(myChat.getId(), me); + + // then + assertThat(result).isPresent(); + } + + @Test + @DisplayName("다른 사람의 채팅방 조회 시 Empty를 반환해야 한다") + void shouldReturnEmptyWhenChatIsNotMine() { + // given + Member me = memberRepository.save( + Member.builder() + .email("me@test.com") + .password("password") + .build()); + Member other = memberRepository.save( + Member.builder() + .email("other@test.com") + .password("password") + .build()); + + Chat othersChat = chatRepository.save( + Chat.builder() + .member(other) + .title("test") + .build()); + + // when: 내 정보(me)로 남의 채팅방(othersChat) 조회 시도 + Optional result = chatRepository.findByIdAndMember(othersChat.getId(), me); + + // then + assertThat(result).isEmpty(); // 조회되면 안 됨 (보안 검증) + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java index 8a7ef177..45966612 100644 --- a/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/ChatQueryServiceTest.java @@ -4,6 +4,7 @@ import static org.mockito.BDDMockito.*; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -13,8 +14,10 @@ import org.mockito.junit.jupiter.MockitoExtension; import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.error.ChatErrorCode; import com.sofa.linkiving.domain.chat.repository.ChatRepository; import com.sofa.linkiving.domain.member.entity.Member; +import com.sofa.linkiving.global.error.exception.BusinessException; @ExtendWith(MockitoExtension.class) public class ChatQueryServiceTest { @@ -41,4 +44,33 @@ void shouldReturnChatListWhenFindAllOrderByLastMessageDesc() { assertThat(result).isEqualTo(chats); verify(chatRepository).findAllByMemberOrderByLastMessageDesc(member); } + + @Test + @DisplayName("채팅방 조회 성공 시 Chat 엔티티 반환") + void shouldReturnChatWhenChatExists() { + // given + Long chatId = 1L; + Chat chat = mock(Chat.class); + given(chatRepository.findByIdAndMember(chatId, member)).willReturn(Optional.of(chat)); + + // when + Chat result = chatQueryService.findChat(chatId, member); + + // then + assertThat(result).isEqualTo(chat); + verify(chatRepository).findByIdAndMember(chatId, member); + } + + @Test + @DisplayName("채팅방 미존재 시 BusinessException(CHAT_NOT_FOUND) 발생") + void shouldThrowExceptionWhenChatNotFound() { + // given + Long chatId = 999L; + given(chatRepository.findByIdAndMember(chatId, member)).willReturn(Optional.empty()); + + // when & then + assertThatThrownBy(() -> chatQueryService.findChat(chatId, member)) + .isInstanceOf(BusinessException.class) + .hasFieldOrPropertyWithValue("errorCode", ChatErrorCode.CHAT_NOT_FOUND); + } } diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java new file mode 100644 index 00000000..5e9c3eac --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageCommandServiceTest.java @@ -0,0 +1,39 @@ +package com.sofa.linkiving.domain.chat.service; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.BDDMockito.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import com.sofa.linkiving.domain.chat.entity.Message; +import com.sofa.linkiving.domain.chat.repository.MessageRepository; + +@ExtendWith(MockitoExtension.class) +public class MessageCommandServiceTest { + + @InjectMocks + private MessageCommandService messageCommandService; + + @Mock + private MessageRepository messageRepository; + + @Test + @DisplayName("MessageRepository.save 호출 및 저장된 Message 반환") + void shouldReturnSavedMessageWhenSaveMessage() { + // given + Message message = mock(Message.class); + given(messageRepository.save(any(Message.class))).willReturn(message); + + // when + Message result = messageCommandService.saveMessage(message); + + // then + assertThat(result).isEqualTo(message); + verify(messageRepository).save(message); + } +} diff --git a/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java new file mode 100644 index 00000000..c2a1a236 --- /dev/null +++ b/src/test/java/com/sofa/linkiving/domain/chat/service/MessageServiceTest.java @@ -0,0 +1,74 @@ +package com.sofa.linkiving.domain.chat.service; + +import static org.mockito.Mockito.*; + +import java.util.Map; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.test.util.ReflectionTestUtils; + +import com.sofa.linkiving.domain.chat.entity.Chat; +import com.sofa.linkiving.domain.chat.manager.SubscriptionManager; + +@ExtendWith(MockitoExtension.class) +public class MessageServiceTest { + + @InjectMocks + private MessageService messageService; + + @Mock + private SimpMessagingTemplate messagingTemplate; + + @Mock + private SubscriptionManager subscriptionManager; + + @Mock + private Chat chat; + + @BeforeEach + void setUp() { + // Chat ID Mocking + lenient().when(chat.getId()).thenReturn(1L); + } + + @Test + @DisplayName("답변 취소 요청 시 구독 취소 및 취소 메시지 전송") + void shouldCancelSubscriptionAndSendMessageWhenCancelAnswer() { + // given + String roomId = "1"; + + // when + messageService.cancelAnswer(chat); + + // then + verify(subscriptionManager).cancel(roomId); + verify(messagingTemplate).convertAndSend(eq("/topic/chat/" + roomId), eq("GENERATION_CANCELLED")); + } + + @Test + @DisplayName("이미 답변 생성 중일 경우 중복 요청 무시") + void shouldIgnoreRequestWhenAlreadyGenerating() { + // given + // messageBuffers 필드에 강제로 현재 채팅방 ID를 넣어 생성 중인 상태로 만듦 + @SuppressWarnings("unchecked") + Map buffers = (Map)ReflectionTestUtils.getField(messageService, + "messageBuffers"); + Assertions.assertNotNull(buffers); + buffers.put("1", new StringBuilder()); + + // when + messageService.generateAnswer(chat, "질문"); + + // then + // WebClient 호출 로직으로 넘어가지 않아야 하므로 SubscriptionManager 호출이 없어야 함 + verify(subscriptionManager, never()).add(anyString(), any()); + } +}