Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: EXPOSED-561 Restructure code in MigrationUtils and SchemaUtils to avoid calling currentDialect.tableColumns(*tables) in MigrationUtils.statementsRequiredForDatabaseMigration twice #2279

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion exposed-core/api/exposed-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ public final class org/jetbrains/exposed/sql/Avg : org/jetbrains/exposed/sql/Fun
public fun toQueryBuilder (Lorg/jetbrains/exposed/sql/QueryBuilder;)V
}

public abstract class org/jetbrains/exposed/sql/BaseSchemaUtils {
public fun <init> ()V
protected final fun addMissingColumnsStatements ([Lorg/jetbrains/exposed/sql/Table;Ljava/util/Map;Z)Ljava/util/List;
public static synthetic fun addMissingColumnsStatements$default (Lorg/jetbrains/exposed/sql/BaseSchemaUtils;[Lorg/jetbrains/exposed/sql/Table;Ljava/util/Map;ZILjava/lang/Object;)Ljava/util/List;
protected final fun logTimeSpent (Ljava/lang/String;ZLkotlin/jvm/functions/Function0;)Ljava/lang/Object;
}

public class org/jetbrains/exposed/sql/BasicBinaryColumnType : org/jetbrains/exposed/sql/ColumnType {
public fun <init> ()V
public synthetic fun nonNullValueToString (Ljava/lang/Object;)Ljava/lang/String;
Expand Down Expand Up @@ -2170,7 +2177,7 @@ public final class org/jetbrains/exposed/sql/Schema {
public fun toString ()Ljava/lang/String;
}

public final class org/jetbrains/exposed/sql/SchemaUtils {
public final class org/jetbrains/exposed/sql/SchemaUtils : org/jetbrains/exposed/sql/BaseSchemaUtils {
public static final field INSTANCE Lorg/jetbrains/exposed/sql/SchemaUtils;
public final fun addMissingColumnsStatements ([Lorg/jetbrains/exposed/sql/Table;Z)Ljava/util/List;
public static synthetic fun addMissingColumnsStatements$default (Lorg/jetbrains/exposed/sql/SchemaUtils;[Lorg/jetbrains/exposed/sql/Table;ZILjava/lang/Object;)Ljava/util/List;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
package org.jetbrains.exposed.sql

import org.jetbrains.exposed.sql.SchemaUtils.createFKey
import org.jetbrains.exposed.sql.SchemaUtils.createIndex
import org.jetbrains.exposed.sql.SqlExpressionBuilder.asLiteral
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.jetbrains.exposed.sql.vendors.*
import java.math.BigDecimal

/** Base class housing shared code between [SchemaUtils] and [MigrationUtils]. */
abstract class BaseSchemaUtils {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one another idea of how common part could be extracted.

At the current moment the whole SchemaUtils looks a bit overloaded. It has methods to generate DDL, and has methods to execute commands to manipulate DB.

We could try to split it into the way to extract DDL generation part into another class, let's call it for now DDLUtils; and keep inside SchemaUtils responsibility to execute these DDLs on the database.

In this case we will have common part with DDL statements generation, and two consumers of that class. SchemaUtils that directly performs it on the DB, and MigrationUtils that prepare it as a migration scripts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@obabichevjb In that case, how would we make the following function visible only to SchemaUtils and MigrationUtils? We don't want it to be public.

fun addMissingColumnsStatements(vararg tables: Table, existingTablesColumns: Map<Table, List<ColumnMetadata>>, withLogs: Boolean = true): List<String>

protected inline fun <R> logTimeSpent(message: String, withLogs: Boolean, block: () -> R): R {
return if (withLogs) {
val start = System.currentTimeMillis()
val answer = block()
exposedLogger.info(message + " took " + (System.currentTimeMillis() - start) + "ms")
answer
} else {
block()
}
}

protected fun addMissingColumnsStatements(vararg tables: Table, existingTablesColumns: Map<Table, List<ColumnMetadata>>, withLogs: Boolean = true): List<String> {
val statements = ArrayList<String>()

val existingPrimaryKeys = logTimeSpent("Extracting primary keys", withLogs) {
currentDialect.existingPrimaryKeys(*tables)
}

val dbSupportsAlterTableWithAddColumn = TransactionManager.current().db.supportsAlterTableWithAddColumn

tables.forEach { table ->
// create columns
val thisTableExistingColumns = existingTablesColumns[table].orEmpty()
val existingTableColumns = table.columns.mapNotNull { column ->
val existingColumn = thisTableExistingColumns.find { column.nameUnquoted().equals(it.name, true) }
if (existingColumn != null) column to existingColumn else null
}.toMap()
val missingTableColumns = table.columns.filter { it !in existingTableColumns }

missingTableColumns.flatMapTo(statements) { it.ddl }

if (dbSupportsAlterTableWithAddColumn) {
// create indexes with new columns
table.indices.filter { index ->
index.columns.any {
missingTableColumns.contains(it)
}
}.forEach { statements.addAll(createIndex(it)) }

// sync existing columns
val dataTypeProvider = currentDialect.dataTypeProvider
val redoColumns = existingTableColumns.mapValues { (col, existingCol) ->
val columnType = col.columnType
val colNullable = if (col.dbDefaultValue?.let { currentDialect.isAllowedAsColumnDefault(it) } == false) {
true // Treat a disallowed default value as null because that is what Exposed does with it
} else {
columnType.nullable
}
val incorrectNullability = existingCol.nullable != colNullable

val incorrectAutoInc = isIncorrectAutoInc(existingCol, col)

val incorrectDefaults = isIncorrectDefault(dataTypeProvider, existingCol, col)

val incorrectCaseSensitiveName = existingCol.name.inProperCase() != col.nameUnquoted().inProperCase()

val incorrectSizeOrScale = isIncorrectSizeOrScale(existingCol, columnType)

ColumnDiff(incorrectNullability, incorrectAutoInc, incorrectDefaults, incorrectCaseSensitiveName, incorrectSizeOrScale)
}.filterValues { it.hasDifferences() }

redoColumns.flatMapTo(statements) { (col, changedState) -> col.modifyStatements(changedState) }

// add missing primary key
val missingPK = table.primaryKey?.takeIf { pk -> pk.columns.none { it in missingTableColumns } }
if (missingPK != null && existingPrimaryKeys[table] == null) {
val missingPKName = missingPK.name.takeIf { table.isCustomPKNameDefined() }
statements.add(
currentDialect.addPrimaryKey(table, missingPKName, pkColumns = missingPK.columns)
)
}
}
}

if (dbSupportsAlterTableWithAddColumn) {
statements.addAll(addMissingColumnConstraints(*tables, withLogs = withLogs))
}

return statements
}

private fun isIncorrectAutoInc(columnMetadata: ColumnMetadata, column: Column<*>): Boolean = when {
!columnMetadata.autoIncrement && column.columnType.isAutoInc && column.autoIncColumnType?.sequence == null ->
true
columnMetadata.autoIncrement && column.columnType.isAutoInc && column.autoIncColumnType?.sequence != null ->
true
columnMetadata.autoIncrement && !column.columnType.isAutoInc -> true
else -> false
}

/**
* For DDL purposes we do not segregate the cases when the default value was not specified, and when it
* was explicitly set to `null`.
*/
private fun isIncorrectDefault(dataTypeProvider: DataTypeProvider, columnMeta: ColumnMetadata, column: Column<*>): Boolean {
val isExistingColumnDefaultNull = columnMeta.defaultDbValue == null
val isDefinedColumnDefaultNull = column.dbDefaultValue?.takeIf { currentDialect.isAllowedAsColumnDefault(it) } == null ||
(column.dbDefaultValue is LiteralOp<*> && (column.dbDefaultValue as? LiteralOp<*>)?.value == null)

return when {
// Both values are null-like, no DDL update is needed
isExistingColumnDefaultNull && isDefinedColumnDefaultNull -> false
// Only one of the values is null-like, DDL update is needed
isExistingColumnDefaultNull != isDefinedColumnDefaultNull -> true

else -> {
val columnDefaultValue = column.dbDefaultValue?.let {
dataTypeProvider.dbDefaultToString(column, it)
}
columnMeta.defaultDbValue != columnDefaultValue
}
}
}

private fun isIncorrectSizeOrScale(columnMeta: ColumnMetadata, columnType: IColumnType<*>): Boolean {
// ColumnMetadata.scale can only be non-null if ColumnMetadata.size is non-null
if (columnMeta.size == null) return false

return when (columnType) {
is DecimalColumnType -> columnType.precision != columnMeta.size || columnType.scale != columnMeta.scale
is CharColumnType -> columnType.colLength != columnMeta.size
is VarCharColumnType -> columnType.colLength != columnMeta.size
is BinaryColumnType -> columnType.length != columnMeta.size
else -> false
}
}

private fun addMissingColumnConstraints(vararg tables: Table, withLogs: Boolean): List<String> {
val existingColumnConstraint = logTimeSpent("Extracting column constraints", withLogs) {
currentDialect.columnConstraints(*tables)
}

val foreignKeyConstraints = tables.flatMap { table ->
table.foreignKeys.map { it to existingColumnConstraint[table to it.from]?.firstOrNull() }
}

val statements = ArrayList<String>()

for ((foreignKey, existingConstraint) in foreignKeyConstraints) {
if (existingConstraint == null) {
statements.addAll(createFKey(foreignKey))
continue
}

val noForeignKey = existingConstraint.targetTable != foreignKey.targetTable
val deleteRuleMismatch = foreignKey.deleteRule != existingConstraint.deleteRule
val updateRuleMismatch = foreignKey.updateRule != existingConstraint.updateRule

if (noForeignKey || deleteRuleMismatch || updateRuleMismatch) {
statements.addAll(existingConstraint.dropStatement())
statements.addAll(createFKey(foreignKey))
}
}

return statements
}

@Suppress("NestedBlockDepth", "ComplexMethod", "LongMethod")
private fun DataTypeProvider.dbDefaultToString(column: Column<*>, exp: Expression<*>): String {
return when (exp) {
is LiteralOp<*> -> {
val dialect = currentDialect
when (val value = exp.value) {
is Boolean -> when (dialect) {
is MysqlDialect -> if (value) "1" else "0"
is PostgreSQLDialect -> value.toString()
else -> booleanToStatementString(value)
}

is String -> when {
dialect is PostgreSQLDialect -> when (column.columnType) {
is VarCharColumnType -> "'$value'::character varying"
is TextColumnType -> "'$value'::text"
else -> processForDefaultValue(exp)
}

dialect is OracleDialect || dialect.h2Mode == H2Dialect.H2CompatibilityMode.Oracle -> when {
column.columnType is VarCharColumnType && value == "" -> "NULL"
column.columnType is TextColumnType && value == "" -> "NULL"
else -> value
}

else -> value
}

is Enum<*> -> when (exp.columnType) {
is EnumerationNameColumnType<*> -> when (dialect) {
is PostgreSQLDialect -> "'${value.name}'::character varying"
else -> value.name
}

else -> processForDefaultValue(exp)
}

is BigDecimal -> when (dialect) {
is MysqlDialect -> value.setScale((exp.columnType as DecimalColumnType).scale).toString()
else -> processForDefaultValue(exp)
}

is Byte -> when {
dialect is PostgreSQLDialect && value < 0 -> "'${processForDefaultValue(exp)}'::integer"
else -> processForDefaultValue(exp)
}

is Short -> when {
dialect is PostgreSQLDialect && value < 0 -> "'${processForDefaultValue(exp)}'::integer"
else -> processForDefaultValue(exp)
}

is Int -> when {
dialect is PostgreSQLDialect && value < 0 -> "'${processForDefaultValue(exp)}'::integer"
else -> processForDefaultValue(exp)
}

is Long -> when {
currentDialect is SQLServerDialect && (value < 0 || value > Int.MAX_VALUE.toLong()) ->
"${processForDefaultValue(exp)}."
currentDialect is PostgreSQLDialect && (value < 0 || value > Int.MAX_VALUE.toLong()) ->
"'${processForDefaultValue(exp)}'::bigint"
else -> processForDefaultValue(exp)
}

is UInt -> when {
dialect is SQLServerDialect && value > Int.MAX_VALUE.toUInt() -> "${processForDefaultValue(exp)}."
dialect is PostgreSQLDialect && value > Int.MAX_VALUE.toUInt() -> "'${processForDefaultValue(exp)}'::bigint"
else -> processForDefaultValue(exp)
}

is ULong -> when {
currentDialect is SQLServerDialect && value > Int.MAX_VALUE.toULong() -> "${processForDefaultValue(exp)}."
currentDialect is PostgreSQLDialect && value > Int.MAX_VALUE.toULong() -> "'${processForDefaultValue(exp)}'::bigint"
else -> processForDefaultValue(exp)
}

else -> {
when {
column.columnType is JsonColumnMarker -> {
val processed = processForDefaultValue(exp)
when (dialect) {
is PostgreSQLDialect -> {
if (column.columnType.usesBinaryFormat) {
processed.replace(Regex("(\"|})(:|,)(\\[|\\{|\")"), "$1$2 $3")
} else {
processed
}
}

is MariaDBDialect -> processed.trim('\'')
is MysqlDialect -> "_utf8mb4\\'${processed.trim('(', ')', '\'')}\\'"
else -> when {
processed.startsWith('\'') && processed.endsWith('\'') -> processed.trim('\'')
else -> processed
}
}
}

column.columnType is ArrayColumnType<*, *> && dialect is PostgreSQLDialect -> {
(value as List<*>)
.takeIf { it.isNotEmpty() }
?.run {
val delegateColumnType = column.columnType.delegate as IColumnType<Any>
val delegateColumn = (column as Column<Any?>).withColumnType(delegateColumnType)
val processed = map {
if (delegateColumn.columnType is StringColumnType) {
"'$it'::text"
} else {
dbDefaultToString(delegateColumn, delegateColumn.asLiteral(it))
}
}
"ARRAY$processed"
} ?: processForDefaultValue(exp)
}

column.columnType is IDateColumnType -> {
val processed = processForDefaultValue(exp)
if (processed.startsWith('\'') && processed.endsWith('\'')) {
processed.trim('\'')
} else {
processed
}
}

else -> processForDefaultValue(exp)
}
}
}
}

is Function<*> -> {
var processed = processForDefaultValue(exp)
if (exp.columnType is IDateColumnType) {
if (processed.startsWith("CURRENT_TIMESTAMP") || processed == "GETDATE()") {
when (currentDialect) {
is SQLServerDialect -> processed = "getdate"
is MariaDBDialect -> processed = processed.lowercase()
}
}
if (processed.trim('(').startsWith("CURRENT_DATE")) {
when (currentDialect) {
is MysqlDialect -> processed = "curdate()"
}
}
}
processed
}

else -> processForDefaultValue(exp)
}
}
}
Loading
Loading