Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] SSE 알림 서비스 구현 #122

Merged
merged 5 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.kuit.agarang.domain.ai.utils.MusicGenClientUtil;
import com.kuit.agarang.domain.memory.model.entity.Memory;
import com.kuit.agarang.domain.memory.repository.MemoryRepository;
import com.kuit.agarang.domain.notification.service.SseService;
import com.kuit.agarang.global.common.exception.exception.OpenAPIException;
import com.kuit.agarang.global.common.model.dto.BaseResponseStatus;
import com.kuit.agarang.global.s3.model.dto.S3File;
Expand All @@ -28,14 +29,15 @@ public class MusicGenService {
private final MusicGenClientUtil musicGenClientUtil;
private final S3Util s3Util;
private final MemoryRepository memoryRepository;
private final SseService sseService;

private static final String WEBHOOK_URI = "/api/ai/music-gen/webhook";
private static final Integer MUSIC_DURATION = 40;

public String getMusic(String prompt) {
MusicGenResponse response
= musicGenClientUtil.post(
new MusicGenRequest(version, baseUrl + WEBHOOK_URI, prompt, MUSIC_DURATION), MusicGenResponse.class);
= musicGenClientUtil.post(
new MusicGenRequest(version, baseUrl + WEBHOOK_URI, prompt, MUSIC_DURATION), MusicGenResponse.class);
log.info("created musicgen id : {}", response.getId());
return response.getId();
}
Expand All @@ -53,6 +55,9 @@ public void saveMusic(MusicGenResponse response) {
Memory memory = optionalMemory.get();
memory.setMusicUrl(s3File.getObjectUrl());
memoryRepository.save(memory);

String message = "MusicGen Complete!";
sseService.sendNotification(memory.getMember().getId(), message);
}

private static void checkStatus(MusicGenResponse response) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.kuit.agarang.domain.notification.controller;

import com.kuit.agarang.domain.login.model.dto.CustomOAuth2User;
import com.kuit.agarang.domain.notification.service.SseService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

@Slf4j
@RestController
@RequestMapping("/sse")
@RequiredArgsConstructor
public class SseController {

private final SseService sseService;

@GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter connect(@AuthenticationPrincipal CustomOAuth2User details) {
return sseService.connect(details.getMemberId());
}

@GetMapping(value = "/disconnect")
public void disconnect(@AuthenticationPrincipal CustomOAuth2User details) {
sseService.disconnect(details.getMemberId());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package com.kuit.agarang.domain.notification.service;


import com.kuit.agarang.global.common.exception.exception.BusinessException;
import com.kuit.agarang.global.common.model.dto.BaseResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@Service
public class SseService {

private final ConcurrentHashMap<Long, SseEmitter> emitters = new ConcurrentHashMap<>();

public SseEmitter connect(Long memberId) {

SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);

try {
emitter.send(SseEmitter.event()
.name("connect")
.data("connected!")); // 503에러 방지를 위한 더미
} catch (IOException e) {
emitter.completeWithError(e);
throw new BusinessException(BaseResponseStatus.FAIL_CREATE_EMITTER);
}

emitter.onError(e -> {
log.error("Error on SSE connection for memberId: {}", memberId, e);
emitters.remove(memberId);
});

emitter.onCompletion(() -> {
log.info("onCompletion callback for memberId: {}", memberId);
emitters.remove(memberId);
});

emitter.onTimeout(() -> {
log.info("onTimeout callback for memberId: {}", memberId);
emitter.complete();
});

emitters.put(memberId, emitter);
return emitter;
}

public void sendNotification(Long memberId, String message) {
SseEmitter emitter = emitters.get(memberId);
log.info("memberId : {} , sendNotification", memberId);
if (emitter != null) {
try {
emitter.send(SseEmitter.event()
.name("notification")
.data(message, MediaType.TEXT_PLAIN));
} catch (IOException e) {
log.error("Error on send notification for memberId: {}", memberId, e);
emitters.remove(memberId);
}
}
}

public void disconnect(Long memberId) {
log.info("memberId : {} , disconnect", memberId);
emitters.remove(memberId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public enum BaseResponseStatus {
FAIL_REDIS_CONNECTION(false, HttpStatus.SERVICE_UNAVAILABLE, 5002, "레디스 서버에 연결 실패했습니다."),
FAIL_S3_UPLOAD(false, HttpStatus.SERVICE_UNAVAILABLE, 5003, "S3 파일 서버 업로드에 실패했습니다."),
INVALID_GPT_RESPONSE(false, HttpStatus.INTERNAL_SERVER_ERROR, 5004, "GPT 응답이 유효하지 않습니다."),
FAIL_CREATE_MUSIC(false, HttpStatus.INTERNAL_SERVER_ERROR, 5005, "music gen 음악 생성에 실패했습니다.");

FAIL_CREATE_MUSIC(false, HttpStatus.INTERNAL_SERVER_ERROR, 5005, "music gen 음악 생성에 실패했습니다."),
FAIL_CREATE_EMITTER(false, HttpStatus.INTERNAL_SERVER_ERROR, 5006, "SseEmitter 생성에 실패했습니다.");

private final boolean isSuccess;
@JsonIgnore
Expand Down