Skip to content

Commit

Permalink
add default value support
Browse files Browse the repository at this point in the history
  • Loading branch information
klahap committed Dec 17, 2024
1 parent cb5055e commit 473b126
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 27 deletions.
1 change: 0 additions & 1 deletion src/main/kotlin/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ private fun generateSpec(config: Config) {
.map { if (it is Table.Column.Type.NonPrimitive.Array) it.elementType else it }
.filterIsInstance<Table.Column.Type.NonPrimitive.Enum>().map { it.name }.toSet()
val enums = dbService.getEnums(enumNames)
// TODO add/try view tables
PgenSpec(tables = tables, enums = enums)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/main/kotlin/model/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ data class Table(
val name: ColumnName,
val type: Type,
val isNullable: Boolean = false,
val default: String? = null,
) {
val prettyName get() = name.pretty

Expand Down
2 changes: 2 additions & 0 deletions src/main/kotlin/service/DbService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class DbService(
c.udt_name AS column_type_name,
c.numeric_precision AS numeric_precision,
c.numeric_scale AS numeric_scale,
c.column_default AS column_default,
ty.typcategory AS column_type_category,
tye.typcategory AS column_element_type_category
FROM information_schema.columns AS c
Expand All @@ -137,6 +138,7 @@ class DbService(
name = Table.ColumnName(resultSet.getString("column_name")!!),
type = resultSet.getColumnType(),
isNullable = resultSet.getBoolean("is_nullable"),
default = resultSet.getString("column_default")
)
}.groupBy({ it.first }, { it.second })
}
Expand Down
83 changes: 59 additions & 24 deletions src/main/kotlin/util/codegen/Column.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,41 @@ fun Table.Column.Type.getTypeName(): TypeName {
}
}

private fun Table.Column.getDefaultExpression(): Pair<String, List<Any>>? = when (type) {
Table.Column.Type.Primitive.TIMESTAMP_WITH_TIMEZONE -> when (default) {
"now()" -> ".defaultExpression(%T)" to listOf(Poet.defaultExpTimestampZ)
else -> null
}

Table.Column.Type.Primitive.UUID -> when (default) {
"gen_random_uuid()" -> ".defaultExpression(%T(%S, %T()))" to
listOf(Poet.customFunction, "gen_random_uuid", Poet.uuidColumnType)

else -> null
}

Table.Column.Type.Primitive.BOOL -> when (default) {
"false" -> ".default(false)" to emptyList()
"true" -> ".default(true)" to emptyList()
else -> null
}

else -> null
}

context(CodeGenContext)
fun PropertySpec.Builder.initializer(column: Table.Column, postFix: String, vararg postArgs: Any) {
fun PropertySpec.Builder.initializer(column: Table.Column, postfix: String, postArgs: List<Any>) {
val columnName = column.name.value
var postfix = postfix
var postArgs = postArgs.toTypedArray()
column.getDefaultExpression()?.let {
postfix = it.first + postfix
postArgs = (it.second + postArgs.toList()).toTypedArray()
}

when (val type = column.type) {
is Table.Column.Type.NonPrimitive.Array -> initializer(
"array<%T>(name = %S)$postFix",
"array<%T>(name = %S)$postfix",
type.getTypeName(), columnName, *postArgs
)

Expand All @@ -54,68 +83,74 @@ fun PropertySpec.Builder.initializer(column: Table.Column, postFix: String, vara
sql = %S,
fromDb = { %T<%T>(it as String) },
toDb = { it.toPgObject() },
)$postFix""".trimIndent(),
)$postfix""".trimIndent(),
columnName,
type.name.name,
typeNameGetPgEnumByLabel,
type.name.typeName,
*postArgs
)

Table.Column.Type.Primitive.INT8 -> initializer("long(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.BOOL -> initializer("bool(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.BINARY -> initializer("binary(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.VARCHAR -> initializer("text(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.DATE -> initializer("%T(name = %S)$postFix", Poet.date, columnName, *postArgs)
Table.Column.Type.Primitive.INTERVAL -> initializer("duration(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.INT8 -> initializer("long(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.BOOL -> initializer("bool(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.BINARY -> initializer("binary(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.VARCHAR -> initializer("text(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.DATE -> initializer("%T(name = %S)$postfix", Poet.date, columnName, *postArgs)
Table.Column.Type.Primitive.INTERVAL -> initializer("duration(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.INT4RANGE -> initializer(
"registerColumn(name = %S, type = %T())$postFix",
"registerColumn(name = %S, type = %T())$postfix",
columnName, typeNameInt4RangeColumnType, *postArgs
)

Table.Column.Type.Primitive.INT8RANGE -> initializer(
"registerColumn(name = %S, type = %T())$postFix",
"registerColumn(name = %S, type = %T())$postfix",
columnName, typeNameInt8RangeColumnType, *postArgs
)

Table.Column.Type.Primitive.INT4MULTIRANGE -> initializer(
"registerColumn(name = %S, type = %T())$postFix",
"registerColumn(name = %S, type = %T())$postfix",
columnName, typeNameInt4MultiRangeColumnType, *postArgs
)

Table.Column.Type.Primitive.INT8MULTIRANGE -> initializer(
"registerColumn(name = %S, type = %T())$postFix",
"registerColumn(name = %S, type = %T())$postfix",
columnName, typeNameInt8MultiRangeColumnType, *postArgs
)

Table.Column.Type.Primitive.INT4 -> initializer("integer(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.INT4 -> initializer("integer(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.JSON -> initializer(
"%T<%T>(name = %S, serialize = %T)$postFix",
"%T<%T>(name = %S, serialize = %T)$postfix",
Poet.jsonColumn, Poet.jsonElement, columnName, Poet.json, *postArgs
)

Table.Column.Type.Primitive.JSONB -> initializer(
"%T<%T>(name = %S, jsonConfig = %T)$postFix",
"%T<%T>(name = %S, jsonConfig = %T)$postfix",
Poet.jsonColumn, Poet.jsonElement, columnName, Poet.json, *postArgs
)

is Table.Column.Type.NonPrimitive.Numeric -> initializer(
"decimal(name = %S, precision = ${type.precision}, scale = ${type.scale})$postFix",
"decimal(name = %S, precision = ${type.precision}, scale = ${type.scale})$postfix",
columnName, *postArgs
)

Table.Column.Type.Primitive.INT2 -> initializer("short(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.TEXT -> initializer("text(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.TIME -> initializer("%T(name = %S)$postFix", Poet.time, columnName, *postArgs)
Table.Column.Type.Primitive.TIMESTAMP -> initializer("%T(name = %S)$postFix", Poet.timestamp, columnName, *postArgs)
Table.Column.Type.Primitive.INT2 -> initializer("short(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.TEXT -> initializer("text(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.TIME -> initializer("%T(name = %S)$postfix", Poet.time, columnName, *postArgs)
Table.Column.Type.Primitive.TIMESTAMP -> initializer(
"%T(name = %S)$postfix",
Poet.timestamp,
columnName,
*postArgs
)

Table.Column.Type.Primitive.TIMESTAMP_WITH_TIMEZONE -> initializer(
"%T(name = %S)$postFix",
"%T(name = %S)$postfix",
Poet.timestampWithTimeZone, columnName, *postArgs
)

Table.Column.Type.Primitive.UUID -> initializer("uuid(name = %S)$postFix", columnName, *postArgs)
Table.Column.Type.Primitive.UUID -> initializer("uuid(name = %S)$postfix", columnName, *postArgs)
Table.Column.Type.Primitive.UNCONSTRAINED_NUMERIC -> initializer(
"registerColumn(name = %S, type = %T())$postFix",
"registerColumn(name = %S, type = %T())$postfix",
columnName, typeNameUnconstrainedNumericColumnType, *postArgs
)
}
Expand Down
4 changes: 4 additions & 0 deletions src/main/kotlin/util/codegen/Common.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ object Poet {
val localTime = ClassName("kotlinx.datetime", "LocalTime")
val localDate = ClassName("kotlinx.datetime", "LocalDate")
val offsetDateTime = OffsetDateTime::class.asTypeName()

val defaultExpTimestampZ = ClassName("org.jetbrains.exposed.sql.kotlin.datetime", "CurrentTimestampWithTimeZone")
val customFunction = ClassName("org.jetbrains.exposed.sql", "CustomFunction")
val uuidColumnType = ClassName("org.jetbrains.exposed.sql", "UUIDColumnType")
}

context(CodeGenContext)
Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/util/codegen/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ internal fun Table.toTypeSpecInternal() = buildObject(this@toTypeSpecInternal.na
),
) {
val postArgs = mutableListOf<Any>()
val postFix = buildString {
val postfix = buildString {
foreignKeysSingle[column.name]?.let { foreignKey ->
append(".references(%T.${foreignKey.second.pretty})")
postArgs.add(foreignKey.first.typeName)
}
if (column.isNullable)
append(".nullable()")
}
initializer(column, postFix = postFix, postArgs = postArgs.toTypedArray())
initializer(column, postfix = postfix, postArgs = postArgs)
}
}
if (this@toTypeSpecInternal.primaryKey != null) {
Expand Down

0 comments on commit 473b126

Please sign in to comment.