diff --git a/src/main/java/com/kuit/agarang/domain/ai/service/MusicGenService.java b/src/main/java/com/kuit/agarang/domain/ai/service/MusicGenService.java index 2149e79b..5be95f90 100644 --- a/src/main/java/com/kuit/agarang/domain/ai/service/MusicGenService.java +++ b/src/main/java/com/kuit/agarang/domain/ai/service/MusicGenService.java @@ -5,6 +5,7 @@ import com.kuit.agarang.domain.ai.utils.MusicGenClientUtil; import com.kuit.agarang.domain.memory.model.entity.Memory; import com.kuit.agarang.domain.memory.repository.MemoryRepository; +import com.kuit.agarang.domain.notification.service.SseService; import com.kuit.agarang.global.common.exception.exception.OpenAPIException; import com.kuit.agarang.global.common.model.dto.BaseResponseStatus; import com.kuit.agarang.global.s3.model.dto.S3File; @@ -28,14 +29,15 @@ public class MusicGenService { private final MusicGenClientUtil musicGenClientUtil; private final S3Util s3Util; private final MemoryRepository memoryRepository; + private final SseService sseService; private static final String WEBHOOK_URI = "/api/ai/music-gen/webhook"; private static final Integer MUSIC_DURATION = 40; public String getMusic(String prompt) { MusicGenResponse response - = musicGenClientUtil.post( - new MusicGenRequest(version, baseUrl + WEBHOOK_URI, prompt, MUSIC_DURATION), MusicGenResponse.class); + = musicGenClientUtil.post( + new MusicGenRequest(version, baseUrl + WEBHOOK_URI, prompt, MUSIC_DURATION), MusicGenResponse.class); log.info("created musicgen id : {}", response.getId()); return response.getId(); } @@ -53,6 +55,9 @@ public void saveMusic(MusicGenResponse response) { Memory memory = optionalMemory.get(); memory.setMusicUrl(s3File.getObjectUrl()); memoryRepository.save(memory); + + String message = "MusicGen Complete!"; + sseService.sendNotification(memory.getMember().getId(), message); } private static void checkStatus(MusicGenResponse response) { diff --git a/src/main/java/com/kuit/agarang/domain/notification/controller/SseController.java b/src/main/java/com/kuit/agarang/domain/notification/controller/SseController.java new file mode 100644 index 00000000..b5cdc9be --- /dev/null +++ b/src/main/java/com/kuit/agarang/domain/notification/controller/SseController.java @@ -0,0 +1,31 @@ +package com.kuit.agarang.domain.notification.controller; + +import com.kuit.agarang.domain.login.model.dto.CustomOAuth2User; +import com.kuit.agarang.domain.notification.service.SseService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +@Slf4j +@RestController +@RequestMapping("/sse") +@RequiredArgsConstructor +public class SseController { + + private final SseService sseService; + + @GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter connect(@AuthenticationPrincipal CustomOAuth2User details) { + return sseService.connect(details.getMemberId()); + } + + @GetMapping(value = "/disconnect") + public void disconnect(@AuthenticationPrincipal CustomOAuth2User details) { + sseService.disconnect(details.getMemberId()); + } +} diff --git a/src/main/java/com/kuit/agarang/domain/notification/service/SseService.java b/src/main/java/com/kuit/agarang/domain/notification/service/SseService.java new file mode 100644 index 00000000..08c6eae2 --- /dev/null +++ b/src/main/java/com/kuit/agarang/domain/notification/service/SseService.java @@ -0,0 +1,71 @@ +package com.kuit.agarang.domain.notification.service; + + +import com.kuit.agarang.global.common.exception.exception.BusinessException; +import com.kuit.agarang.global.common.model.dto.BaseResponseStatus; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +@Slf4j +@Service +public class SseService { + + private final ConcurrentHashMap emitters = new ConcurrentHashMap<>(); + + public SseEmitter connect(Long memberId) { + + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + + try { + emitter.send(SseEmitter.event() + .name("connect") + .data("connected!")); // 503에러 방지를 위한 더미 + } catch (IOException e) { + emitter.completeWithError(e); + throw new BusinessException(BaseResponseStatus.FAIL_CREATE_EMITTER); + } + + emitter.onError(e -> { + log.error("Error on SSE connection for memberId: {}", memberId, e); + emitters.remove(memberId); + }); + + emitter.onCompletion(() -> { + log.info("onCompletion callback for memberId: {}", memberId); + emitters.remove(memberId); + }); + + emitter.onTimeout(() -> { + log.info("onTimeout callback for memberId: {}", memberId); + emitter.complete(); + }); + + emitters.put(memberId, emitter); + return emitter; + } + + public void sendNotification(Long memberId, String message) { + SseEmitter emitter = emitters.get(memberId); + log.info("memberId : {} , sendNotification", memberId); + if (emitter != null) { + try { + emitter.send(SseEmitter.event() + .name("notification") + .data(message, MediaType.TEXT_PLAIN)); + } catch (IOException e) { + log.error("Error on send notification for memberId: {}", memberId, e); + emitters.remove(memberId); + } + } + } + + public void disconnect(Long memberId) { + log.info("memberId : {} , disconnect", memberId); + emitters.remove(memberId); + } +} diff --git a/src/main/java/com/kuit/agarang/global/common/model/dto/BaseResponseStatus.java b/src/main/java/com/kuit/agarang/global/common/model/dto/BaseResponseStatus.java index a4343b1f..0e968b70 100644 --- a/src/main/java/com/kuit/agarang/global/common/model/dto/BaseResponseStatus.java +++ b/src/main/java/com/kuit/agarang/global/common/model/dto/BaseResponseStatus.java @@ -39,8 +39,8 @@ public enum BaseResponseStatus { FAIL_REDIS_CONNECTION(false, HttpStatus.SERVICE_UNAVAILABLE, 5002, "레디스 서버에 연결 실패했습니다."), FAIL_S3_UPLOAD(false, HttpStatus.SERVICE_UNAVAILABLE, 5003, "S3 파일 서버 업로드에 실패했습니다."), INVALID_GPT_RESPONSE(false, HttpStatus.INTERNAL_SERVER_ERROR, 5004, "GPT 응답이 유효하지 않습니다."), - FAIL_CREATE_MUSIC(false, HttpStatus.INTERNAL_SERVER_ERROR, 5005, "music gen 음악 생성에 실패했습니다."); - + FAIL_CREATE_MUSIC(false, HttpStatus.INTERNAL_SERVER_ERROR, 5005, "music gen 음악 생성에 실패했습니다."), + FAIL_CREATE_EMITTER(false, HttpStatus.INTERNAL_SERVER_ERROR, 5006, "SseEmitter 생성에 실패했습니다."); private final boolean isSuccess; @JsonIgnore