Skip to content

Commit

Permalink
feat :: Image generation on final chat request
Browse files Browse the repository at this point in the history
  • Loading branch information
jombidev committed Oct 16, 2024
1 parent bd81dbe commit 1370f75
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 44 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies {

implementation("org.springframework.ai:spring-ai-azure-openai-spring-boot-starter")
implementation(group = "io.netty", name = "netty-resolver-dns-native-macos", classifier = "osx-aarch_64")
// implementation("com.azure.spring:spring-cloud-azure-starter-storage-blob")
implementation("com.azure.spring:spring-cloud-azure-starter-storage-blob")
// implementation("com.azure.spring:spring-cloud-azure-starter-storage")

implementation("org.springframework.boot:spring-boot-starter-validation")
Expand Down
16 changes: 16 additions & 0 deletions src/main/kotlin/com/teamapi/palette/config/WebClientConfig.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.teamapi.palette.config

import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.web.reactive.function.client.ExchangeStrategies
import org.springframework.web.reactive.function.client.WebClient

@Configuration
class WebClientConfig {
@Bean
fun client() = WebClient.builder().exchangeStrategies(
ExchangeStrategies.builder()
.codecs { it.defaultCodecs().maxInMemorySize(30 * 1024 * 1024) }
.build()
).build()
}
145 changes: 108 additions & 37 deletions src/main/kotlin/com/teamapi/palette/service/ChatService.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.teamapi.palette.service

import com.azure.ai.openai.models.ChatCompletions
import com.azure.ai.openai.models.ChatCompletionsOptions
import com.azure.ai.openai.models.ChatRequestSystemMessage
import com.azure.ai.openai.models.ChatRequestUserMessage
import com.azure.ai.openai.models.*
import com.azure.core.util.BinaryData
import com.azure.storage.blob.BlobServiceAsyncClient
import com.teamapi.palette.service.infra.comfy.GenerateRequest
import com.teamapi.palette.dto.response.chat.ChatResponse
import com.teamapi.palette.entity.chat.Chat
import com.teamapi.palette.entity.consts.ChatState
Expand All @@ -17,14 +17,19 @@ import com.teamapi.palette.response.ErrorCode
import com.teamapi.palette.response.exception.CustomException
import com.teamapi.palette.service.infra.ChatEmitService
import com.teamapi.palette.service.infra.GenerativeChatService
import com.teamapi.palette.service.infra.GenerativeImageService
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.reactor.awaitSingle
import kotlinx.datetime.Instant
import org.slf4j.LoggerFactory
import org.springframework.data.domain.PageRequest
import org.springframework.stereotype.Service
import reactor.core.publisher.Mono
import java.util.*
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi

@Service
class ChatService(
Expand All @@ -34,13 +39,16 @@ class ChatService(
private val sessionHolder: SessionHolder,
private val roomRepository: RoomRepository,
private val generativeChatService: GenerativeChatService,
private val generativeImageService: GenerativeImageService,
private val blob: BlobServiceAsyncClient
) {
private val log = LoggerFactory.getLogger(ChatService::class.java)

private fun getPendingQuestion(qnA: QnA): PromptData? {
return qnA.qna.find { it.answer == null }
}

@OptIn(ExperimentalEncodingApi::class)
suspend fun <T : ChatAnswer> createChat(roomId: Long, message: T) {
val userId = sessionHolder.me()
val room = roomRepository.findById(roomId) ?: throw CustomException(ErrorCode.ROOM_NOT_FOUND)
Expand Down Expand Up @@ -77,7 +85,11 @@ class ChatService(
val maxSize = grid.question.maxCount
val exceeds = message.choice.filter { it >= gridPossibleMax }
if (exceeds.isNotEmpty())
throw CustomException(ErrorCode.QNA_INVALID_GRID_CHOICES, exceeds.joinToString(", "), gridPossibleMax - 1)
throw CustomException(
ErrorCode.QNA_INVALID_GRID_CHOICES,
exceeds.joinToString(", "),
gridPossibleMax - 1
)

if (message.choice.size > maxSize)
throw CustomException(ErrorCode.QNA_INVALID_GRID_ABOVE_MAX, maxSize)
Expand Down Expand Up @@ -128,30 +140,82 @@ class ChatService(
val pendingQnAs = toBeResolved.qna.filter { it.answer == null }
if (pendingQnAs.isEmpty()) {
// TODO: Handle Image processing
val release = toBeResolved.qna

val title = release.find { it.promptName == "title" }!!.answer as ChatAnswer.UserInputAnswer
val explain = release.find { it.promptName == "product_explanation" }!!.answer as ChatAnswer.UserInputAnswer
val aspectQnA = release.find { it.promptName == "aspect_ratio" }!! as PromptData.Selectable
val aspectAns = aspectQnA.answer!!
val grid = release.find { it.promptName == "title_position" }!!.answer as ChatAnswer.GridAnswer

val userReturn = createUserReturn(explain.input).awaitSingle()
chatEmitService.emitChat(
Chat(
resource = ChatState.CHAT,
roomId = room.id,
userId = userId,
isAi = true,
message = "원하는 질문이 다 채워졌따. 만드느라 수고햇다. 이제 서버에서 이미지 생성 다만들기를 기다려"
message = userReturn.choices.random().message.content
)
)
return@async
}

val addResponse = pendingQnAs.first()
val generated = generativeChatService.roomPromptMessage(addResponse.promptName)
val (width, height) = when (aspectAns.choiceId) {
"DISPLAY" -> 1820 to 1024
"PAPER" -> 1444 to 1024
"SQUARE" -> 1024 to 1024
else -> 1365 to 1024 // TABLET
}

chatEmitService.emitChat(
Chat(
resource = ChatState.PROMPT,
roomId = room.id,
userId = userId,
isAi = true,
message = generated.choices.random().message.content,
promptId = addResponse.id
val prompt = createPrompt(explain.input).awaitSingle()
val generated = generativeImageService.draw(
GenerateRequest(
title.input,
grid.choice[0],
width,
height,
prompt.choices.random().message.content
)
)
)
if (!generated.result) {
chatEmitService.emitChat(
Chat(
resource = ChatState.CHAT,
roomId = room.id,
userId = userId,
isAi = true,
message = "이미지를 생성하는 도중 오류가 발생하였어요. ;.;"
)
)
} else {
val space = blob.getBlobContainerAsyncClient("palette")
val blobClient = space.getBlobAsyncClient("${UUID.randomUUID()}.png")
blobClient.upload(BinaryData.fromBytes(Base64.decode(generated.image!!))).awaitSingle()

chatEmitService.emitChat(
Chat(
resource = ChatState.IMAGE,
roomId = room.id,
userId = userId,
isAi = true,
message = blobClient.blobUrl
)
)
}
} else {
val addResponse = pendingQnAs.first()
val generated = generativeChatService.roomPromptMessage(addResponse.promptName)

chatEmitService.emitChat(
Chat(
resource = ChatState.PROMPT,
roomId = room.id,
userId = userId,
isAi = true,
message = generated.choices.random().message.content,
promptId = addResponse.id
)
)
}
}.invokeOnCompletion {
it?.let {
log.error("error", it)
Expand Down Expand Up @@ -260,24 +324,31 @@ class ChatService(
)
)

// fun createPrompt(text: String) = chatCompletion(
// ChatCompletionsOptions(
// listOf(
// ChatRequestSystemMessage(
// "You must enter a sentence in Korean or English and extract the keywords for the sentence. All words should be in English and the words should be separated by a semicolon (',') and an underscore ('_') if there is a space in a word. Also, there should be no words other than a semicolon and any words. Produce a few more related words if the number of extracted words is less than 5. Map to the 'drawable' words to help generate posters. Give all words lowercase."
// ),
// ChatRequestUserMessage(
// "내가 만든 오렌지 주스를 광고하고 싶어. 오렌지 과즙이 주변에 터졌으면 좋겠고, 오렌지 주스가 담긴 컵과 오렌지 주스가 있었으면 좋겠어. 배경은 집 안이였으면 좋겠어."
// ),
// ChatRequestAssistantMessage(
// "orange, orange_juice, in_house, a_cup_with_orange_juice, juice"
// ),
// ChatRequestUserMessage(
// text
// )
// )
// )
// )
fun createPrompt(text: String) = generativeChatService.chatCompletion(
ChatCompletionsOptions(
listOf(
ChatRequestSystemMessage(
"You must enter a sentence in Korean or English and extract the keywords for the sentence. All words should be in English and the words should be separated by a semicolon (',') and an underscore ('_') if there is a space in a word. Also, there should be no words other than a semicolon and any words. Produce a few more related words if the number of extracted words is less than 5. Map to the 'drawable' words to help generate posters. Give all words lowercase."
),
ChatRequestUserMessage(
"내가 만든 오렌지 주스를 광고하고 싶어. 오렌지 과즙이 주변에 터졌으면 좋겠고, 오렌지 주스가 담긴 컵과 오렌지 주스가 있었으면 좋겠어. 배경은 집 안이였으면 좋겠어."
),
ChatRequestAssistantMessage(
"orange, orange_juice, in_house, a_cup_with_orange_juice, juice"
),
ChatRequestUserMessage(
"우리 상수 목공방에서 자랑하는 멋진 핑크색 상수목재를 홍보하고 싶다. 제목은 금색으로 해줘"
),
ChatRequestAssistantMessage(
"pink planks with forest background, modern and simple design, highlight color with gold"
),
ChatRequestUserMessage(
text
)
)
)
)

suspend fun getMyImage(pageNumber: Int, pageSize: Int): List<String> {
val page = PageRequest.of(pageNumber, pageSize)
val userId = sessionHolder.me()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package com.teamapi.palette.service.infra

import com.azure.ai.openai.OpenAIAsyncClient
import com.azure.ai.openai.models.ImageGenerationOptions
import com.teamapi.palette.service.infra.comfy.GenerateRequest
import com.teamapi.palette.service.infra.comfy.GenerateResponse
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import org.springframework.stereotype.Service
import org.springframework.web.reactive.function.client.WebClient
import org.springframework.web.reactive.function.client.awaitBody
import org.springframework.web.reactive.function.client.awaitExchange

@Service
class GenerativeImageService(
private val azure: OpenAIAsyncClient,
private val client: WebClient,
private val mapper: Json,
) {
fun draw(originalText: String) =
azure.getImageGenerations("Dalle3", ImageGenerationOptions(originalText))
.handleAzureError(mapper)
suspend fun draw(prompt: GenerateRequest): GenerateResponse {
return client.post()
.uri("https://comfy.paletteapp.xyz/gen")
.bodyValue(mapper.encodeToString(prompt))
.header("content-type", "application/json")
.awaitExchange { it.awaitBody<GenerateResponse>() }
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.teamapi.palette.service.infra.comfy

import kotlinx.serialization.Serializable

@Serializable
data class GenerateRequest(
val title: String,
val pos: Int,
val width: Int,
val height: Int,
val prompt: String
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.teamapi.palette.service.infra.comfy

import kotlinx.serialization.Serializable

@Serializable
data class GenerateResponse(val result: Boolean, val image: String? = null, val error: String? = null)
7 changes: 7 additions & 0 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ spring:
enable: true
jooq:
sql-dialect: ${R2DBC_DIALECT}
cloud:
azure:
storage:
blob:
endpoint: ${AZURE_STORAGE_ENDPOINT}
connection-string: ${AZURE_STORAGE_CONNECTION_STRING}

springdoc:
enable-kotlin: true
enable-spring-security: true
Expand Down

0 comments on commit 1370f75

Please sign in to comment.