Skip to content

Commit

Permalink
feat :: prepare qna type chatting - db spec changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jombidev committed Sep 25, 2024
1 parent 9410779 commit e269a71
Show file tree
Hide file tree
Showing 22 changed files with 325 additions and 78 deletions.
11 changes: 9 additions & 2 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile

plugins {
id("org.springframework.boot") version "3.3.0"
id("io.spring.dependency-management") version "1.1.4"
kotlin("jvm") version "1.9.23"
kotlin("plugin.spring") version "1.9.23"
kotlin("plugin.serialization") version "1.9.23"

id("org.springframework.boot") version "3.3.0"
id("io.spring.dependency-management") version "1.1.4"

id("org.jooq.jooq-codegen-gradle") version "3.19.11"
}

Expand Down Expand Up @@ -40,6 +43,10 @@ dependencies {
implementation("org.springframework.session:spring-session-data-redis")
implementation("org.springframework.boot:spring-boot-starter-data-redis-reactive")
implementation("org.springframework.boot:spring-boot-starter-data-mongodb-reactive")
implementation("org.mongodb:mongodb-driver-kotlin-coroutine:5.0.1")
implementation("org.mongodb:bson-kotlinx:5.0.1")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-core")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json")

implementation("org.springframework.ai:spring-ai-azure-openai-spring-boot-starter")
implementation(group = "io.netty", name = "netty-resolver-dns-native-macos", classifier = "osx-aarch_64")
Expand Down
50 changes: 43 additions & 7 deletions src/main/kotlin/com/teamapi/palette/config/DatabaseConfig.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
package com.teamapi.palette.config

import com.mongodb.MongoClientSettings
import com.mongodb.kotlin.client.coroutine.MongoClient
import com.mongodb.kotlin.client.coroutine.MongoDatabase
import com.teamapi.palette.util.Jsr310CodecProvider
import io.r2dbc.spi.ConnectionFactory
import org.bson.codecs.configuration.CodecRegistries
import org.jooq.impl.DefaultConfiguration
import org.jooq.impl.DefaultDSLContext
import org.springframework.boot.autoconfigure.mongo.MongoClientSettingsBuilderCustomizer
import org.springframework.boot.autoconfigure.mongo.MongoConnectionDetails
import org.springframework.boot.autoconfigure.mongo.MongoProperties
import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.data.mongodb.repository.config.EnableMongoRepositories
Expand All @@ -11,22 +20,49 @@ import org.springframework.data.redis.core.ReactiveRedisTemplate
import org.springframework.data.redis.serializer.RedisSerializationContext
import org.springframework.data.redis.serializer.RedisSerializer
import org.springframework.r2dbc.connection.TransactionAwareConnectionFactoryProxy
import com.mongodb.reactivestreams.client.MongoClient as JMongoClient

@Configuration
@EnableMongoRepositories
@EnableConfigurationProperties(MongoProperties::class)
class DatabaseConfig {
@Bean
fun redisTemplateForRepository(redisConnectionFactory: ReactiveRedisConnectionFactory): ReactiveRedisTemplate<String, Any> {
return ReactiveRedisTemplate(redisConnectionFactory, RedisSerializationContext
.newSerializationContext<String, Any>()
.key(RedisSerializer.string())
.hashKey(RedisSerializer.string())
.value(RedisSerializer.java())
.hashValue(RedisSerializer.java())
.build()
return ReactiveRedisTemplate(
redisConnectionFactory,
RedisSerializationContext
.newSerializationContext<String, Any>()
.key(RedisSerializer.string())
.hashKey(RedisSerializer.string())
.value(RedisSerializer.java())
.hashValue(RedisSerializer.java())
.build()
)
}

@Bean
fun customSetting() = MongoClientSettingsBuilderCustomizer {
it.codecRegistry(
CodecRegistries.fromRegistries(
MongoClientSettings.getDefaultCodecRegistry(),
CodecRegistries.fromProviders(Jsr310CodecProvider)
)
)
}

@Bean
fun mongoClient(client: JMongoClient): MongoClient = MongoClient(client)

@Bean
fun coroutineMongoTemplate(
client: MongoClient, props: MongoProperties,
connectionDetails: MongoConnectionDetails
): MongoDatabase = client.getDatabase(
props.database
?: connectionDetails.connectionString.database
?: throw RuntimeException("Error while constructing mongo: No database received.")
)

@Bean
fun dslContext(connectionFactory: ConnectionFactory) = DefaultDSLContext(
DefaultConfiguration().set(TransactionAwareConnectionFactoryProxy(connectionFactory))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.swagger.v3.oas.annotations.Parameter
import org.springdoc.core.converters.models.PageableAsQueryParam
import org.springframework.http.ResponseEntity
import org.springframework.web.bind.annotation.*
import java.time.OffsetDateTime

@RestController
@RequestMapping("/chat")
Expand All @@ -30,7 +31,7 @@ class ChatController(
@GetMapping("/{roomId}")
suspend fun getChatList(
@PathVariable("roomId") roomId: Long,
@RequestParam(required = false) before: Long = System.currentTimeMillis(),
@RequestParam(required = false) before: String = OffsetDateTime.now().toString(),
@RequestParam(required = false) size: Long = 25
): ResponseEntity<ResponseBody<List<ChatResponse>>> {
val data = chatService.getChatList(
Expand Down
13 changes: 8 additions & 5 deletions src/main/kotlin/com/teamapi/palette/dto/chat/ChatResponse.kt
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.teamapi.palette.dto.chat

import java.time.LocalDateTime
import com.teamapi.palette.entity.chat.PromptData
import com.teamapi.palette.entity.consts.ChatState
import java.time.ZonedDateTime

data class ChatResponse(
val id: Long?,
val message: String,
val datetime: LocalDateTime,
val id: String?,
val message: String?,
val resource: ChatState,
val datetime: ZonedDateTime,
val roomId: Long,
val userId: Long,
val isAi: Boolean,
val resource: String
val data: PromptData?,
)
7 changes: 5 additions & 2 deletions src/main/kotlin/com/teamapi/palette/entity/Room.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ import org.springframework.data.relational.core.mapping.Table
data class Room(
@Id
val id: Long? = null,
val title: String? = "New Chat",
@Column("user_id")
val userId: Long
val userId: Long,
val title: String? = "New Chat",
@Column("base_prompt")
val basePrompt: String? = null,
val aspect: String? = null, // will be enum
) {
suspend fun validateUser(sessionHolder: SessionHolder): Room {
val me = sessionHolder.me()
Expand Down
32 changes: 18 additions & 14 deletions src/main/kotlin/com/teamapi/palette/entity/chat/Chat.kt
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
package com.teamapi.palette.entity.chat

import com.teamapi.palette.dto.chat.ChatResponse
import org.springframework.data.annotation.Id
import com.teamapi.palette.entity.consts.ChatState
import com.teamapi.palette.repository.mongo.MongoDatabases
import org.bson.codecs.pojo.annotations.BsonId
import org.bson.types.ObjectId
import org.springframework.data.mongodb.core.mapping.Document
import org.springframework.data.relational.core.mapping.Column
import java.time.LocalDateTime
import java.time.ZonedDateTime

@Document("chats")
@Document(MongoDatabases.CHAT)
data class Chat(
@Id
val id: Long? = null,
val message: String,
val resource: String = "CHAT",
val datetime: LocalDateTime,
@Column("room_id")
@BsonId
val id: ObjectId = ObjectId.get(),
val datetime: ZonedDateTime, // MUST-INCLUDED
val resource: ChatState = ChatState.CHAT,

// default property
val message: String? = null,
val roomId: Long,
@Column("user_id")
val userId: Long,
@Column("is_ai")
val isAi: Boolean
val isAi: Boolean,

// additional prompt data
val data: PromptData? = null
) {
fun toDto() = ChatResponse(id, message, datetime, roomId, userId, isAi, resource)
fun toDto() = ChatResponse(id.toString(), message, resource, datetime, roomId, userId, isAi, data)
}
14 changes: 14 additions & 0 deletions src/main/kotlin/com/teamapi/palette/entity/chat/PromptData.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.teamapi.palette.entity.chat

import com.teamapi.palette.entity.consts.PromptType
import kotlinx.serialization.Serializable

@Serializable
sealed class PromptData(val type: PromptType) {
@Serializable
data class Selectable(val choice: List<String>) : PromptData(PromptType.SELECTABLE)
@Serializable
data class Grid(val xSize: Int, val ySize: Int) : PromptData(PromptType.GRID)
@Serializable
data object UserInput : PromptData(PromptType.USER_INPUT)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.teamapi.palette.entity.consts

enum class ChatState {
CHAT,
IMAGE,
PROMPT_USER_INPUT,
PROMPT_SELECT,
PROMPT_GRID,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.teamapi.palette.entity.consts

enum class PromptType {
USER_INPUT,
SELECTABLE,
GRID
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.teamapi.palette.repository

import com.teamapi.palette.entity.VerifyCode
import com.teamapi.palette.repository.redis.ReactiveRedisCrudRepository
import kotlinx.coroutines.reactive.awaitFirstOrNull
import org.springframework.data.redis.core.*
import org.springframework.stereotype.Repository

Expand All @@ -22,7 +21,6 @@ class VerifyCodeRepository(private val redisOperations: ReactiveRedisTemplate<St

override suspend fun findById(id: Long): VerifyCode? {
val kv = redisOperations.opsForValue()
println(kv.get("$namespace$id").awaitFirstOrNull())
return kv.getAndAwait("$namespace$id") as? VerifyCode
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.teamapi.palette.repository.chat

import com.teamapi.palette.dto.chat.ChatResponse
import org.springframework.data.domain.Pageable
import java.time.ZonedDateTime

interface ChatQueryRepository {
suspend fun getImagesByUserId(userId: Long, pageable: Pageable): List<String>
suspend fun getMessageByRoomId(roomId: Long, offset: ZonedDateTime, size: Long): List<ChatResponse>
suspend fun getLatestMessageMapById(roomIds: List<Long>): Map<Long, String?>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.teamapi.palette.repository.chat

import com.mongodb.client.model.Aggregates
import com.mongodb.client.model.Projections
import com.mongodb.kotlin.client.coroutine.MongoDatabase
import com.teamapi.palette.dto.chat.ChatResponse
import com.teamapi.palette.entity.chat.Chat
import com.teamapi.palette.entity.consts.ChatState
import com.teamapi.palette.repository.mongo.*
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import org.bson.codecs.pojo.annotations.BsonId
import org.springframework.data.domain.Pageable
import java.time.LocalDateTime
import java.time.ZonedDateTime

class ChatQueryRepositoryImpl(
private val mongoDb: MongoDatabase
) : ChatQueryRepository {
// used internally
internal data class ImageFoundResult(val message: String)
internal data class LastMessageResult(@BsonId val id: Long, val lastMessage: String?)

override suspend fun getImagesByUserId(userId: Long, pageable: Pageable): List<String> {
val collection = mongoDb.getCollection<Chat>(MongoDatabases.CHAT)

return collection
.find<ImageFoundResult>((Chat::userId eq userId) and (Chat::resource eq ChatState.IMAGE))
.sort(Chat::datetime.desc())
.skip(pageable.offset.toInt())
.limit(pageable.pageSize)
.projection(Projections.include(Chat::message.name))
.map { it.message }
.toList()
}

override suspend fun getMessageByRoomId(roomId: Long, offset: ZonedDateTime, size: Long): List<ChatResponse> {
val collection = mongoDb.getCollection<Chat>(MongoDatabases.CHAT)

return collection
.find(Chat::datetime lt offset and (Chat::roomId eq roomId))
.sort(Chat::datetime.desc())
.limit(size.toInt())
.toList()
.map { it.toDto() }
}

override suspend fun getLatestMessageMapById(roomIds: List<Long>): Map<Long, String?> {
val collection = mongoDb.getCollection<Chat>(MongoDatabases.CHAT)

return collection.aggregate<LastMessageResult>(
Aggregates.match((Chat::roomId `in` roomIds) and (Chat::resource ne ChatState.IMAGE)),
Aggregates.group(Chat::roomId.literal, Chat::message.getLastAs("lastMessage")),
).toList().associate { it.id to it.lastMessage }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ package com.teamapi.palette.repository.chat
import org.springframework.stereotype.Repository

@Repository
interface ChatRepository : ChatR2dbcRepository
interface ChatRepository : ChatR2dbcRepository, ChatQueryRepository
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.teamapi.palette.repository.mongo

data object MongoDatabases {
const val CHAT = "chats"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.teamapi.palette.repository.mongo

import com.mongodb.client.model.Accumulators
import com.mongodb.client.model.BsonField
import com.mongodb.client.model.Filters
import com.mongodb.client.model.Sorts
import com.mongodb.kotlin.client.coroutine.MongoCollection
import org.bson.conversions.Bson
import org.springframework.data.mapping.toDotPath
import kotlin.reflect.KProperty1

infix fun <R, T : Any?> KProperty1<R, T>.eq(other: T?): Bson = Filters.eq(toDotPath(), other)
infix fun <R, T> KProperty1<R, T>.ne(other: T?): Bson = Filters.ne(toDotPath(), other)
infix fun <R, T : Any> KProperty1<R, T>.lt(other: T): Bson = Filters.lt(toDotPath(), other)
infix fun <R, T : Any> KProperty1<R, T>.lte(other: T): Bson = Filters.lte(toDotPath(), other)
infix fun <R, T : Any> KProperty1<R, T>.gt(other: T): Bson = Filters.gt(toDotPath(), other)
infix fun <R, T : Any> KProperty1<R, T>.gte(other: T): Bson = Filters.gte(toDotPath(), other)
infix fun <R, T : Any?> KProperty1<R, T>.`in`(others: List<T>): Bson = Filters.`in`(toDotPath(), others)

//fun <T : Any> MongoCollection<T>.aggregate(vararg pipeline: Bson) = aggregate(pipeline.toList())
inline fun <reified T : Any> MongoCollection<*>.aggregate(vararg pipeline: Bson) = aggregate<T>(pipeline.toList())

infix fun Bson.and(bson: Bson) = Filters.and(this, bson)

fun KProperty1<*, *>.getFirstAs(name: String): BsonField = Accumulators.first(name, literal)
fun KProperty1<*, *>.getLastAs(name: String): BsonField = Accumulators.last(name, literal)
val KProperty1<*, *>.literal: String get() = "\$${toDotPath()}"

fun KProperty1<*, *>.asc() = Sorts.ascending(name)
fun KProperty1<*, *>.desc() = Sorts.descending(name)
Loading

0 comments on commit e269a71

Please sign in to comment.