Skip to content

Commit

Permalink
Specialized model function interpretations (#101)
Browse files Browse the repository at this point in the history
* Initial function interpretation specialization

* Move function interpretation to separate files

* Invalidate solver models

* Initial evaluator optimization

* Specialized evaluator

* Fix model serialization

* Track vars usage in Z3 model

* Rebase on Bitwuzla model fix
  • Loading branch information
Saloed authored Apr 27, 2023
1 parent 7e88851 commit 2fd0701
Show file tree
Hide file tree
Showing 17 changed files with 811 additions and 326 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package org.ksmt.solver.bitwuzla

import org.ksmt.KContext
import org.ksmt.decl.KConstDecl
import org.ksmt.decl.KDecl
import org.ksmt.expr.KExpr
import org.ksmt.expr.KUninterpretedSortValue
import org.ksmt.solver.model.KFuncInterp
import org.ksmt.solver.model.KFuncInterpEntryVarsFree
import org.ksmt.solver.model.KFuncInterpEntryVarsFreeOneAry
import org.ksmt.solver.model.KFuncInterpVarsFree
import org.ksmt.solver.KModel
import org.ksmt.solver.KSolverUnsupportedFeatureException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.FunValue
import org.ksmt.solver.bitwuzla.bindings.Native
import org.ksmt.solver.model.KFuncInterpWithVars
import org.ksmt.solver.model.KModelEvaluator
import org.ksmt.solver.model.KModelImpl
import org.ksmt.sort.KArraySort
Expand All @@ -36,8 +40,18 @@ open class KBitwuzlaModel(
private val evaluatorWithModelCompletion by lazy { KModelEvaluator(ctx, this, isComplete = true) }
private val evaluatorWithoutModelCompletion by lazy { KModelEvaluator(ctx, this, isComplete = false) }

override fun <T : KSort> eval(expr: KExpr<T>, isComplete: Boolean): KExpr<T> {
private var isValid: Boolean = true

fun markInvalid() {
isValid = false
}

private fun ensureModelValid() {
bitwuzlaCtx.ensureActive()
check(isValid) { "The model is no longer valid" }
}

override fun <T : KSort> eval(expr: KExpr<T>, isComplete: Boolean): KExpr<T> {
ctx.ensureContextMatch(expr)

val evaluator = if (isComplete) evaluatorWithModelCompletion else evaluatorWithoutModelCompletion
Expand Down Expand Up @@ -68,11 +82,11 @@ open class KBitwuzlaModel(
uninterpretedSortValueContext.currentSortUniverse(sort)
}

private val interpretations: MutableMap<KDecl<*>, KModel.KFuncInterp<*>> = hashMapOf()
private val interpretations: MutableMap<KDecl<*>, KFuncInterp<*>> = hashMapOf()

override fun <T : KSort> interpretation(decl: KDecl<T>): KModel.KFuncInterp<T>? {
override fun <T : KSort> interpretation(decl: KDecl<T>): KFuncInterp<T>? {
ensureModelValid()
ctx.ensureContextMatch(decl)
bitwuzlaCtx.ensureActive()

if (decl !in modelDeclarations) return null

Expand All @@ -90,7 +104,7 @@ open class KBitwuzlaModel(
private fun <T : KSort> getInterpretationSafe(
decl: KDecl<T>,
term: BitwuzlaTerm
): KModel.KFuncInterp<T> = bitwuzlaCtx.bitwuzlaTry {
): KFuncInterp<T> = bitwuzlaCtx.bitwuzlaTry {
handleModelIsUnsupportedWithQuantifiers {
getInterpretation(decl, term)
}
Expand All @@ -99,16 +113,15 @@ open class KBitwuzlaModel(
private fun <T : KSort> getInterpretation(
decl: KDecl<T>,
term: BitwuzlaTerm
): KModel.KFuncInterp<T> = converter.withUninterpretedSortValueContext(uninterpretedSortValueContext) {
): KFuncInterp<T> = converter.withUninterpretedSortValueContext(uninterpretedSortValueContext) {
when {
Native.bitwuzlaTermIsArray(term) -> arrayInterpretation(decl, term)
Native.bitwuzlaTermIsFun(term) -> functionInterpretation(decl, term)
else -> {
val value = Native.bitwuzlaGetValue(bitwuzlaCtx.bitwuzla, term)
val convertedValue = with(converter) { value.convertExpr(decl.sort) }
KModel.KFuncInterp(
KFuncInterpVarsFree(
decl = decl,
vars = emptyList(),
entries = emptyList(),
default = convertedValue
)
Expand All @@ -119,11 +132,11 @@ open class KBitwuzlaModel(
private fun <T : KSort> functionInterpretation(
decl: KDecl<T>,
term: BitwuzlaTerm
): KModel.KFuncInterp<T> {
): KFuncInterp<T> {
val interp = Native.bitwuzlaGetFunValue(bitwuzlaCtx.bitwuzla, term)
return if (interp.size != 0) {
handleArrayFunctionDecl(decl) { functionDecl, vars ->
functionValueInterpretation(functionDecl, vars, interp)
handleArrayFunctionDecl(decl) { functionDecl ->
functionValueInterpretation(functionDecl, interp)
}
} else {
/**
Expand All @@ -136,21 +149,19 @@ open class KBitwuzlaModel(

private fun <T : KSort> KBitwuzlaExprConverter.functionValueInterpretation(
decl: KDecl<T>,
vars: List<KConstDecl<*>>,
interp: FunValue
): KModel.KFuncInterp<T> {
val entries = mutableListOf<KModel.KFuncInterpEntry<T>>()
): KFuncInterpVarsFree<T> {
val entries = mutableListOf<KFuncInterpEntryVarsFree<T>>()

for (i in 0 until interp.size) {
// Don't substitute vars since arguments in Bitwuzla model are always constants
val args = interp.args!![i].zip(decl.argSorts) { arg, sort -> arg.convertExpr(sort) }
val value = interp.values!![i].convertExpr(decl.sort)
entries += KModel.KFuncInterpEntry(args, value)
entries += KFuncInterpEntryVarsFree.create(args, value)
}

return KModel.KFuncInterp(
return KFuncInterpVarsFree(
decl = decl,
vars = vars,
entries = entries,
default = null
)
Expand All @@ -159,7 +170,7 @@ open class KBitwuzlaModel(
private fun <T : KSort> KBitwuzlaExprConverter.retrieveFunctionValue(
decl: KDecl<T>,
functionTerm: BitwuzlaTerm
): KModel.KFuncInterp<T> = handleArrayFunctionInterpretation(decl) { arraySort ->
): KFuncInterp<T> = handleArrayFunctionInterpretation(decl) { arraySort ->
// We expect lambda expression here. Therefore, we convert function interpretation as array.
val functionValue = Native.bitwuzlaGetValue(bitwuzlaCtx.bitwuzla, functionTerm)
functionValue.convertExpr(arraySort)
Expand All @@ -168,51 +179,45 @@ open class KBitwuzlaModel(
private fun <T : KSort> arrayInterpretation(
decl: KDecl<T>,
term: BitwuzlaTerm
): KModel.KFuncInterp<T> = handleArrayFunctionDecl(decl) { arrayFunctionDecl, vars ->
): KFuncInterp<T> = handleArrayFunctionDecl(decl) { arrayFunctionDecl ->
val sort: KArraySort<KSort, KSort> = decl.sort.uncheckedCast()
val entries = mutableListOf<KModel.KFuncInterpEntry<KSort>>()
val entries = mutableListOf<KFuncInterpEntryVarsFree<KSort>>()
val interp = Native.bitwuzlaGetArrayValue(bitwuzlaCtx.bitwuzla, term)

for (i in 0 until interp.size) {
val index = interp.indices!![i].convertExpr(sort.domain)
val value = interp.values!![i].convertExpr(sort.range)
entries += KModel.KFuncInterpEntry(listOf(index), value)
entries += KFuncInterpEntryVarsFreeOneAry(index, value)
}

val default = interp.defaultValue.takeIf { it != 0L }?.convertExpr(sort.range)

KModel.KFuncInterp(
KFuncInterpVarsFree(
decl = arrayFunctionDecl,
vars = vars,
entries = entries,
default = default
)
}

private inline fun <T : KSort> handleArrayFunctionDecl(
decl: KDecl<T>,
body: KBitwuzlaExprConverter.(KDecl<KSort>, List<KConstDecl<*>>) -> KModel.KFuncInterp<*>
): KModel.KFuncInterp<T> = with(ctx) {
body: KBitwuzlaExprConverter.(KDecl<KSort>) -> KFuncInterp<*>
): KFuncInterp<T> = with(ctx) {
val sort = decl.sort

if (sort !is KArraySortBase<*>) {
val vars = decl.argSorts.mapIndexed { i, s -> s.mkFreshConstDecl("x!$i") }
return converter.body(decl.uncheckedCast(), vars).uncheckedCast()
return converter.body(decl.uncheckedCast()).uncheckedCast()
}

check(decl.argSorts.isEmpty()) { "Unexpected function with array range" }

val arrayInterpDecl = mkFreshFuncDecl("array", sort.range, sort.domainSorts)
val arrayInterpIndicesDecls = sort.domainSorts.mapIndexed { i, s ->
s.mkFreshConstDecl("idx!$i")
}

modelDeclarations += arrayInterpDecl
interpretations[arrayInterpDecl] = converter.body(arrayInterpDecl, arrayInterpIndicesDecls)
interpretations[arrayInterpDecl] = converter.body(arrayInterpDecl)

KModel.KFuncInterp(
KFuncInterpVarsFree(
decl = decl,
vars = emptyList(),
entries = emptyList(),
default = mkFunctionAsArray(sort.uncheckedCast(), arrayInterpDecl).uncheckedCast()
)
Expand All @@ -221,14 +226,13 @@ open class KBitwuzlaModel(
private inline fun <T : KSort> KBitwuzlaExprConverter.handleArrayFunctionInterpretation(
decl: KDecl<T>,
convertInterpretation: (KArraySortBase<*>) -> KExpr<KArraySortBase<*>>
): KModel.KFuncInterp<T> {
): KFuncInterp<T> {
val sort = decl.sort

if (sort is KArraySortBase<*> && decl.argSorts.isEmpty()) {
val arrayInterpretation = convertInterpretation(sort)
return KModel.KFuncInterp(
return KFuncInterpVarsFree(
decl = decl,
vars = emptyList(),
entries = emptyList(),
default = arrayInterpretation.uncheckedCast()
)
Expand All @@ -242,7 +246,7 @@ open class KBitwuzlaModel(
val functionVars = decl.argSorts.mapIndexed { i, s -> s.mkFreshConstDecl("x!$i") }
val functionValue = ctx.mkAnyArraySelect(arrayInterpretation, functionVars.map { it.apply() })

return KModel.KFuncInterp(
return KFuncInterpWithVars(
decl = decl,
vars = functionVars,
entries = emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ open class KBitwuzlaSolver(private val ctx: KContext) : KSolver<KBitwuzlaSolverC
bitwuzlaCtx.bitwuzlaTry {
ctx.ensureContextMatch(assumptions)

invalidatePreviousModel()
lastAssumptions.clear()

trackVars.forEach {
Expand All @@ -119,13 +120,25 @@ open class KBitwuzlaSolver(private val ctx: KContext) : KSolver<KBitwuzlaSolverC
Native.bitwuzlaCheckSatTimeoutResult(bitwuzlaCtx.bitwuzla, timeout.inWholeMilliseconds)
}

private var lastModel: KBitwuzlaModel? = null

/**
* Bitwuzla model is only valid until the next check-sat call.
* */
private fun invalidatePreviousModel() {
lastModel?.markInvalid()
lastModel = null
}

override fun model(): KModel = bitwuzlaCtx.bitwuzlaTry {
require(lastCheckStatus == KSolverStatus.SAT) { "Model are only available after SAT checks" }
return KBitwuzlaModel(
val model = lastModel ?: KBitwuzlaModel(
ctx, bitwuzlaCtx, exprConverter,
bitwuzlaCtx.declarations(),
bitwuzlaCtx.uninterpretedSortsWithRelevantDecls()
)
lastModel = model
model
}

override fun unsatCore(): List<KExpr<KBoolSort>> = bitwuzlaCtx.bitwuzlaTry {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Example {

solver.close()

assertFailsWith(IllegalStateException::class) { model.eval(a) }
assertFailsWith(IllegalStateException::class) { model.interpretation(b) }
assertEquals(aValue, detachedModel.eval(a))
assertEquals(cValue, detachedModel.eval(c))
}
Expand Down
42 changes: 1 addition & 41 deletions ksmt-core/src/main/kotlin/org/ksmt/solver/KModel.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package org.ksmt.solver

import org.ksmt.decl.KDecl
import org.ksmt.decl.KFuncDecl
import org.ksmt.expr.KExpr
import org.ksmt.expr.KUninterpretedSortValue
import org.ksmt.solver.model.KFuncInterp
import org.ksmt.sort.KSort
import org.ksmt.sort.KUninterpretedSort

Expand All @@ -22,44 +22,4 @@ interface KModel {
fun uninterpretedSortUniverse(sort: KUninterpretedSort): Set<KUninterpretedSortValue>?

fun detach(): KModel

data class KFuncInterp<T : KSort>(
val decl: KDecl<T>,
val vars: List<KDecl<*>>,
val entries: List<KFuncInterpEntry<T>>,
val default: KExpr<T>?
) {
init {
if (decl is KFuncDecl<T>) {
require(decl.argSorts.size == vars.size) {
"Function $decl has ${decl.argSorts.size} arguments but ${vars.size} were provided"
}
}
require(entries.all { it.args.size == vars.size }) {
"Function interpretation arguments mismatch"
}
}

val sort: T
get() = decl.sort

override fun toString(): String {
if (entries.isEmpty()) return default.toString()
return buildString {
appendLine('{')
entries.forEach { appendLine(it) }
append("else -> ")
appendLine(default)
append('}')
}
}
}

data class KFuncInterpEntry<T : KSort>(
val args: List<KExpr<*>>,
val value: KExpr<T>
) {
override fun toString(): String =
args.joinToString(prefix = "(", postfix = ") -> $value")
}
}
Loading

0 comments on commit 2fd0701

Please sign in to comment.