From 1370f756a019149dc1f6319b05524ce0e84331f7 Mon Sep 17 00:00:00 2001 From: jombi Date: Thu, 17 Oct 2024 00:28:14 +0900 Subject: [PATCH] feat :: Image generation on final chat request --- build.gradle.kts | 2 +- .../teamapi/palette/config/WebClientConfig.kt | 16 ++ .../teamapi/palette/service/ChatService.kt | 145 +++++++++++++----- .../service/infra/GenerativeImageService.kt | 21 ++- .../service/infra/comfy/GenerateRequest.kt | 12 ++ .../service/infra/comfy/GenerateResponse.kt | 6 + src/main/resources/application.yml | 7 + 7 files changed, 165 insertions(+), 44 deletions(-) create mode 100644 src/main/kotlin/com/teamapi/palette/config/WebClientConfig.kt create mode 100644 src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateRequest.kt create mode 100644 src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateResponse.kt diff --git a/build.gradle.kts b/build.gradle.kts index b6b16d0..fd5453d 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -52,7 +52,7 @@ dependencies { implementation("org.springframework.ai:spring-ai-azure-openai-spring-boot-starter") implementation(group = "io.netty", name = "netty-resolver-dns-native-macos", classifier = "osx-aarch_64") -// implementation("com.azure.spring:spring-cloud-azure-starter-storage-blob") + implementation("com.azure.spring:spring-cloud-azure-starter-storage-blob") // implementation("com.azure.spring:spring-cloud-azure-starter-storage") implementation("org.springframework.boot:spring-boot-starter-validation") diff --git a/src/main/kotlin/com/teamapi/palette/config/WebClientConfig.kt b/src/main/kotlin/com/teamapi/palette/config/WebClientConfig.kt new file mode 100644 index 0000000..3090ef9 --- /dev/null +++ b/src/main/kotlin/com/teamapi/palette/config/WebClientConfig.kt @@ -0,0 +1,16 @@ +package com.teamapi.palette.config + +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.Configuration +import org.springframework.web.reactive.function.client.ExchangeStrategies +import org.springframework.web.reactive.function.client.WebClient + +@Configuration +class WebClientConfig { + @Bean + fun client() = WebClient.builder().exchangeStrategies( + ExchangeStrategies.builder() + .codecs { it.defaultCodecs().maxInMemorySize(30 * 1024 * 1024) } + .build() + ).build() +} diff --git a/src/main/kotlin/com/teamapi/palette/service/ChatService.kt b/src/main/kotlin/com/teamapi/palette/service/ChatService.kt index 78e67fd..c77e827 100644 --- a/src/main/kotlin/com/teamapi/palette/service/ChatService.kt +++ b/src/main/kotlin/com/teamapi/palette/service/ChatService.kt @@ -1,9 +1,9 @@ package com.teamapi.palette.service -import com.azure.ai.openai.models.ChatCompletions -import com.azure.ai.openai.models.ChatCompletionsOptions -import com.azure.ai.openai.models.ChatRequestSystemMessage -import com.azure.ai.openai.models.ChatRequestUserMessage +import com.azure.ai.openai.models.* +import com.azure.core.util.BinaryData +import com.azure.storage.blob.BlobServiceAsyncClient +import com.teamapi.palette.service.infra.comfy.GenerateRequest import com.teamapi.palette.dto.response.chat.ChatResponse import com.teamapi.palette.entity.chat.Chat import com.teamapi.palette.entity.consts.ChatState @@ -17,14 +17,19 @@ import com.teamapi.palette.response.ErrorCode import com.teamapi.palette.response.exception.CustomException import com.teamapi.palette.service.infra.ChatEmitService import com.teamapi.palette.service.infra.GenerativeChatService +import com.teamapi.palette.service.infra.GenerativeImageService import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async +import kotlinx.coroutines.reactor.awaitSingle import kotlinx.datetime.Instant import org.slf4j.LoggerFactory import org.springframework.data.domain.PageRequest import org.springframework.stereotype.Service import reactor.core.publisher.Mono +import java.util.* +import kotlin.io.encoding.Base64 +import kotlin.io.encoding.ExperimentalEncodingApi @Service class ChatService( @@ -34,6 +39,8 @@ class ChatService( private val sessionHolder: SessionHolder, private val roomRepository: RoomRepository, private val generativeChatService: GenerativeChatService, + private val generativeImageService: GenerativeImageService, + private val blob: BlobServiceAsyncClient ) { private val log = LoggerFactory.getLogger(ChatService::class.java) @@ -41,6 +48,7 @@ class ChatService( return qnA.qna.find { it.answer == null } } + @OptIn(ExperimentalEncodingApi::class) suspend fun createChat(roomId: Long, message: T) { val userId = sessionHolder.me() val room = roomRepository.findById(roomId) ?: throw CustomException(ErrorCode.ROOM_NOT_FOUND) @@ -77,7 +85,11 @@ class ChatService( val maxSize = grid.question.maxCount val exceeds = message.choice.filter { it >= gridPossibleMax } if (exceeds.isNotEmpty()) - throw CustomException(ErrorCode.QNA_INVALID_GRID_CHOICES, exceeds.joinToString(", "), gridPossibleMax - 1) + throw CustomException( + ErrorCode.QNA_INVALID_GRID_CHOICES, + exceeds.joinToString(", "), + gridPossibleMax - 1 + ) if (message.choice.size > maxSize) throw CustomException(ErrorCode.QNA_INVALID_GRID_ABOVE_MAX, maxSize) @@ -128,30 +140,82 @@ class ChatService( val pendingQnAs = toBeResolved.qna.filter { it.answer == null } if (pendingQnAs.isEmpty()) { // TODO: Handle Image processing + val release = toBeResolved.qna + + val title = release.find { it.promptName == "title" }!!.answer as ChatAnswer.UserInputAnswer + val explain = release.find { it.promptName == "product_explanation" }!!.answer as ChatAnswer.UserInputAnswer + val aspectQnA = release.find { it.promptName == "aspect_ratio" }!! as PromptData.Selectable + val aspectAns = aspectQnA.answer!! + val grid = release.find { it.promptName == "title_position" }!!.answer as ChatAnswer.GridAnswer + + val userReturn = createUserReturn(explain.input).awaitSingle() chatEmitService.emitChat( Chat( + resource = ChatState.CHAT, roomId = room.id, userId = userId, isAi = true, - message = "원하는 질문이 다 채워졌따. 만드느라 수고햇다. 이제 서버에서 이미지 생성 다만들기를 기다려" + message = userReturn.choices.random().message.content ) ) - return@async - } - val addResponse = pendingQnAs.first() - val generated = generativeChatService.roomPromptMessage(addResponse.promptName) + val (width, height) = when (aspectAns.choiceId) { + "DISPLAY" -> 1820 to 1024 + "PAPER" -> 1444 to 1024 + "SQUARE" -> 1024 to 1024 + else -> 1365 to 1024 // TABLET + } - chatEmitService.emitChat( - Chat( - resource = ChatState.PROMPT, - roomId = room.id, - userId = userId, - isAi = true, - message = generated.choices.random().message.content, - promptId = addResponse.id + val prompt = createPrompt(explain.input).awaitSingle() + val generated = generativeImageService.draw( + GenerateRequest( + title.input, + grid.choice[0], + width, + height, + prompt.choices.random().message.content + ) ) - ) + if (!generated.result) { + chatEmitService.emitChat( + Chat( + resource = ChatState.CHAT, + roomId = room.id, + userId = userId, + isAi = true, + message = "이미지를 생성하는 도중 오류가 발생하였어요. ;.;" + ) + ) + } else { + val space = blob.getBlobContainerAsyncClient("palette") + val blobClient = space.getBlobAsyncClient("${UUID.randomUUID()}.png") + blobClient.upload(BinaryData.fromBytes(Base64.decode(generated.image!!))).awaitSingle() + + chatEmitService.emitChat( + Chat( + resource = ChatState.IMAGE, + roomId = room.id, + userId = userId, + isAi = true, + message = blobClient.blobUrl + ) + ) + } + } else { + val addResponse = pendingQnAs.first() + val generated = generativeChatService.roomPromptMessage(addResponse.promptName) + + chatEmitService.emitChat( + Chat( + resource = ChatState.PROMPT, + roomId = room.id, + userId = userId, + isAi = true, + message = generated.choices.random().message.content, + promptId = addResponse.id + ) + ) + } }.invokeOnCompletion { it?.let { log.error("error", it) @@ -260,24 +324,31 @@ class ChatService( ) ) - // fun createPrompt(text: String) = chatCompletion( -// ChatCompletionsOptions( -// listOf( -// ChatRequestSystemMessage( -// "You must enter a sentence in Korean or English and extract the keywords for the sentence. All words should be in English and the words should be separated by a semicolon (',') and an underscore ('_') if there is a space in a word. Also, there should be no words other than a semicolon and any words. Produce a few more related words if the number of extracted words is less than 5. Map to the 'drawable' words to help generate posters. Give all words lowercase." -// ), -// ChatRequestUserMessage( -// "내가 만든 오렌지 주스를 광고하고 싶어. 오렌지 과즙이 주변에 터졌으면 좋겠고, 오렌지 주스가 담긴 컵과 오렌지 주스가 있었으면 좋겠어. 배경은 집 안이였으면 좋겠어." -// ), -// ChatRequestAssistantMessage( -// "orange, orange_juice, in_house, a_cup_with_orange_juice, juice" -// ), -// ChatRequestUserMessage( -// text -// ) -// ) -// ) -// ) + fun createPrompt(text: String) = generativeChatService.chatCompletion( + ChatCompletionsOptions( + listOf( + ChatRequestSystemMessage( + "You must enter a sentence in Korean or English and extract the keywords for the sentence. All words should be in English and the words should be separated by a semicolon (',') and an underscore ('_') if there is a space in a word. Also, there should be no words other than a semicolon and any words. Produce a few more related words if the number of extracted words is less than 5. Map to the 'drawable' words to help generate posters. Give all words lowercase." + ), + ChatRequestUserMessage( + "내가 만든 오렌지 주스를 광고하고 싶어. 오렌지 과즙이 주변에 터졌으면 좋겠고, 오렌지 주스가 담긴 컵과 오렌지 주스가 있었으면 좋겠어. 배경은 집 안이였으면 좋겠어." + ), + ChatRequestAssistantMessage( + "orange, orange_juice, in_house, a_cup_with_orange_juice, juice" + ), + ChatRequestUserMessage( + "우리 상수 목공방에서 자랑하는 멋진 핑크색 상수목재를 홍보하고 싶다. 제목은 금색으로 해줘" + ), + ChatRequestAssistantMessage( + "pink planks with forest background, modern and simple design, highlight color with gold" + ), + ChatRequestUserMessage( + text + ) + ) + ) + ) + suspend fun getMyImage(pageNumber: Int, pageSize: Int): List { val page = PageRequest.of(pageNumber, pageSize) val userId = sessionHolder.me() diff --git a/src/main/kotlin/com/teamapi/palette/service/infra/GenerativeImageService.kt b/src/main/kotlin/com/teamapi/palette/service/infra/GenerativeImageService.kt index b1d407d..c10b20d 100644 --- a/src/main/kotlin/com/teamapi/palette/service/infra/GenerativeImageService.kt +++ b/src/main/kotlin/com/teamapi/palette/service/infra/GenerativeImageService.kt @@ -1,16 +1,25 @@ package com.teamapi.palette.service.infra -import com.azure.ai.openai.OpenAIAsyncClient -import com.azure.ai.openai.models.ImageGenerationOptions +import com.teamapi.palette.service.infra.comfy.GenerateRequest +import com.teamapi.palette.service.infra.comfy.GenerateResponse +import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import org.springframework.stereotype.Service +import org.springframework.web.reactive.function.client.WebClient +import org.springframework.web.reactive.function.client.awaitBody +import org.springframework.web.reactive.function.client.awaitExchange @Service class GenerativeImageService( - private val azure: OpenAIAsyncClient, + private val client: WebClient, private val mapper: Json, ) { - fun draw(originalText: String) = - azure.getImageGenerations("Dalle3", ImageGenerationOptions(originalText)) - .handleAzureError(mapper) + suspend fun draw(prompt: GenerateRequest): GenerateResponse { + return client.post() + .uri("https://comfy.paletteapp.xyz/gen") + .bodyValue(mapper.encodeToString(prompt)) + .header("content-type", "application/json") + .awaitExchange { it.awaitBody() } + } + } diff --git a/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateRequest.kt b/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateRequest.kt new file mode 100644 index 0000000..7da20ef --- /dev/null +++ b/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateRequest.kt @@ -0,0 +1,12 @@ +package com.teamapi.palette.service.infra.comfy + +import kotlinx.serialization.Serializable + +@Serializable +data class GenerateRequest( + val title: String, + val pos: Int, + val width: Int, + val height: Int, + val prompt: String +) diff --git a/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateResponse.kt b/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateResponse.kt new file mode 100644 index 0000000..6a1b57d --- /dev/null +++ b/src/main/kotlin/com/teamapi/palette/service/infra/comfy/GenerateResponse.kt @@ -0,0 +1,6 @@ +package com.teamapi.palette.service.infra.comfy + +import kotlinx.serialization.Serializable + +@Serializable +data class GenerateResponse(val result: Boolean, val image: String? = null, val error: String? = null) diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 5cebd8c..8920f2d 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -33,6 +33,13 @@ spring: enable: true jooq: sql-dialect: ${R2DBC_DIALECT} + cloud: + azure: + storage: + blob: + endpoint: ${AZURE_STORAGE_ENDPOINT} + connection-string: ${AZURE_STORAGE_CONNECTION_STRING} + springdoc: enable-kotlin: true enable-spring-security: true