From 0acf6575865d692baea6c5a398a49d5069b519ee Mon Sep 17 00:00:00 2001 From: Roman Pozharskiy Date: Tue, 17 Sep 2024 16:05:24 +0300 Subject: [PATCH] Move `TrieNode` from `kotlinx.collections` (#194) * move immutable collections from kotlinx * use `UPersistentHashMap`; introduce ownership * add ownership in memory * numeric constraints: reorder remove and get * linter warnings fixes * fix ownership flow; add ownership in type constraints * use default ownership in composer memory * use memory ownership in memcpy, union, merge * fix ownership inheritance * various fixes * add ownership in composer * fixes * change pc.ownership in clone * remove scope.ownership * fixes * fix set iterator * fix iterator * iterate sequence only once in JcStaticFieldsRegion * fix mutable iterators in NumericConstraints * add equals, hashCode, toString * drop RegionTree * fix hashCode, equals; fix union in map, set * move some logic in extensions * toList in equalityConstrints * fix compile err * add toList in some iterators * style --- .../src/main/kotlin/org/usvm/Composition.kt | 4 +- usvm-core/src/main/kotlin/org/usvm/Context.kt | 9 +- usvm-core/src/main/kotlin/org/usvm/Mocks.kt | 42 +- usvm-core/src/main/kotlin/org/usvm/State.kt | 5 + .../src/main/kotlin/org/usvm/StateForker.kt | 14 +- .../src/main/kotlin/org/usvm/UComponents.kt | 5 +- .../src/main/kotlin/org/usvm/api/MemoryApi.kt | 5 +- .../src/main/kotlin/org/usvm/api/MockApi.kt | 4 +- .../org/usvm/collection/array/ArrayRegion.kt | 59 +- .../usvm/collection/array/ArrayRegionApi.kt | 11 +- .../usvm/collection/array/USymbolicArrayId.kt | 4 +- .../array/length/ArrayLengthRegion.kt | 16 +- .../array/length/USymbolicArrayLengthId.kt | 2 +- .../org/usvm/collection/field/FieldsRegion.kt | 16 +- .../usvm/collection/field/USymbolicFieldId.kt | 2 +- .../collection/map/length/UMapLengthRegion.kt | 16 +- .../map/length/USymbolicMapLengthId.kt | 2 +- .../collection/map/primitive/UMapRegion.kt | 42 +- .../collection/map/primitive/UMapRegionApi.kt | 4 +- .../map/primitive/USymbolicMapId.kt | 4 +- .../usvm/collection/map/ref/URefMapRegion.kt | 83 +- .../collection/map/ref/URefMapRegionApi.kt | 2 +- .../collection/map/ref/USymbolicRefMapId.kt | 6 +- .../collection/set/USetCollectionDecoder.kt | 11 +- .../collection/set/primitive/USetRegion.kt | 40 +- .../collection/set/primitive/USetRegionApi.kt | 2 +- .../set/primitive/USymbolicSetId.kt | 4 +- .../usvm/collection/set/ref/URefSetRegion.kt | 84 +- .../collection/set/ref/URefSetRegionApi.kt | 2 +- .../collection/set/ref/USymbolicRefSetId.kt | 6 +- .../usvm/constraints/EqualityConstraints.kt | 184 ++-- .../org/usvm/constraints/PathConstraints.kt | 68 +- .../org/usvm/constraints/TypeConstraints.kt | 90 +- .../usvm/constraints/ULogicalConstraints.kt | 38 +- .../usvm/constraints/UNumericConstraints.kt | 251 +++-- .../org/usvm/memory/HeapRefSplitting.kt | 3 +- .../src/main/kotlin/org/usvm/memory/Memory.kt | 71 +- .../kotlin/org/usvm/memory/RegistersStack.kt | 2 + .../org/usvm/memory/USymbolicCollection.kt | 15 +- .../main/kotlin/org/usvm/merging/Merging.kt | 16 + .../org/usvm/merging/MergingPathSelector.kt | 2 +- .../src/main/kotlin/org/usvm/model/Model.kt | 6 +- .../kotlin/org/usvm/model/ModelRegions.kt | 12 +- .../kotlin/org/usvm/model/UModelEvaluator.kt | 20 +- .../src/main/kotlin/org/usvm/solver/Solver.kt | 2 +- .../test/kotlin/org/usvm/CompositionTest.kt | 59 +- .../src/test/kotlin/org/usvm/TestUtil.kt | 13 +- .../collections/SymbolicCollectionTestBase.kt | 25 +- .../constraints/EqualityConstraintsTests.kt | 14 +- .../constraints/NumericConstraintsTests.kt | 51 +- .../kotlin/org/usvm/memory/HeapMemCpyTest.kt | 9 +- .../kotlin/org/usvm/memory/HeapMemsetTest.kt | 9 +- .../kotlin/org/usvm/memory/HeapRefEqTest.kt | 7 +- .../org/usvm/memory/HeapRefSplittingTest.kt | 9 +- .../org/usvm/memory/MemoryRegionTest.kt | 13 +- .../kotlin/org/usvm/memory/SetEntriesTest.kt | 9 +- .../usvm/merging/CloseStatesSearcherTest.kt | 7 +- .../org/usvm/merging/MemoryMergingTest.kt | 28 +- .../merging/PathConstraintsMergingTest.kt | 17 +- .../org/usvm/model/ModelCompositionTest.kt | 52 +- .../org/usvm/model/ModelDecodingTest.kt | 13 +- .../org/usvm/solver/SoftConstraintsTest.kt | 19 +- .../kotlin/org/usvm/solver/TranslationTest.kt | 45 +- .../kotlin/org/usvm/types/TypeSolverTest.kt | 25 +- .../kotlin/org/usvm/machine/JcComponents.kt | 6 +- .../kotlin/org/usvm/machine/JcTransformer.kt | 4 +- .../machine/interpreter/JcCallSiteRegion.kt | 14 +- .../machine/interpreter/JcExprResolver.kt | 9 +- .../usvm/machine/interpreter/JcInterpreter.kt | 5 +- .../statics/JcStaticFieldsRegion.kt | 42 +- .../kotlin/org/usvm/machine/mocks/JcMocker.kt | 2 +- .../kotlin/org/usvm/machine/state/JcState.kt | 36 +- .../main/kotlin/org/usvm/machine/Mocking.kt | 6 +- .../main/kotlin/org/usvm/machine/PyMachine.kt | 8 +- .../org/usvm/machine/PyPathConstraints.kt | 19 +- .../main/kotlin/org/usvm/machine/PyState.kt | 13 +- .../org/usvm/runner/PrintingResultReceiver.kt | 2 +- .../kotlin/org/usvm/machine/SampleMachine.kt | 3 +- .../kotlin/org/usvm/machine/SampleState.kt | 33 +- .../src/main/kotlin/org/usvm/TSInterpreter.kt | 3 +- .../src/main/kotlin/org/usvm/state/TSState.kt | 18 +- .../algorithms/PersistentMultiMapBuilder.kt | 95 -- .../org/usvm/algorithms/SeparationUtils.kt | 34 +- .../usvm/algorithms/UPersistentMultiMap.kt | 86 ++ .../usvm/collections/immutable/extensions.kt | 70 ++ .../immutableMap/TrieIterator.kt | 180 ++++ .../implementations/immutableMap/TrieNode.kt | 935 ++++++++++++++++++ .../implementations/immutableSet/TrieNode.kt | 807 +++++++++++++++ .../UPersistentHashSetIterator.kt | 120 +++ .../immutable/internal/ForEachOneBit.kt | 18 + .../immutable/internal/MutabilityOwnership.kt | 13 + 91 files changed, 3431 insertions(+), 860 deletions(-) delete mode 100644 usvm-util/src/main/kotlin/org/usvm/algorithms/PersistentMultiMapBuilder.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/algorithms/UPersistentMultiMap.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/extensions.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieIterator.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieNode.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/TrieNode.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/UPersistentHashSetIterator.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/ForEachOneBit.kt create mode 100644 usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/MutabilityOwnership.kt diff --git a/usvm-core/src/main/kotlin/org/usvm/Composition.kt b/usvm-core/src/main/kotlin/org/usvm/Composition.kt index ecb44bf966..574b91908d 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Composition.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Composition.kt @@ -15,6 +15,7 @@ import org.usvm.collection.set.primitive.UInputSetReading import org.usvm.collection.set.ref.UAllocatedRefSetWithInputElementsReading import org.usvm.collection.set.ref.UInputRefSetWithAllocatedElementsReading import org.usvm.collection.set.ref.UInputRefSetWithInputElementsReading +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UReadOnlyMemory import org.usvm.memory.USymbolicCollectionId import org.usvm.regions.Region @@ -22,7 +23,8 @@ import org.usvm.regions.Region @Suppress("MemberVisibilityCanBePrivate") open class UComposer( ctx: UContext, - val memory: UReadOnlyMemory + val memory: UReadOnlyMemory, + val ownership: MutabilityOwnership ) : UExprTransformer(ctx) { open fun compose(expr: UExpr): UExpr = apply(expr) diff --git a/usvm-core/src/main/kotlin/org/usvm/Context.kt b/usvm-core/src/main/kotlin/org/usvm/Context.kt index dfda6180df..47d688b98a 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Context.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Context.kt @@ -42,6 +42,7 @@ import org.usvm.collection.set.ref.UInputRefSetWithAllocatedElements import org.usvm.collection.set.ref.UInputRefSetWithAllocatedElementsReading import org.usvm.collection.set.ref.UInputRefSetWithInputElements import org.usvm.collection.set.ref.UInputRefSetWithInputElementsReading +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UAddressCounter import org.usvm.memory.UReadOnlyMemory import org.usvm.memory.splitUHeapRef @@ -61,11 +62,12 @@ open class UContext( private val solver by lazy { components.mkSolver(this) } private val typeSystem by lazy { components.mkTypeSystem(this) } private val softConstraintsProvider by lazy { components.mkSoftConstraintsProvider(this) } - private val composerBuilder: (UReadOnlyMemory<*>) -> UComposer<*, USizeSort> by lazy { + private val composerBuilder: (UReadOnlyMemory<*>, MutabilityOwnership) -> UComposer<*, USizeSort> by lazy { @Suppress("UNCHECKED_CAST") - components.mkComposer(this) as (UReadOnlyMemory<*>) -> UComposer<*, USizeSort> + components.mkComposer(this) as (UReadOnlyMemory<*>, MutabilityOwnership) -> UComposer<*, USizeSort> } + val defaultOwnership = MutabilityOwnership() val sizeExprs by lazy { components.mkSizeExprProvider(this) } val statesForkProvider by lazy { components.mkStatesForkProvider() } @@ -86,7 +88,8 @@ open class UContext( fun softConstraintsProvider(): USoftConstraintsProvider = softConstraintsProvider.cast() - fun composer(memory: UReadOnlyMemory): UComposer = composerBuilder(memory).cast() + fun composer(memory: UReadOnlyMemory, ownership: MutabilityOwnership): UComposer = + composerBuilder(memory, ownership).cast() val addressSort: UAddressSort = mkUninterpretedSort("Address") val nullRef: UNullRef = UNullRef(this) diff --git a/usvm-core/src/main/kotlin/org/usvm/Mocks.kt b/usvm-core/src/main/kotlin/org/usvm/Mocks.kt index c689ba2725..599bb0c3a4 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Mocks.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Mocks.kt @@ -2,9 +2,11 @@ package org.usvm import io.ksmt.utils.cast import kotlinx.collections.immutable.PersistentList -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import kotlinx.collections.immutable.persistentListOf +import org.usvm.collections.immutable.getOrDefault +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.merging.MergeGuard import org.usvm.merging.UMergeable @@ -19,10 +21,15 @@ interface UMocker : UMockEvaluator { method: Method, args: Sequence>, sort: Sort, + ownership: MutabilityOwnership, ): UMockSymbol - val trackedLiterals: Collection + val trackedLiterals: Sequence - fun createMockSymbol(trackedLiteral: TrackedLiteral?, sort: Sort): UExpr + fun createMockSymbol( + trackedLiteral: TrackedLiteral?, + sort: Sort, + ownership: MutabilityOwnership, + ): UExpr fun getTrackedExpression(trackedLiteral: TrackedLiteral): UExpr @@ -30,37 +37,42 @@ interface UMocker : UMockEvaluator { } class UIndexedMocker( - private var methodMockClauses: PersistentMap>> = persistentHashMapOf(), - private var trackedSymbols: PersistentMap> = persistentHashMapOf(), + private var methodMockClauses: UPersistentHashMap>> = persistentHashMapOf(), + private var trackedSymbols: UPersistentHashMap> = persistentHashMapOf(), private var untrackedSymbols: PersistentList> = persistentListOf(), ) : UMocker, UMergeable, MergeGuard> { override fun call( method: Method, args: Sequence>, sort: Sort, + ownership: MutabilityOwnership, ): UMockSymbol { val currentClauses = methodMockClauses.getOrDefault(method, persistentListOf()) val index = currentClauses.size val const = sort.uctx.mkIndexedMethodReturnValue(method, index, sort) - methodMockClauses = methodMockClauses.put(method, currentClauses.add(const)) + methodMockClauses = methodMockClauses.put(method, currentClauses.add(const), ownership) return const } override fun eval(symbol: UMockSymbol): UExpr = symbol - override val trackedLiterals: Collection + override val trackedLiterals: Sequence get() = trackedSymbols.keys /** * Creates a mock symbol. If [trackedLiteral] is not null, created expression * can be retrieved later by this [trackedLiteral] using [getTrackedExpression] method. */ - override fun createMockSymbol(trackedLiteral: TrackedLiteral?, sort: Sort): UExpr { + override fun createMockSymbol( + trackedLiteral: TrackedLiteral?, + sort: Sort, + ownership: MutabilityOwnership, + ): UExpr { val const = sort.uctx.mkTrackedSymbol(sort) if (trackedLiteral != null) { - trackedSymbols = trackedSymbols.put(trackedLiteral, const) + trackedSymbols = trackedSymbols.put(trackedLiteral, const, ownership) } else { untrackedSymbols = untrackedSymbols.add(const) } @@ -71,10 +83,11 @@ class UIndexedMocker( override fun getTrackedExpression(trackedLiteral: TrackedLiteral): UExpr { if (trackedLiteral !in trackedSymbols) error("Access by unregistered track literal $trackedLiteral") - return trackedSymbols.getValue(trackedLiteral).cast() + return trackedSymbols[trackedLiteral]!!.cast() } - override fun clone(): UIndexedMocker = UIndexedMocker(methodMockClauses, trackedSymbols, untrackedSymbols) + override fun clone(): UIndexedMocker = + UIndexedMocker(methodMockClauses, trackedSymbols, untrackedSymbols) /** * Check if this [UIndexedMocker] can be merged with [other] indexed mocker. @@ -83,7 +96,10 @@ class UIndexedMocker( * * @return the merged indexed mocker. */ - override fun mergeWith(other: UIndexedMocker, by: MergeGuard): UIndexedMocker? { + override fun mergeWith( + other: UIndexedMocker, + by: MergeGuard, + ): UIndexedMocker? { if (methodMockClauses !== other.methodMockClauses || trackedSymbols !== other.trackedSymbols || untrackedSymbols !== other.untrackedSymbols diff --git a/usvm-core/src/main/kotlin/org/usvm/State.kt b/usvm-core/src/main/kotlin/org/usvm/State.kt index 886ca65053..2e82fe4ddb 100644 --- a/usvm-core/src/main/kotlin/org/usvm/State.kt +++ b/usvm-core/src/main/kotlin/org/usvm/State.kt @@ -1,5 +1,6 @@ package org.usvm +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.merging.UMergeable @@ -12,6 +13,7 @@ typealias StateId = UInt abstract class UState( // TODO: add interpreter-specific information val ctx: Context, + initOwnership: MutabilityOwnership, open val callStack: UCallStack, open val pathConstraints: UPathConstraints, open val memory: UMemory, @@ -33,6 +35,9 @@ abstract class UState( */ val id: StateId = ctx.getNextStateId() + open var ownership = initOwnership + protected set + /** * Creates new state structurally identical to this. * If [newConstraints] is null, clones [pathConstraints]. Otherwise, uses [newConstraints] in cloned state. diff --git a/usvm-core/src/main/kotlin/org/usvm/StateForker.kt b/usvm-core/src/main/kotlin/org/usvm/StateForker.kt index 0f50bb88ae..ec634b97c4 100644 --- a/usvm-core/src/main/kotlin/org/usvm/StateForker.kt +++ b/usvm-core/src/main/kotlin/org/usvm/StateForker.kt @@ -1,5 +1,6 @@ package org.usvm +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.model.UModelBase import org.usvm.solver.USatResult import org.usvm.solver.UUnknownResult @@ -155,7 +156,11 @@ object WithSolverStateForker : StateForker { val satResult = solver.check(constraintsToCheck) return when (satResult) { - is UUnsatResult -> null + is UUnsatResult -> { + // rollback previous ownership + state.pathConstraints.changeOwnership(state.ownership) + null + } is USatResult -> { // Note that we cannot extract common code here due to @@ -177,6 +182,8 @@ object WithSolverStateForker : StateForker { } is UUnknownResult -> { + // rollback previous ownership + state.pathConstraints.changeOwnership(state.ownership) state.pathConstraints += if (stateToCheck) newConstraintToOriginalState else newConstraintToForkedState null @@ -192,11 +199,11 @@ object NoSolverStateForker : StateForker { ): ForkResult { val (trueModels, falseModels, _) = splitModelsByCondition(state.models, condition) val notCondition = state.ctx.mkNot(condition) - val clonedPathConstraints = state.pathConstraints.clone() clonedPathConstraints += condition val (posState, negState) = if (clonedPathConstraints.isFalse) { + // changing ownership is unnecessary state.pathConstraints += notCondition state.models = falseModels @@ -225,8 +232,7 @@ object NoSolverStateForker : StateForker { val result = mutableListOf() for (condition in conditions) { val (trueModels, _) = splitModelsByCondition(curState.models, condition) - - val clonedConstraints = curState.pathConstraints.clone() + val clonedConstraints = curState.pathConstraints.clone(MutabilityOwnership(), MutabilityOwnership()) clonedConstraints += condition if (clonedConstraints.isFalse) { diff --git a/usvm-core/src/main/kotlin/org/usvm/UComponents.kt b/usvm-core/src/main/kotlin/org/usvm/UComponents.kt index 8dc74a92ba..386b9e7002 100644 --- a/usvm-core/src/main/kotlin/org/usvm/UComponents.kt +++ b/usvm-core/src/main/kotlin/org/usvm/UComponents.kt @@ -1,5 +1,6 @@ package org.usvm +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UReadOnlyMemory import org.usvm.model.ULazyModelDecoder import org.usvm.model.UModelDecoder @@ -34,8 +35,8 @@ interface UComponents { fun > mkComposer( ctx: Context, - ): (UReadOnlyMemory) -> UComposer = - { memory: UReadOnlyMemory -> UComposer(ctx, memory) } + ): (UReadOnlyMemory, MutabilityOwnership) -> UComposer = + { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } fun mkStatesForkProvider(): StateForker = if (useSolverForForks) WithSolverStateForker else NoSolverStateForker diff --git a/usvm-core/src/main/kotlin/org/usvm/api/MemoryApi.kt b/usvm-core/src/main/kotlin/org/usvm/api/MemoryApi.kt index 89f259894f..c178d905fb 100644 --- a/usvm-core/src/main/kotlin/org/usvm/api/MemoryApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/api/MemoryApi.kt @@ -6,11 +6,12 @@ import org.usvm.UContext import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort -import org.usvm.memory.UMemory import org.usvm.memory.UReadOnlyMemory import org.usvm.memory.UWritableMemory import org.usvm.collection.array.UArrayIndexLValue import org.usvm.collection.array.length.UArrayLengthLValue +import org.usvm.collection.array.memcpy as memcpyInternal +import org.usvm.collection.array.memset as memsetInternal import org.usvm.collection.field.UFieldLValue import org.usvm.collection.set.primitive.USetEntryLValue import org.usvm.collection.set.ref.URefSetEntryLValue @@ -22,8 +23,6 @@ import org.usvm.regions.Region import org.usvm.types.UTypeStream import org.usvm.uctx import org.usvm.withSizeSort -import org.usvm.collection.array.memcpy as memcpyInternal -import org.usvm.collection.array.memset as memsetInternal import org.usvm.collection.array.allocateArray as allocateArrayInternal import org.usvm.collection.array.allocateArrayInitialized as allocateArrayInitializedInternal diff --git a/usvm-core/src/main/kotlin/org/usvm/api/MockApi.kt b/usvm-core/src/main/kotlin/org/usvm/api/MockApi.kt index 33e1f2de20..864d583cd2 100644 --- a/usvm-core/src/main/kotlin/org/usvm/api/MockApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/api/MockApi.kt @@ -15,7 +15,7 @@ fun UState<*, Method, *, *, *, *>.makeSymbolicPrimitive( sort: T ): UExpr { check(sort != sort.uctx.addressSort) { "$sort is not primitive" } - return memory.mocker.createMockSymbol(trackedLiteral = null, sort) + return memory.mocker.createMockSymbol(trackedLiteral = null, sort, ownership) } fun StepScope.makeSymbolicRef( @@ -34,7 +34,7 @@ fun StepScope.makeNullableSymbolicRefWi mockSymbolicRef { ctx.mkOr(objectTypeEquals(it, representative), ctx.mkEq(it, ctx.nullRef)) } fun UState<*, Method, *, *, *, *>.makeSymbolicRefUntyped(): UHeapRef = - memory.mocker.createMockSymbol(trackedLiteral = null, ctx.addressSort) + memory.mocker.createMockSymbol(trackedLiteral = null, ctx.addressSort, ownership) private inline fun StepScope.mockSymbolicRef( crossinline mkTypeConstraint: State.(UHeapRef) -> UBoolExpr diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegion.kt index 4c7dc1b96b..ba2c9475c7 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegion.kt @@ -1,12 +1,15 @@ package org.usvm.collection.array -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort +import org.usvm.collections.immutable.getOrPut +import org.usvm.uctx +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -50,6 +53,7 @@ interface UArrayRegion : UMemoryRegi fromDstIdx: UExpr, toDstIdx: UExpr, operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): UArrayRegion fun initializeAllocatedArray( @@ -57,30 +61,33 @@ interface UArrayRegion : UMemoryRegi arrayType: ArrayType, sort: Sort, content: Map, UExpr>, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): UArrayRegion } internal class UArrayMemoryRegion( - private var allocatedArrays: PersistentMap> = persistentHashMapOf(), - private var inputArray: UInputArray? = null + private var allocatedArrays: UPersistentHashMap> = persistentHashMapOf(), + private var inputArray: UInputArray? = null, ) : UArrayRegion { private fun getAllocatedArray( arrayType: ArrayType, sort: Sort, - address: UConcreteHeapAddress + address: UConcreteHeapAddress, ): UAllocatedArray { - var collection = allocatedArrays[address] - if (collection == null) { - collection = UAllocatedArrayId<_, _, USizeSort>(arrayType, sort, address).emptyRegion() - allocatedArrays = allocatedArrays.put(address, collection) + val (updatedArrays, collection) = allocatedArrays.getOrPut(address, sort.uctx.defaultOwnership) { + UAllocatedArrayId<_, _, USizeSort>(arrayType, sort, address).emptyRegion() } + allocatedArrays = updatedArrays return collection } - private fun updateAllocatedArray(ref: UConcreteHeapAddress, updated: UAllocatedArray) = - UArrayMemoryRegion(allocatedArrays.put(ref, updated), inputArray) + private fun updateAllocatedArray( + ref: UConcreteHeapAddress, + updated: UAllocatedArray, + ownership: MutabilityOwnership, + ) = UArrayMemoryRegion(allocatedArrays.put(ref, updated, ownership), inputArray) private fun getInputArray(arrayType: ArrayType, sort: Sort): UInputArray { if (inputArray == null) @@ -91,27 +98,29 @@ internal class UArrayMemoryRegion( private fun updateInput(updated: UInputArray) = UArrayMemoryRegion(allocatedArrays, updated) - override fun read(key: UArrayIndexLValue): UExpr = key.ref.mapWithStaticAsSymbolic( - concreteMapper = { concreteRef -> getAllocatedArray(key.arrayType, key.sort, concreteRef.address).read(key.index) }, - symbolicMapper = { symbolicRef -> getInputArray(key.arrayType, key.sort).read(symbolicRef to key.index) } - ) + override fun read(key: UArrayIndexLValue): UExpr = + key.ref.mapWithStaticAsSymbolic( + concreteMapper = { concreteRef -> getAllocatedArray(key.arrayType, key.sort, concreteRef.address).read(key.index) }, + symbolicMapper = { symbolicRef -> getInputArray(key.arrayType, key.sort).read(symbolicRef to key.index) } + ) override fun write( key: UArrayIndexLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ): UMemoryRegion, Sort> = foldHeapRefWithStaticAsSymbolic( key.ref, initial = this, initialGuard = guard, blockOnConcrete = { region, (concreteRef, innerGuard) -> val oldRegion = region.getAllocatedArray(key.arrayType, key.sort, concreteRef.address) - val newRegion = oldRegion.write(key.index, value, innerGuard) - region.updateAllocatedArray(concreteRef.address, newRegion) + val newRegion = oldRegion.write(key.index, value, innerGuard, ownership) + region.updateAllocatedArray(concreteRef.address, newRegion, ownership) }, blockOnSymbolic = { region, (symbolicRef, innerGuard) -> val oldRegion = region.getInputArray(key.arrayType, key.sort) - val newRegion = oldRegion.write(symbolicRef to key.index, value, innerGuard) + val newRegion = oldRegion.write(symbolicRef to key.index, value, innerGuard, ownership) region.updateInput(newRegion) } ) @@ -125,6 +134,7 @@ internal class UArrayMemoryRegion( fromDstIdx: UExpr, toDstIdx: UExpr, operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRef2( ref0 = srcRef, ref1 = dstRef, @@ -137,7 +147,7 @@ internal class UArrayMemoryRegion( fromSrcIdx, fromDstIdx, toDstIdx, USizeExprKeyInfo() ) val newDstCollection = dstCollection.copyRange(srcCollection, adapter, guard) - region.updateAllocatedArray(dstConcrete.address, newDstCollection) + region.updateAllocatedArray(dstConcrete.address, newDstCollection, ownership) }, blockOnConcrete0Symbolic1 = { region, srcConcrete, dstSymbolic, guard -> @@ -162,7 +172,7 @@ internal class UArrayMemoryRegion( USizeExprKeyInfo() ) val newDstCollection = dstCollection.copyRange(srcCollection, adapter, guard) - region.updateAllocatedArray(dstConcrete.address, newDstCollection) + region.updateAllocatedArray(dstConcrete.address, newDstCollection, ownership) }, blockOnSymbolic0Symbolic1 = { region, srcSymbolic, dstSymbolic, guard -> val srcCollection = region.getInputArray(type, elementSort) @@ -183,10 +193,11 @@ internal class UArrayMemoryRegion( arrayType: ArrayType, sort: Sort, content: Map, UExpr>, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): UArrayMemoryRegion { val arrayId = UAllocatedArrayId<_, _, USizeSort>(arrayType, sort, address) val newCollection = arrayId.initializedArray(content, operationGuard) - return UArrayMemoryRegion(allocatedArrays.put(address, newCollection), inputArray) + return UArrayMemoryRegion(allocatedArrays.put(address, newCollection, ownership), inputArray) } } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegionApi.kt b/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegionApi.kt index ee54051857..8941daf8a7 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegionApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/array/ArrayRegionApi.kt @@ -28,7 +28,7 @@ internal fun UWritableMemory<*>.mem "memcpy is not applicable to $region" } - val newRegion = region.memcpy(srcRef, dstRef, type, elementSort, fromSrcIdx, fromDstIdx, toDstIdx, guard) + val newRegion = region.memcpy(srcRef, dstRef, type, elementSort, fromSrcIdx, fromDstIdx, toDstIdx, guard, ownership) setRegion(regionId, newRegion) } @@ -51,7 +51,14 @@ internal fun UWritableMemory internal con return key.uctx.withSizeSort().mkAllocatedArrayReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -129,7 +129,7 @@ class UInputArrayId internal constru return sort.uctx.withSizeSort().mkInputArrayReading(collection, key.first, key.second) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/array/length/ArrayLengthRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/array/length/ArrayLengthRegion.kt index 693b237371..a02d7ee2f8 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/array/length/ArrayLengthRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/array/length/ArrayLengthRegion.kt @@ -1,12 +1,13 @@ package org.usvm.collection.array.length -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -42,11 +43,11 @@ interface UArrayLengthsRegion : UMemoryRegion( private val sort: USizeSort, private val arrayType: ArrayType, - private val allocatedLengths: PersistentMap> = persistentHashMapOf(), + private val allocatedLengths: UPersistentHashMap> = persistentHashMapOf(), private var inputLengths: UInputArrayLengths? = null ) : UArrayLengthsRegion { - private fun updateAllocated(updated: PersistentMap>) = + private fun updateAllocated(updated: UPersistentHashMap>) = UArrayLengthsMemoryRegion(sort, arrayType, updated, inputLengths) private fun getInputLength(ref: UArrayLengthLValue): UInputArrayLengths { @@ -66,20 +67,21 @@ internal class UArrayLengthsMemoryRegion( override fun write( key: UArrayLengthLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( key.ref, initial = this, initialGuard = guard, blockOnConcrete = { region, (concreteRef, innerGuard) -> - val newRegion = region.allocatedLengths.guardedWrite(concreteRef.address, value, innerGuard) { + val newRegion = region.allocatedLengths.guardedWrite(concreteRef.address, value, innerGuard, ownership) { sort.sampleUValue() } region.updateAllocated(newRegion) }, blockOnSymbolic = { region, (symbolicRef, innerGuard) -> val oldRegion = region.getInputLength(key) - val newRegion = oldRegion.write(symbolicRef, value, innerGuard) + val newRegion = oldRegion.write(symbolicRef, value, innerGuard, ownership) region.updatedInput(newRegion) } ) diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/array/length/USymbolicArrayLengthId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/array/length/USymbolicArrayLengthId.kt index 996f3ab879..76e7ca5bb1 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/array/length/USymbolicArrayLengthId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/array/length/USymbolicArrayLengthId.kt @@ -35,7 +35,7 @@ class UInputArrayLengthId internal constructor( return key.uctx.withSizeSort().mkInputArrayLengthReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/field/FieldsRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/field/FieldsRegion.kt index 400496f117..bd802d8f16 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/field/FieldsRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/field/FieldsRegion.kt @@ -1,12 +1,13 @@ package org.usvm.collection.field -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -39,11 +40,11 @@ interface UFieldsRegion : UMemoryRegion( private val sort: Sort, private val field: Field, - private val allocatedFields: PersistentMap> = persistentHashMapOf(), + private val allocatedFields: UPersistentHashMap> = persistentHashMapOf(), private var inputFields: UInputFields? = null ) : UFieldsRegion { - private fun updateAllocated(updated: PersistentMap>) = + private fun updateAllocated(updated: UPersistentHashMap>) = UFieldsMemoryRegion(sort, field, updated, inputFields) private fun getInputFields(ref: UFieldLValue): UInputFields { @@ -63,20 +64,21 @@ internal class UFieldsMemoryRegion( override fun write( key: UFieldLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ): UMemoryRegion, Sort> = foldHeapRefWithStaticAsSymbolic( key.ref, initial = this, initialGuard = guard, blockOnConcrete = { region, (concreteRef, innerGuard) -> - val newRegion = region.allocatedFields.guardedWrite(concreteRef.address, value, innerGuard) { + val newRegion = region.allocatedFields.guardedWrite(concreteRef.address, value, innerGuard, ownership) { sort.sampleUValue() } region.updateAllocated(newRegion) }, blockOnSymbolic = { region, (symbolicRef, innerGuard) -> val oldRegion = region.getInputFields(key) - val newRegion = oldRegion.write(symbolicRef, value, innerGuard) + val newRegion = oldRegion.write(symbolicRef, value, innerGuard, ownership) region.updateInput(newRegion) } ) diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/field/USymbolicFieldId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/field/USymbolicFieldId.kt index 910f41d32c..62df6895e7 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/field/USymbolicFieldId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/field/USymbolicFieldId.kt @@ -35,7 +35,7 @@ class UInputFieldId internal constructor( return key.uctx.mkInputFieldReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/length/UMapLengthRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/length/UMapLengthRegion.kt index 0ceacf1974..9fc287bf44 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/length/UMapLengthRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/length/UMapLengthRegion.kt @@ -1,12 +1,13 @@ package org.usvm.collection.map.length -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -43,11 +44,11 @@ interface UMapLengthRegion : UMemoryRegion( private val sort: USizeSort, private val mapType: MapType, - private val allocatedLengths: PersistentMap> = persistentHashMapOf(), + private val allocatedLengths: UPersistentHashMap> = persistentHashMapOf(), private var inputLengths: UInputMapLength? = null ) : UMapLengthRegion { - private fun updateAllocated(updated: PersistentMap>) = + private fun updateAllocated(updated: UPersistentHashMap>) = UMapLengthMemoryRegion(sort, mapType, updated, inputLengths) private fun getInputLength(ref: UMapLengthLValue): UInputMapLength { @@ -67,20 +68,21 @@ internal class UMapLengthMemoryRegion( override fun write( key: UMapLengthLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( ref = key.ref, initial = this, initialGuard = guard, blockOnConcrete = { region, (concreteRef, innerGuard) -> - val newRegion = region.allocatedLengths.guardedWrite(concreteRef.address, value, innerGuard) { + val newRegion = region.allocatedLengths.guardedWrite(concreteRef.address, value, innerGuard, ownership) { sort.sampleUValue() } region.updateAllocated(newRegion) }, blockOnSymbolic = { region, (symbolicRef, innerGuard) -> val oldRegion = region.getInputLength(key) - val newRegion = oldRegion.write(symbolicRef, value, innerGuard) + val newRegion = oldRegion.write(symbolicRef, value, innerGuard, ownership) region.updateInput(newRegion) } ) diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/length/USymbolicMapLengthId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/length/USymbolicMapLengthId.kt index a992e4bdc9..6c2af2f65f 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/length/USymbolicMapLengthId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/length/USymbolicMapLengthId.kt @@ -32,7 +32,7 @@ class UInputMapLengthId internal constructor( return sort.uctx.withSizeSort().mkInputMapLengthReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegion.kt index 44a1707ddd..7ad19693f4 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegion.kt @@ -1,13 +1,15 @@ package org.usvm.collection.map.primitive -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort import org.usvm.collection.map.USymbolicMapKey import org.usvm.collection.set.primitive.USetRegion +import org.usvm.collections.immutable.getOrPut +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -58,7 +60,8 @@ interface UMapRegion, - initialGuard: UBoolExpr + initialGuard: UBoolExpr, + ownership: MutabilityOwnership, ): UMapRegion } @@ -67,9 +70,12 @@ internal class UMapMemoryRegion, Reg>, - private var allocatedMaps: PersistentMap, UAllocatedMap> = persistentHashMapOf(), + private var allocatedMaps: UPersistentHashMap, UAllocatedMap> = persistentHashMapOf(), private var inputMap: UInputMap? = null, ) : UMapRegion { + + private val defaultOwnership = valueSort.uctx.defaultOwnership + init { check(keySort != keySort.uctx.addressSort) { "Ref map must be used to handle maps with ref keys" @@ -79,23 +85,21 @@ internal class UMapMemoryRegion ): UAllocatedMap { - var collection = allocatedMaps[id] - if (collection == null) { - collection = id.emptyRegion() - allocatedMaps = allocatedMaps.put(id, collection) - } + val (updatesMaps, collection) = allocatedMaps.getOrPut(id, defaultOwnership) { id.emptyRegion() } + allocatedMaps = updatesMaps return collection } private fun updateAllocatedMap( id: UAllocatedMapId, - updatedMap: UAllocatedMap + updatedMap: UAllocatedMap, + ownership: MutabilityOwnership, ) = UMapMemoryRegion( keySort, valueSort, mapType, keyInfo, - allocatedMaps.put(id, updatedMap), + allocatedMaps.put(id, updatedMap, ownership), inputMap ) @@ -128,7 +132,8 @@ internal class UMapMemoryRegion, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( ref = key.mapRef, initial = this, @@ -136,12 +141,12 @@ internal class UMapMemoryRegion val id = UAllocatedMapId(keySort, valueSort, mapType, keyInfo, concreteRef.address) val map = region.getAllocatedMap(id) - val newMap = map.write(key.mapKey, value, guard) - region.updateAllocatedMap(id, newMap) + val newMap = map.write(key.mapKey, value, guard, ownership) + region.updateAllocatedMap(id, newMap, ownership) }, blockOnSymbolic = { region, (symbolicRef, guard) -> val map = region.getInputMap() - val newMap = map.write(symbolicRef to key.mapKey, value, guard) + val newMap = map.write(symbolicRef to key.mapKey, value, guard, ownership) region.updateInputMap(newMap) } ) @@ -151,7 +156,8 @@ internal class UMapMemoryRegion, - initialGuard: UBoolExpr + initialGuard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRef2( ref0 = srcRef, ref1 = dstRef, @@ -167,7 +173,7 @@ internal class UMapMemoryRegion val srcId = UAllocatedMapId(keySort, valueSort, mapType, keyInfo, srcConcrete.address) @@ -188,7 +194,7 @@ internal class UMapMemoryRegion val srcCollection = region.getInputMap() diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegionApi.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegionApi.kt index 4d6b2adc84..f11f031c78 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegionApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/UMapRegionApi.kt @@ -18,7 +18,7 @@ internal fun > UW sort: ValueSort, keyInfo: USymbolicCollectionKeyInfo, Reg>, keySet: USetRegionId, - guard: UBoolExpr + guard: UBoolExpr, ) { val regionId = UMapRegionId(keySort, sort, mapType, keyInfo) val region = getRegion(regionId) @@ -32,6 +32,6 @@ internal fun > UW "mapMerge is not applicable to set $region" } - val newRegion = region.merge(srcRef, dstRef, mapType, keySetRegion, guard) + val newRegion = region.merge(srcRef, dstRef, mapType, keySetRegion, guard, ownership) setRegion(regionId, newRegion) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/USymbolicMapId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/USymbolicMapId.kt index e99d115798..ca8663c730 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/USymbolicMapId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/primitive/USymbolicMapId.kt @@ -66,7 +66,7 @@ class UAllocatedMapId return sort.uctx.mkInputMapReading(collection, key.first, key.second) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegion.kt index a842d980d1..b8b6f47da4 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegion.kt @@ -1,7 +1,5 @@ package org.usvm.collection.map.ref -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UAddressSort import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress @@ -12,6 +10,10 @@ import org.usvm.USort import org.usvm.collection.map.USymbolicMapKey import org.usvm.collection.set.ref.URefSetEntryLValue import org.usvm.collection.set.ref.URefSetRegion +import org.usvm.collections.immutable.getOrPut +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -52,7 +54,8 @@ interface URefMapRegion mapType: MapType, sort: ValueSort, keySet: URefSetRegion, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): URefMapRegion } @@ -73,14 +76,16 @@ internal data class UAllocatedRefMapWithAllocatedKeysId( internal class URefMapMemoryRegion( private val valueSort: ValueSort, private val mapType: MapType, - private var allocatedMapWithAllocatedKeys: PersistentMap> = persistentHashMapOf(), - private var inputMapWithAllocatedKeys: PersistentMap, UInputRefMapWithAllocatedKeys> = persistentHashMapOf(), - private var allocatedMapWithInputKeys: PersistentMap, UAllocatedRefMapWithInputKeys> = persistentHashMapOf(), + private var allocatedMapWithAllocatedKeys: UPersistentHashMap> = persistentHashMapOf(), + private var inputMapWithAllocatedKeys: UPersistentHashMap, UInputRefMapWithAllocatedKeys> = persistentHashMapOf(), + private var allocatedMapWithInputKeys: UPersistentHashMap, UAllocatedRefMapWithInputKeys> = persistentHashMapOf(), private var inputMapWithInputKeys: UInputRefMap? = null, ) : URefMapRegion { + private val defaultOwnership = valueSort.uctx.defaultOwnership + private fun updateAllocatedMapWithAllocatedKeys( - updated: PersistentMap> + updated: UPersistentHashMap> ) = URefMapMemoryRegion( valueSort, mapType, @@ -96,22 +101,20 @@ internal class URefMapMemoryRegion( private fun getInputMapWithAllocatedKeys( id: UInputRefMapWithAllocatedKeysId ): UInputRefMapWithAllocatedKeys { - var collection = inputMapWithAllocatedKeys[id] - if (collection == null) { - collection = id.emptyRegion() - inputMapWithAllocatedKeys = inputMapWithAllocatedKeys.put(id, collection) - } + val (updatedMap, collection) = inputMapWithAllocatedKeys.getOrPut(id, defaultOwnership) { id.emptyRegion() } + inputMapWithAllocatedKeys = updatedMap return collection } private fun updateInputMapWithAllocatedKeys( id: UInputRefMapWithAllocatedKeysId, - updatedMap: UInputRefMapWithAllocatedKeys + updatedMap: UInputRefMapWithAllocatedKeys, + ownership: MutabilityOwnership, ) = URefMapMemoryRegion( valueSort, mapType, allocatedMapWithAllocatedKeys, - inputMapWithAllocatedKeys.put(id, updatedMap), + inputMapWithAllocatedKeys.put(id, updatedMap, ownership), allocatedMapWithInputKeys, inputMapWithInputKeys ) @@ -122,23 +125,21 @@ internal class URefMapMemoryRegion( private fun getAllocatedMapWithInputKeys( id: UAllocatedRefMapWithInputKeysId ): UAllocatedRefMapWithInputKeys { - var collection = allocatedMapWithInputKeys[id] - if (collection == null) { - collection = id.emptyRegion() - allocatedMapWithInputKeys = allocatedMapWithInputKeys.put(id, collection) - } + val (updatedMap, collection) = allocatedMapWithInputKeys.getOrPut(id, defaultOwnership) { id.emptyRegion() } + allocatedMapWithInputKeys = updatedMap return collection } private fun updateAllocatedMapWithInputKeys( id: UAllocatedRefMapWithInputKeysId, - updatedMap: UAllocatedRefMapWithInputKeys + updatedMap: UAllocatedRefMapWithInputKeys, + ownership: MutabilityOwnership, ) = URefMapMemoryRegion( valueSort, mapType, allocatedMapWithAllocatedKeys, inputMapWithAllocatedKeys, - allocatedMapWithInputKeys.put(id, updatedMap), + allocatedMapWithInputKeys.put(id, updatedMap, ownership), inputMapWithInputKeys ) @@ -190,7 +191,8 @@ internal class URefMapMemoryRegion( override fun write( key: URefMapEntryLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( ref = key.mapRef, initial = this, @@ -202,7 +204,7 @@ internal class URefMapMemoryRegion( initialGuard = mapGuard, blockOnConcrete = { region, (concreteKeyRef, guard) -> val id = UAllocatedRefMapWithAllocatedKeysId(concreteMapRef.address, concreteKeyRef.address) - val newMap = region.allocatedMapWithAllocatedKeys.guardedWrite(id, value, guard) { + val newMap = region.allocatedMapWithAllocatedKeys.guardedWrite(id, value, guard, ownership) { valueSort.sampleUValue() } region.updateAllocatedMapWithAllocatedKeys(newMap) @@ -210,8 +212,8 @@ internal class URefMapMemoryRegion( blockOnSymbolic = { region, (symbolicKeyRef, guard) -> val id = allocatedMapWithInputKeyId(concreteMapRef.address) val newMap = region.getAllocatedMapWithInputKeys(id) - .write(symbolicKeyRef, value, guard) - region.updateAllocatedMapWithInputKeys(id, newMap) + .write(symbolicKeyRef, value, guard, ownership) + region.updateAllocatedMapWithInputKeys(id, newMap, ownership) } ) }, @@ -223,12 +225,12 @@ internal class URefMapMemoryRegion( blockOnConcrete = { region, (concreteKeyRef, guard) -> val id = inputMapWithAllocatedKeyId(concreteKeyRef.address) val newMap = region.getInputMapWithAllocatedKeys(id) - .write(symbolicMapRef, value, guard) - region.updateInputMapWithAllocatedKeys(id, newMap) + .write(symbolicMapRef, value, guard, ownership) + region.updateInputMapWithAllocatedKeys(id, newMap, ownership) }, blockOnSymbolic = { region, (symbolicKeyRef, guard) -> val newMap = region.getInputMapWithInputKeys() - .write(symbolicMapRef to symbolicKeyRef, value, guard) + .write(symbolicMapRef to symbolicKeyRef, value, guard, ownership) region.updateInputMapWithInputKeys(newMap) } ) @@ -255,7 +257,8 @@ internal class URefMapMemoryRegion( mapType: MapType, sort: ValueSort, keySet: URefSetRegion, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRef2( ref0 = srcRef, ref1 = dstRef, @@ -271,7 +274,7 @@ internal class URefMapMemoryRegion( read = { initialAllocatedMapState[it] ?: valueSort.sampleUValue() }, mkDstKeyId = { UAllocatedRefMapWithAllocatedKeysId(dstConcrete.address, it) }, write = { result, dstKeyId, value, g -> - result.guardedWrite(dstKeyId, value, g) { valueSort.sampleUValue() } + result.guardedWrite(dstKeyId, value, g, ownership) { valueSort.sampleUValue() } } ) val updatedRegion = region.updateAllocatedMapWithAllocatedKeys(updatedAllocatedMap) @@ -285,7 +288,7 @@ internal class URefMapMemoryRegion( val adapter = UAllocatedToAllocatedSymbolicRefMapMergeAdapter(srcKeys) val updatedDstCollection = dstInputKeysCollection.copyRange(srcInputKeysCollection, adapter, guard) - updatedRegion.updateAllocatedMapWithInputKeys(dstInputKeysId, updatedDstCollection) + updatedRegion.updateAllocatedMapWithInputKeys(dstInputKeysId, updatedDstCollection, ownership) }, blockOnConcrete0Symbolic1 = { region, srcConcrete, dstSymbolic, guard -> val initialAllocatedMapState = region.allocatedMapWithAllocatedKeys @@ -295,8 +298,8 @@ internal class URefMapMemoryRegion( mkDstKeyId = { inputMapWithAllocatedKeyId(it) }, write = { result, dstKeyId, value, g -> val newMap = result.getInputMapWithAllocatedKeys(dstKeyId) - .write(dstSymbolic, value, g) - result.updateInputMapWithAllocatedKeys(dstKeyId, newMap) + .write(dstSymbolic, value, g, ownership) + result.updateInputMapWithAllocatedKeys(dstKeyId, newMap, ownership) } ) @@ -317,7 +320,7 @@ internal class URefMapMemoryRegion( read = { region.getInputMapWithAllocatedKeys(it).read(srcSymbolic) }, mkDstKeyId = { UAllocatedRefMapWithAllocatedKeysId(dstConcrete.address, it) }, write = { result, dstKeyId, value, g -> - result.guardedWrite(dstKeyId, value, g) { sort.sampleUValue() } + result.guardedWrite(dstKeyId, value, g, ownership) { sort.sampleUValue() } } ) val updatedRegion = region.updateAllocatedMapWithAllocatedKeys(updatedAllocatedMap) @@ -330,7 +333,7 @@ internal class URefMapMemoryRegion( val adapter = UInputToAllocatedSymbolicRefMapMergeAdapter(srcSymbolic, srcKeys) val updatedDstCollection = dstInputKeysCollection.copyRange(srcInputKeysCollection, adapter, guard) - updatedRegion.updateAllocatedMapWithInputKeys(dstInputKeysId, updatedDstCollection) + updatedRegion.updateAllocatedMapWithInputKeys(dstInputKeysId, updatedDstCollection, ownership) }, blockOnSymbolic0Symbolic1 = { region, srcSymbolic, dstSymbolic, guard -> val updatedRegion = region.mergeInputMapAllocatedKeys( @@ -339,8 +342,8 @@ internal class URefMapMemoryRegion( mkDstKeyId = { inputMapWithAllocatedKeyId(it) }, write = { result, dstKeyId, value, g -> val newMap = result.getInputMapWithAllocatedKeys(dstKeyId) - .write(dstSymbolic, value, g) - result.updateInputMapWithAllocatedKeys(dstKeyId, newMap) + .write(dstSymbolic, value, g, ownership) + result.updateInputMapWithAllocatedKeys(dstKeyId, newMap, ownership) } ) val srcKeys = keySet.inputSetWithInputElements() @@ -364,7 +367,7 @@ internal class URefMapMemoryRegion( write: (R, DstKeyId, UExpr, UBoolExpr) -> R ) = mergeAllocatedKeys( initial, - inputMapWithAllocatedKeys.keys, + inputMapWithAllocatedKeys.keys.toList(), guard, keySet, srcMapRef, @@ -384,7 +387,7 @@ internal class URefMapMemoryRegion( write: (R, DstKeyId, UExpr, UBoolExpr) -> R ) = mergeAllocatedKeys( initial, - allocatedMapWithAllocatedKeys.keys.filter { it.mapAddress == srcMapRef.address }, + allocatedMapWithAllocatedKeys.keys.filterTo(mutableListOf()) { it.mapAddress == srcMapRef.address }, guard, keySet, srcMapRef, @@ -396,7 +399,7 @@ internal class URefMapMemoryRegion( private inline fun mergeAllocatedKeys( initial: R, - keys: Iterable, + keys: List, guard: UBoolExpr, keySet: URefSetRegion, srcMapRef: UHeapRef, diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegionApi.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegionApi.kt index b66dba4dc4..1eb6bb396a 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegionApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/URefMapRegionApi.kt @@ -27,6 +27,6 @@ internal fun UWritableMemory<*>.refMapMerge( "refMapMerge is not applicable to set $region" } - val newRegion = region.merge(srcRef, dstRef, mapType, sort, keySet, guard) + val newRegion = region.merge(srcRef, dstRef, mapType, sort, keySet, guard, ownership) setRegion(regionId, newRegion) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/USymbolicRefMapId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/USymbolicRefMapId.kt index b762a242c2..1e69cd05a9 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/USymbolicRefMapId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/map/ref/USymbolicRefMapId.kt @@ -61,7 +61,7 @@ class UAllocatedRefMapWithInputKeysId( return key.uctx.mkAllocatedRefMapWithInputKeysReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -126,7 +126,7 @@ class UInputRefMapWithAllocatedKeysId( return key.uctx.mkInputRefMapWithAllocatedKeysReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -185,7 +185,7 @@ class UInputRefMapWithInputKeysId( return sort.uctx.mkInputRefMapWithInputKeysReading(collection, key.first, key.second) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/USetCollectionDecoder.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/USetCollectionDecoder.kt index 5af60da358..fd13b604df 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/USetCollectionDecoder.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/USetCollectionDecoder.kt @@ -5,7 +5,7 @@ import io.ksmt.expr.KExpr import io.ksmt.expr.KFunctionApp import io.ksmt.sort.KBoolSort import io.ksmt.utils.uncheckedCast -import kotlinx.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.UAddressSort import org.usvm.UBoolExpr import org.usvm.UBoolSort @@ -41,7 +41,7 @@ abstract class USetCollectionDecoder { val usedSetKeys = hashSetOf>() assertions.flatMapTo(usedSetKeys) { appCollector.applyVisitor(it) } - val entries = persistentHashMapOf, UBoolExpr>().builder() + var entries = persistentHashMapOf, UBoolExpr>() for (key in usedSetKeys) { val keyInSet = model.eval(key, isComplete = false) if (!keyInSet.isTrue) continue @@ -49,13 +49,14 @@ abstract class USetCollectionDecoder { val (rawSetRef, rawElement) = key.args val setRef: UHeapRef = rawSetRef.uncheckedCast() val element: UExpr = rawElement.uncheckedCast() - val setRefModel = model.eval(setRef, isComplete = true).mapAddress(mapping) val elementModel = model.eval(element, isComplete = true).mapAddress(mapping) - entries[setRefModel to elementModel] = inputFunction.ctx.trueExpr + entries = entries.put( + setRefModel to elementModel, inputFunction.ctx.trueExpr, evaluator.ctx.defaultOwnership + ) } - return UMemory2DArray(entries.build(), constValue = inputFunction.ctx.falseExpr) + return UMemory2DArray(entries, constValue = inputFunction.ctx.falseExpr) } } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegion.kt index 74ec7eeed3..2268378f87 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegion.kt @@ -1,7 +1,5 @@ package org.usvm.collection.set.primitive -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UBoolSort import org.usvm.UConcreteHeapAddress @@ -11,6 +9,10 @@ import org.usvm.USort import org.usvm.collection.set.USymbolicSetEntries import org.usvm.collection.set.USymbolicSetElement import org.usvm.collection.set.USymbolicSetElementsCollector +import org.usvm.collections.immutable.getOrPut +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -77,6 +79,7 @@ interface USetRegion> : srcRef: UHeapRef, dstRef: UHeapRef, operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): USetRegion } @@ -84,9 +87,11 @@ internal class USetMemoryRegion> private val setType: SetType, private val elementSort: ElementSort, private val elementInfo: USymbolicCollectionKeyInfo, Reg>, - private var allocatedSets: PersistentMap, UAllocatedSet> = persistentHashMapOf(), + private var allocatedSets: UPersistentHashMap, UAllocatedSet> = persistentHashMapOf(), private var inputSet: UInputSet? = null, ) : USetRegion { + + private val defaultOwnership = elementSort.uctx.defaultOwnership init { check(elementSort != elementSort.uctx.addressSort) { "Ref set must be used to handle sets with ref elements" @@ -99,18 +104,16 @@ internal class USetMemoryRegion> private fun getAllocatedSet( id: UAllocatedSetId ): UAllocatedSet { - var collection = allocatedSets[id] - if (collection == null) { - collection = id.emptyRegion() - allocatedSets = allocatedSets.put(id, collection) - } + val (updatesSets, collection) = allocatedSets.getOrPut(id, defaultOwnership) { id.emptyRegion() } + allocatedSets = updatesSets return collection } private fun updateAllocatedSet( id: UAllocatedSetId, - updated: UAllocatedSet - ) = USetMemoryRegion(setType, elementSort, elementInfo, allocatedSets.put(id, updated), inputSet) + updated: UAllocatedSet, + ownership: MutabilityOwnership, + ) = USetMemoryRegion(setType, elementSort, elementInfo, allocatedSets.put(id, updated, ownership), inputSet) override fun inputSetElements(): UInputSet { if (inputSet == null) { @@ -127,11 +130,11 @@ internal class USetMemoryRegion> { concreteRef -> allocatedSetElements(concreteRef.address).read(key.setElement) }, { symbolicRef -> inputSetElements().read(symbolicRef to key.setElement) } ) - override fun write( key: USetEntryLValue, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( ref = key.setRef, initial = this, @@ -139,12 +142,12 @@ internal class USetMemoryRegion> blockOnConcrete = { region, (concreteRef, guard) -> val id = UAllocatedSetId(concreteRef.address, elementSort, setType, elementInfo) val newCollection = region.getAllocatedSet(id) - .write(key.setElement, value, guard) - region.updateAllocatedSet(id, newCollection) + .write(key.setElement, value, guard, ownership) + region.updateAllocatedSet(id, newCollection, ownership) }, blockOnSymbolic = { region, (symbolicRef, guard) -> val newCollection = region.inputSetElements() - .write(symbolicRef to key.setElement, value, guard) + .write(symbolicRef to key.setElement, value, guard, ownership) region.updateInputSet(newCollection) } ) @@ -152,7 +155,8 @@ internal class USetMemoryRegion> override fun union( srcRef: UHeapRef, dstRef: UHeapRef, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRef2( ref0 = srcRef, ref1 = dstRef, @@ -167,7 +171,7 @@ internal class USetMemoryRegion> val adapter = UAllocatedToAllocatedSymbolicSetUnionAdapter(srcCollection) val updated = dstCollection.copyRange(srcCollection, adapter, guard) - region.updateAllocatedSet(dstId, updated) + region.updateAllocatedSet(dstId, updated, ownership) }, blockOnConcrete0Symbolic1 = { region, srcConcrete, dstSymbolic, guard -> val srcId = UAllocatedSetId(srcConcrete.address, elementSort, setType, elementInfo) @@ -187,7 +191,7 @@ internal class USetMemoryRegion> val adapter = UInputToAllocatedSymbolicSetUnionAdapter(srcSymbolic, srcCollection) val updated = dstCollection.copyRange(srcCollection, adapter, guard) - region.updateAllocatedSet(dstId, updated) + region.updateAllocatedSet(dstId, updated, ownership) }, blockOnSymbolic0Symbolic1 = { region, srcSymbolic, dstSymbolic, guard -> val srcCollection = region.inputSetElements() diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegionApi.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegionApi.kt index 5d463a20f8..91d8f1c0a7 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegionApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USetRegionApi.kt @@ -24,7 +24,7 @@ internal fun > UWritableMemory<*>.se "setUnion is not applicable to $region" } - val newRegion = region.union(srcRef, dstRef, guard) + val newRegion = region.union(srcRef, dstRef, guard, ownership) setRegion(regionId, newRegion) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USymbolicSetId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USymbolicSetId.kt index 19e8bb62d6..bf1f0f50d8 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USymbolicSetId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/primitive/USymbolicSetId.kt @@ -57,7 +57,7 @@ class UAllocatedSetId>( return sort.uctx.mkAllocatedSetReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -133,7 +133,7 @@ class UInputSetId>( return sort.uctx.mkInputSetReading(collection, key.first, key.second) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegion.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegion.kt index 27de317e53..a8c3403d92 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegion.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegion.kt @@ -1,7 +1,5 @@ package org.usvm.collection.set.ref -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UAddressSort import org.usvm.UBoolExpr import org.usvm.UBoolSort @@ -10,6 +8,10 @@ import org.usvm.UHeapRef import org.usvm.collection.set.USymbolicSetEntries import org.usvm.collection.set.USymbolicSetElement import org.usvm.collection.set.USymbolicSetElementsCollector +import org.usvm.collections.immutable.getOrPut +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -76,19 +78,23 @@ interface URefSetRegion : srcRef: UHeapRef, dstRef: UHeapRef, operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ): URefSetRegion } internal class URefSetMemoryRegion( private val setType: SetType, private val sort: UBoolSort, - private var allocatedSetWithAllocatedElements: PersistentMap = persistentHashMapOf(), - private var allocatedSetWithInputElements: PersistentMap, UAllocatedRefSetWithInputElements> = persistentHashMapOf(), - private var inputSetWithAllocatedElements: PersistentMap, UInputRefSetWithAllocatedElements> = persistentHashMapOf(), + private var allocatedSetWithAllocatedElements: UPersistentHashMap = persistentHashMapOf(), + private var allocatedSetWithInputElements: UPersistentHashMap, UAllocatedRefSetWithInputElements> = persistentHashMapOf(), + private var inputSetWithAllocatedElements: UPersistentHashMap, UInputRefSetWithAllocatedElements> = persistentHashMapOf(), private var inputSetWithInputElements: UInputRefSetWithInputElements? = null ) : URefSetRegion { + + private val defaultOwnership = sort.uctx.defaultOwnership + private fun updateAllocatedSetWithAllocatedElements( - updated: PersistentMap + updated: UPersistentHashMap ) = URefSetMemoryRegion( setType, sort, updated, @@ -101,13 +107,10 @@ internal class URefSetMemoryRegion( UAllocatedRefSetWithInputElementsId(setAddress, setType, sort) private fun getAllocatedSetWithInputElements( - id: UAllocatedRefSetWithInputElementsId + id: UAllocatedRefSetWithInputElementsId, ): UAllocatedRefSetWithInputElements { - var collection = allocatedSetWithInputElements[id] - if (collection == null) { - collection = id.emptyRegion() - allocatedSetWithInputElements = allocatedSetWithInputElements.put(id, collection) - } + val (updatedSet, collection) = allocatedSetWithInputElements.getOrPut(id, defaultOwnership) { id.emptyRegion() } + allocatedSetWithInputElements = updatedSet return collection } @@ -116,11 +119,12 @@ internal class URefSetMemoryRegion( private fun updateAllocatedSetWithInputElements( id: UAllocatedRefSetWithInputElementsId, - updatedSet: UAllocatedRefSetWithInputElements + updatedSet: UAllocatedRefSetWithInputElements, + ownership: MutabilityOwnership, ) = URefSetMemoryRegion( setType, sort, allocatedSetWithAllocatedElements, - allocatedSetWithInputElements.put(id, updatedSet), + allocatedSetWithInputElements.put(id, updatedSet, ownership), inputSetWithAllocatedElements, inputSetWithInputElements ) @@ -131,22 +135,20 @@ internal class URefSetMemoryRegion( private fun getInputSetWithAllocatedElements( id: UInputRefSetWithAllocatedElementsId ): UInputRefSetWithAllocatedElements { - var collection = inputSetWithAllocatedElements[id] - if (collection == null) { - collection = id.emptyRegion() - inputSetWithAllocatedElements = inputSetWithAllocatedElements.put(id, collection) - } + val (updatedMap, collection) = inputSetWithAllocatedElements.getOrPut(id, defaultOwnership) { id.emptyRegion() } + inputSetWithAllocatedElements = updatedMap return collection } private fun updateInputSetWithAllocatedElements( id: UInputRefSetWithAllocatedElementsId, - updatedSet: UInputRefSetWithAllocatedElements + updatedSet: UInputRefSetWithAllocatedElements, + ownership: MutabilityOwnership, ) = URefSetMemoryRegion( setType, sort, allocatedSetWithAllocatedElements, allocatedSetWithInputElements, - inputSetWithAllocatedElements.put(id, updatedSet), + inputSetWithAllocatedElements.put(id, updatedSet, ownership), inputSetWithInputElements ) @@ -195,7 +197,8 @@ internal class URefSetMemoryRegion( override fun write( key: URefSetEntryLValue, value: UBoolExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRefWithStaticAsSymbolic( ref = key.setRef, initial = this, @@ -207,7 +210,7 @@ internal class URefSetMemoryRegion( initialGuard = setGuard, blockOnConcrete = { region, (concreteElemRef, guard) -> val id = UAllocatedRefSetWithAllocatedElementId(concreteSetRef.address, concreteElemRef.address) - val newMap = region.allocatedSetWithAllocatedElements.guardedWrite(id, value, guard) { + val newMap = region.allocatedSetWithAllocatedElements.guardedWrite(id, value, guard, ownership) { sort.uctx.falseExpr } region.updateAllocatedSetWithAllocatedElements(newMap) @@ -215,8 +218,8 @@ internal class URefSetMemoryRegion( blockOnSymbolic = { region, (symbolicElemRef, guard) -> val id = allocatedSetWithInputElementsId(concreteSetRef.address) val newMap = region.getAllocatedSetWithInputElements(id) - .write(symbolicElemRef, value, guard) - region.updateAllocatedSetWithInputElements(id, newMap) + .write(symbolicElemRef, value, guard, ownership) + region.updateAllocatedSetWithInputElements(id, newMap, ownership) } ) }, @@ -228,12 +231,12 @@ internal class URefSetMemoryRegion( blockOnConcrete = { region, (concreteElemRef, guard) -> val id = inputSetWithAllocatedElementsId(concreteElemRef.address) val newMap = region.getInputSetWithAllocatedElements(id) - .write(symbolicSetRef, value, guard) - region.updateInputSetWithAllocatedElements(id, newMap) + .write(symbolicSetRef, value, guard, ownership) + region.updateInputSetWithAllocatedElements(id, newMap, ownership) }, blockOnSymbolic = { region, (symbolicElemRef, guard) -> val newMap = region.inputSetWithInputElements() - .write(symbolicSetRef to symbolicElemRef, value, guard) + .write(symbolicSetRef to symbolicElemRef, value, guard, ownership) region.updateInputSetWithInputElements(newMap) } ) @@ -243,7 +246,8 @@ internal class URefSetMemoryRegion( override fun union( srcRef: UHeapRef, dstRef: UHeapRef, - operationGuard: UBoolExpr + operationGuard: UBoolExpr, + ownership: MutabilityOwnership, ) = foldHeapRef2( ref0 = srcRef, ref1 = dstRef, @@ -258,7 +262,7 @@ internal class URefSetMemoryRegion( read = { initialAllocatedSetState[it] ?: sort.uctx.falseExpr }, mkDstKeyId = { UAllocatedRefSetWithAllocatedElementId(dstConcrete.address, it) }, write = { result, dstKeyId, value, g -> - result.guardedWrite(dstKeyId, value, g) { sort.uctx.falseExpr } + result.guardedWrite(dstKeyId, value, g, ownership) { sort.uctx.falseExpr } } ) val updatedRegion = region.updateAllocatedSetWithAllocatedElements(updatedAllocatedSet) @@ -271,7 +275,7 @@ internal class URefSetMemoryRegion( val adapter = UAllocatedToAllocatedSymbolicRefSetUnionAdapter(srcCollection) val updated = dstCollection.copyRange(srcCollection, adapter, guard) - updatedRegion.updateAllocatedSetWithInputElements(dstId, updated) + updatedRegion.updateAllocatedSetWithInputElements(dstId, updated, ownership) }, blockOnConcrete0Symbolic1 = { region, srcConcrete, dstSymbolic, guard -> val initialAllocatedSetState = region.allocatedSetWithAllocatedElements @@ -281,8 +285,8 @@ internal class URefSetMemoryRegion( mkDstKeyId = { inputSetWithAllocatedElementsId(it) }, write = { result, dstKeyId, value, g -> val newMap = result.getInputSetWithAllocatedElements(dstKeyId) - .write(dstSymbolic, value, g) - result.updateInputSetWithAllocatedElements(dstKeyId, newMap) + .write(dstSymbolic, value, g, ownership) + result.updateInputSetWithAllocatedElements(dstKeyId, newMap, ownership) } ) @@ -302,7 +306,7 @@ internal class URefSetMemoryRegion( read = { region.getInputSetWithAllocatedElements(it).read(srcSymbolic) }, mkDstKeyId = { UAllocatedRefSetWithAllocatedElementId(dstConcrete.address, it) }, write = { result, dstKeyId, value, g -> - result.guardedWrite(dstKeyId, value, g) { sort.uctx.falseExpr } + result.guardedWrite(dstKeyId, value, g, ownership) { sort.uctx.falseExpr } } ) val updatedRegion = region.updateAllocatedSetWithAllocatedElements(updatedAllocatedSet) @@ -314,7 +318,7 @@ internal class URefSetMemoryRegion( val adapter = UInputToAllocatedSymbolicRefSetUnionAdapter(srcSymbolic, srcCollection) val updated = dstCollection.copyRange(srcCollection, adapter, guard) - updatedRegion.updateAllocatedSetWithInputElements(dstId, updated) + updatedRegion.updateAllocatedSetWithInputElements(dstId, updated, ownership) }, blockOnSymbolic0Symbolic1 = { region, srcSymbolic, dstSymbolic, guard -> val updatedRegion = region.unionInputSetAllocatedElements( @@ -323,8 +327,8 @@ internal class URefSetMemoryRegion( mkDstKeyId = { inputSetWithAllocatedElementsId(it) }, write = { result, dstKeyId, value, g -> val newMap = result.getInputSetWithAllocatedElements(dstKeyId) - .write(dstSymbolic, value, g) - result.updateInputSetWithAllocatedElements(dstKeyId, newMap) + .write(dstSymbolic, value, g, ownership) + result.updateInputSetWithAllocatedElements(dstKeyId, newMap, ownership) } ) val srcCollection = updatedRegion.inputSetWithInputElements() @@ -344,7 +348,7 @@ internal class URefSetMemoryRegion( write: (R, DstKeyId, UBoolExpr, UBoolExpr) -> R ) = unionAllocatedElements( initial, - inputSetWithAllocatedElements.keys, + inputSetWithAllocatedElements.keys.toList(), guard, read, { mkDstKeyId(it.elementAddress) }, @@ -360,7 +364,7 @@ internal class URefSetMemoryRegion( write: (R, DstKeyId, UBoolExpr, UBoolExpr) -> R ) = unionAllocatedElements( initial, - allocatedSetWithAllocatedElements.keys.filter { it.setAddress == srcAddress }, + allocatedSetWithAllocatedElements.keys.filterTo(mutableListOf()) { it.setAddress == srcAddress }, guard, read, { mkDstKeyId(it.elementAddress) }, @@ -369,7 +373,7 @@ internal class URefSetMemoryRegion( private inline fun unionAllocatedElements( initial: R, - keys: Iterable, + keys: List, guard: UBoolExpr, read: (SrcKeyId) -> UBoolExpr, mkDstKeyId: (SrcKeyId) -> DstKeyId, diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegionApi.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegionApi.kt index ff77cf7e7b..5f84710bc5 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegionApi.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/URefSetRegionApi.kt @@ -19,7 +19,7 @@ internal fun UWritableMemory<*>.refSetUnion( "setUnion is not applicable to $region" } - val newRegion = region.union(srcRef, dstRef, guard) + val newRegion = region.union(srcRef, dstRef, guard, ownership) setRegion(regionId, newRegion) } diff --git a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/USymbolicRefSetId.kt b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/USymbolicRefSetId.kt index 791b4c1f7a..6aafe2909f 100644 --- a/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/USymbolicRefSetId.kt +++ b/usvm-core/src/main/kotlin/org/usvm/collection/set/ref/USymbolicRefSetId.kt @@ -54,7 +54,7 @@ class UAllocatedRefSetWithInputElementsId( return key.uctx.mkAllocatedRefSetWithInputElementsReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -128,7 +128,7 @@ class UInputRefSetWithAllocatedElementsId( return key.uctx.mkInputRefSetWithAllocatedElementsReading(collection, key) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } @@ -197,7 +197,7 @@ class UInputRefSetWithInputElementsId( return sort.uctx.mkInputRefSetWithInputElementsReading(collection, key.first, key.second) } - val memory = composer.memory.toWritableMemory() + val memory = composer.memory.toWritableMemory(composer.ownership) collection.applyTo(memory, key, composer) return memory.read(mkLValue(key)) } diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/EqualityConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/EqualityConstraints.kt index 14d62eeefd..e63d44f7ba 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/EqualityConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/EqualityConstraints.kt @@ -1,20 +1,27 @@ package org.usvm.constraints -import kotlinx.collections.immutable.PersistentSet -import kotlinx.collections.immutable.persistentHashMapOf -import kotlinx.collections.immutable.persistentHashSetOf +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.persistentHashSetOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapRef import org.usvm.UContext import org.usvm.UHeapRef import org.usvm.UNullRef import org.usvm.USymbolicHeapRef +import org.usvm.algorithms.addToSet +import org.usvm.algorithms.addAll import org.usvm.algorithms.DisjointSets -import org.usvm.algorithms.PersistentMultiMap -import org.usvm.algorithms.PersistentMultiMapBuilder +import org.usvm.algorithms.UPersistentMultiMap +import org.usvm.algorithms.containsValue +import org.usvm.algorithms.removeValue +import org.usvm.algorithms.removeAllValues +import org.usvm.algorithms.multiMapIterator +import org.usvm.collections.immutable.implementations.immutableSet.UPersistentHashSet +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.isEmpty import org.usvm.isStaticHeapRef import org.usvm.merging.MutableMergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable import org.usvm.solver.UExprTranslator /** @@ -33,28 +40,24 @@ import org.usvm.solver.UExprTranslator */ class UEqualityConstraints private constructor( internal val ctx: UContext<*>, + private var ownership: MutabilityOwnership, private val equalReferences: DisjointSets, - persistentDistinctReferences: PersistentSet, - persistentReferenceDisequalities: PersistentMultiMap, - persistentNullableDisequalities: PersistentMultiMap, -) : UMergeable { - constructor(ctx: UContext<*>) : this( + var distinctReferences: UPersistentHashSet, + var referenceDisequalities: UPersistentMultiMap, + var nullableDisequalities: UPersistentMultiMap, +) : UOwnedMergeable { + constructor(ctx: UContext<*>, ownership: MutabilityOwnership) : this( ctx, + ownership, DisjointSets(representativeSelector = RefsRepresentativeSelector), - persistentHashSetOf(ctx.nullRef), + persistentHashSetOf().add(ctx.nullRef, ownership), persistentHashMapOf(), persistentHashMapOf(), ) - private val mutableDistinctReferences = persistentDistinctReferences.builder() - private val mutableReferenceDisequalities = PersistentMultiMapBuilder(persistentReferenceDisequalities) - private val mutableNullableDisequalities = PersistentMultiMapBuilder(persistentNullableDisequalities) - - internal val distinctReferences: Set get() = mutableDistinctReferences.build() - - internal val referenceDisequalities: Map> get() = mutableReferenceDisequalities.build() - - internal val nullableDisequalities: Map> get() = mutableNullableDisequalities.build() + fun changeOwnership(ownership: MutabilityOwnership) { + this.ownership = ownership + } /** * Determines whether a static ref could be assigned to a symbolic, according to additional information. @@ -71,16 +74,16 @@ class UEqualityConstraints private constructor( private fun contradiction() { isContradicting = true equalReferences.clear() - mutableDistinctReferences.clear() - mutableReferenceDisequalities.clear() - mutableNullableDisequalities.clear() + distinctReferences = distinctReferences.clear() + referenceDisequalities = referenceDisequalities.clear() + nullableDisequalities = nullableDisequalities.clear() } private fun containsReferenceDisequality(ref1: UHeapRef, ref2: UHeapRef): Boolean = - mutableReferenceDisequalities.containsValue(ref1, ref2) + referenceDisequalities.containsValue(ref1, ref2) private fun containsNullableDisequality(ref1: UHeapRef, ref2: UHeapRef) = - mutableNullableDisequalities.containsValue(ref1, ref2) + nullableDisequalities.containsValue(ref1, ref2) /** * Returns if [ref1] is identical to [ref2] in *all* models. @@ -95,7 +98,7 @@ class UEqualityConstraints private constructor( return false } - val distinctByClique = mutableDistinctReferences.contains(repr1) && mutableDistinctReferences.contains(repr2) + val distinctByClique = distinctReferences.contains(repr1) && distinctReferences.contains(repr2) return distinctByClique || containsReferenceDisequality(repr1, repr2) } @@ -141,16 +144,15 @@ class UEqualityConstraints private constructor( * [from] and merging its disequality constraints into [to]. */ private fun rename(to: UHeapRef, from: UHeapRef) { - if (mutableDistinctReferences.contains(from)) { - if (mutableDistinctReferences.contains(to)) { + if (distinctReferences.contains(from)) { + if (distinctReferences.contains(to)) { contradiction() return } - mutableDistinctReferences.remove(from) - mutableDistinctReferences.add(to) + distinctReferences = distinctReferences.remove(from, ownership).add(to, ownership) } - val fromDiseqs = mutableReferenceDisequalities[from] + val fromDiseqs = referenceDisequalities[from] if (fromDiseqs != null) { if (fromDiseqs.contains(to)) { @@ -158,9 +160,9 @@ class UEqualityConstraints private constructor( return } - mutableReferenceDisequalities.remove(from) - fromDiseqs.forEach { - mutableReferenceDisequalities.removeValue(it, from) + referenceDisequalities = referenceDisequalities.remove(from, ownership) + fromDiseqs.toList().forEach { + referenceDisequalities = referenceDisequalities.removeValue(it, from, ownership) makeRefNonEqual(to, it) } } @@ -168,51 +170,53 @@ class UEqualityConstraints private constructor( val nullRepr = findRepresentative(ctx.nullRef) if (to == nullRepr) { // x == null satisfies nullable disequality (x !== y || (x == null && y == null)) - val removedFrom = mutableNullableDisequalities.remove(from) - val removedTo = mutableNullableDisequalities.remove(to) + val (mapWithRemovedFrom, removedFrom) = nullableDisequalities.removeAndGetValue(from, ownership) + val (mapWithRemovedTo, removedTo) = mapWithRemovedFrom.removeAndGetValue(to, ownership) + nullableDisequalities = mapWithRemovedTo removedFrom?.forEach { - mutableNullableDisequalities.removeValue(it, from) + nullableDisequalities = nullableDisequalities.removeValue(it, from, ownership) } removedTo?.forEach { - mutableNullableDisequalities.removeValue(it, to) + nullableDisequalities = nullableDisequalities.removeValue(it, to, ownership) } } else if (containsNullableDisequality(from, to)) { // If x === y, nullable disequality can hold only if both references are null makeRefEqual(to, nullRepr) } else { - val removedFrom = mutableNullableDisequalities.remove(from) + val (mapWithRemovedFrom, removedFrom) = nullableDisequalities.removeAndGetValue(from, ownership) + nullableDisequalities = mapWithRemovedFrom removedFrom?.forEach { - mutableNullableDisequalities.removeValue(it, from) + nullableDisequalities = nullableDisequalities.removeValue(it, from, ownership) makeRefNonEqualOrBothNull(to, it) } } } private fun addDisequalityUnguarded(repr1: UHeapRef, repr2: UHeapRef) { - when (mutableDistinctReferences.size) { + when (distinctReferences.calculateSize()) { 0 -> { - require(mutableReferenceDisequalities.isEmpty()) + require(referenceDisequalities.isEmpty()) // Init clique with {repr1, repr2} - mutableDistinctReferences.add(repr1) - mutableDistinctReferences.add(repr2) + distinctReferences = distinctReferences.add(repr1, ownership) + distinctReferences = distinctReferences.add(repr2, ownership) return } 1 -> { - val onlyRef = mutableDistinctReferences.single() + val onlyRef = distinctReferences.single() if (repr1 == onlyRef) { - mutableDistinctReferences.add(repr2) + distinctReferences = distinctReferences.add(repr2, ownership) return } if (repr2 == onlyRef) { - mutableDistinctReferences.add(repr1) + distinctReferences = distinctReferences.add(repr1, ownership) return } } } - val ref1InClique = mutableDistinctReferences.contains(repr1) - val ref2InClique = mutableDistinctReferences.contains(repr2) + val ref1InClique = distinctReferences.contains(repr1) + val ref2InClique = distinctReferences.contains(repr2) if (ref1InClique && ref2InClique) { return @@ -226,21 +230,22 @@ class UEqualityConstraints private constructor( val refInClique = if (ref1InClique) repr1 else repr2 val refNotInClique = if (ref1InClique) repr2 else repr1 - if (mutableDistinctReferences.all { it == refInClique || containsReferenceDisequality(refNotInClique, it) }) { + if (distinctReferences.all { it == refInClique || containsReferenceDisequality(refNotInClique, it) }) { // Ref is not in clique and disjoint from all refs in clique. Thus, we can join it to clique... - mutableReferenceDisequalities.removeAllValues(refNotInClique, mutableDistinctReferences) + referenceDisequalities = + referenceDisequalities.removeAllValues(refNotInClique, distinctReferences, ownership) - for (ref in mutableDistinctReferences) { - mutableReferenceDisequalities.removeValue(ref, refNotInClique) + for (ref in distinctReferences) { + referenceDisequalities = referenceDisequalities.removeValue(ref, refNotInClique, ownership) } - mutableDistinctReferences.add(refNotInClique) + distinctReferences = distinctReferences.add(refNotInClique, ownership) return } } - mutableReferenceDisequalities.add(repr1, repr2) - mutableReferenceDisequalities.add(repr2, repr1) + referenceDisequalities = referenceDisequalities.addToSet(repr1, repr2, ownership) + referenceDisequalities = referenceDisequalities.addToSet(repr2, repr1, ownership) } /** @@ -321,14 +326,14 @@ class UEqualityConstraints private constructor( return } - mutableNullableDisequalities.add(repr1, repr2) - mutableNullableDisequalities.add(repr2, repr1) + nullableDisequalities = nullableDisequalities.addToSet(repr1, repr2, ownership) + nullableDisequalities = nullableDisequalities.addToSet(repr2, repr1, ownership) } private fun removeNullableDisequality(repr1: UHeapRef, repr2: UHeapRef) { if (containsNullableDisequality(repr1, repr2)) { - mutableNullableDisequalities.removeValue(repr1, repr2) - mutableNullableDisequalities.removeValue(repr2, repr1) + nullableDisequalities = nullableDisequalities.removeValue(repr1, repr2, ownership) + nullableDisequalities = nullableDisequalities.removeValue(repr2, repr1, ownership) } } @@ -351,46 +356,47 @@ class UEqualityConstraints private constructor( /** * Given a newly allocated static ref [allocatedStaticRef], updates [distinctReferences] and - * [mutableReferenceDisequalities] in the following way - removes all symbolic refs that may be equal to the + * [referenceDisequalities] in the following way - removes all symbolic refs that may be equal to the * [allocatedStaticRef] (according to the [isStaticRefAssignableToSymbolic]) from the [distinctReferences] and - * moves the information about disequality for these refs to the [mutableReferenceDisequalities]. - * After that, adds the [allocatedStaticRef] to the [mutableDistinctReferences]. + * moves the information about disequality for these refs to the [referenceDisequalities]. + * After that, adds the [allocatedStaticRef] to the [distinctReferences]. */ internal fun updateDisequality(allocatedStaticRef: UConcreteHeapRef) { if (!isStaticHeapRef(allocatedStaticRef)) { return } - // Move from the clique to the [mutableDistinctReferences] + // Move from the clique to the [distinctReferences] // all symbolic refs that are typely compatible with this static ref - val referencesToRemove = mutableDistinctReferences.filter { + val referencesToRemove = distinctReferences.filter { it is USymbolicHeapRef && it !is UNullRef && isStaticRefAssignableToSymbolic(allocatedStaticRef, it) } // Here we need to save a copy of distinct refs to use all of them except the single ref from removed references - val oldDistinctRefs = mutableDistinctReferences.toSet() + val oldDistinctRefs = distinctReferences.toSet() - for (ref in referencesToRemove) { + for (ref in referencesToRemove.toList()) { val otherDistinctRefs = oldDistinctRefs - ref - mutableDistinctReferences.remove(ref) + distinctReferences = distinctReferences.remove(ref, ownership) - mutableReferenceDisequalities.addAll(ref, otherDistinctRefs) + referenceDisequalities = referenceDisequalities.addAll(ref, otherDistinctRefs, ownership) otherDistinctRefs.forEach { - mutableReferenceDisequalities.add(it, ref) + referenceDisequalities = referenceDisequalities.addToSet(it, ref, ownership) } } - mutableDistinctReferences += allocatedStaticRef + distinctReferences = distinctReferences.add(allocatedStaticRef, ownership) } /** * Creates a mutable copy of this structure. * Note that current subscribers get unsubscribed! */ - fun clone(): UEqualityConstraints { + fun clone(thisOwnership: MutabilityOwnership, cloneOwnership: MutabilityOwnership): UEqualityConstraints { + this.ownership = thisOwnership if (isContradicting) { val result = UEqualityConstraints( - ctx, DisjointSets(), + ctx, cloneOwnership, DisjointSets(), persistentHashSetOf(), persistentHashMapOf(), persistentHashMapOf() @@ -401,10 +407,11 @@ class UEqualityConstraints private constructor( return UEqualityConstraints( ctx, + cloneOwnership, equalReferences.clone(), - mutableDistinctReferences.build(), - mutableReferenceDisequalities.build(), - mutableNullableDisequalities.build(), + distinctReferences, + referenceDisequalities, + nullableDisequalities, ) } @@ -421,23 +428,30 @@ class UEqualityConstraints private constructor( * * @return the merged equality constraints. */ - override fun mergeWith(other: UEqualityConstraints, by: MutableMergeGuard): UEqualityConstraints? { + override fun mergeWith( + other: UEqualityConstraints, + by: MutableMergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership + ): UEqualityConstraints? { // TODO: refactor it - if (mutableDistinctReferences != other.mutableDistinctReferences) { + if (distinctReferences != other.distinctReferences) { return null } - if (mutableReferenceDisequalities != other.mutableReferenceDisequalities) { + if (referenceDisequalities != other.referenceDisequalities) { return null } - if (mutableNullableDisequalities != other.mutableNullableDisequalities) { + if (nullableDisequalities != other.nullableDisequalities) { return null } if (equalReferences != other.equalReferences) { return null } + other.ownership = otherOwnership // Clone because of mutable [isStaticRefAssignableToSymbolic] - return clone() + return clone(thisOwnership, mergedOwnership) } fun constraints(translator: UExprTranslator<*, *>): Sequence { @@ -446,7 +460,7 @@ class UEqualityConstraints private constructor( val result = mutableListOf() val nullRepr = findRepresentative(ctx.nullRef) - for (ref in mutableDistinctReferences) { + for (ref in distinctReferences) { // Static refs are already translated as a values of an uninterpreted sort if (isStaticHeapRef(ref)) { continue @@ -466,7 +480,7 @@ class UEqualityConstraints private constructor( val processedConstraints = mutableSetOf>() - for ((ref1, ref2) in mutableReferenceDisequalities) { + for ((ref1, ref2) in referenceDisequalities.multiMapIterator()) { if (!processedConstraints.contains(ref2 to ref1)) { processedConstraints.add(ref1 to ref2) val translatedRef1 = translator.translate(ref1) @@ -477,7 +491,7 @@ class UEqualityConstraints private constructor( processedConstraints.clear() val translatedNull = translator.transform(ctx.nullRef) - for ((ref1, ref2) in mutableNullableDisequalities) { + for ((ref1, ref2) in nullableDisequalities.multiMapIterator()) { if (!processedConstraints.contains(ref2 to ref1)) { processedConstraints.add(ref1 to ref2) val translatedRef1 = translator.translate(ref1) diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt index 93f62be9b6..f31d00494a 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/PathConstraints.kt @@ -11,10 +11,11 @@ import org.usvm.UIsSupertypeExpr import org.usvm.UNotExpr import org.usvm.UOrExpr import org.usvm.USymbolicHeapRef +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.isStaticHeapRef import org.usvm.isSymbolicHeapRef import org.usvm.merging.MutableMergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable import org.usvm.solver.UExprTranslator import org.usvm.uctx @@ -23,35 +24,48 @@ import org.usvm.uctx */ open class UPathConstraints( protected val ctx: UContext<*>, + protected open var ownership: MutabilityOwnership, protected val logicalConstraints: ULogicalConstraints = ULogicalConstraints.empty(), /** * Specially represented equalities and disequalities between objects, used in various part of constraints management. */ - protected val equalityConstraints: UEqualityConstraints = UEqualityConstraints(ctx), + protected val equalityConstraints: UEqualityConstraints = UEqualityConstraints(ctx, ownership), /** * Constraints solved by type solver. */ val typeConstraints: UTypeConstraints = UTypeConstraints( + ownership, ctx.typeSystem(), equalityConstraints ), /** * Specially represented numeric constraints (e.g. >, <, >=, ...). */ - protected val numericConstraints: UNumericConstraints = UNumericConstraints(ctx, sort = ctx.bv32Sort), -) : UMergeable, MutableMergeGuard> { + protected val numericConstraints: UNumericConstraints = + UNumericConstraints(ctx, sort = ctx.bv32Sort, ownership) +) : UOwnedMergeable, MutableMergeGuard> { init { // Use the information from the type constraints to check whether any static ref is assignable to any symbolic ref equalityConstraints.setTypesCheck(typeConstraints::canStaticRefBeEqualToSymbolic) } + /** + * Recursively changes ownership for all nested data structures that use persistent maps. + */ + fun changeOwnership(ownership: MutabilityOwnership) { + this.ownership = ownership + numericConstraints.changeOwnership(ownership) + equalityConstraints.changeOwnership(ownership) + typeConstraints.changeOwnership(ownership) + } + /** * Constraints solved by SMT solver. */ val softConstraintsSourceSequence: Sequence get() = logicalConstraints.asSequence() + numericConstraints.constraints() - constructor(ctx: UContext<*>) : this(ctx, ULogicalConstraints.empty()) + constructor(ctx: UContext<*>, ownership: MutabilityOwnership) : this(ctx, ownership, ULogicalConstraints.empty()) val isFalse: Boolean get() = equalityConstraints.isContradicting || @@ -150,23 +164,28 @@ open class UPathConstraints( notConstraint is UOrExpr -> notConstraint.args.forEach { plusAssign(ctx.mkNot(it)) } - else -> logicalConstraints += constraint + else -> logicalConstraints.add(constraint, ownership) } } logicalConstraints.contains(constraint.not()) -> contradiction(ctx) - else -> logicalConstraints += constraint + else -> logicalConstraints.add(constraint, ownership) } } - open fun clone(): UPathConstraints { + open fun clone( + thisOwnership: MutabilityOwnership = MutabilityOwnership(), + cloneOwnership: MutabilityOwnership = MutabilityOwnership(), // ownerships must be fresh new because of plus assign operations + ): UPathConstraints { val clonedLogicalConstraints = logicalConstraints.clone() - val clonedEqualityConstraints = equalityConstraints.clone() - val clonedTypeConstraints = typeConstraints.clone(clonedEqualityConstraints) - val clonedNumericConstraints = numericConstraints.clone() + val clonedEqualityConstraints = equalityConstraints.clone(thisOwnership, cloneOwnership) + val clonedTypeConstraints = typeConstraints.clone(clonedEqualityConstraints, thisOwnership, cloneOwnership) + val clonedNumericConstraints = numericConstraints.clone(thisOwnership, cloneOwnership) + this.ownership = thisOwnership return UPathConstraints( ctx = ctx, + ownership = cloneOwnership, logicalConstraints = clonedLogicalConstraints, equalityConstraints = clonedEqualityConstraints, typeConstraints = clonedTypeConstraints, @@ -175,7 +194,7 @@ open class UPathConstraints( } private fun contradiction(ctx: UContext<*>) { - logicalConstraints.contradiction(ctx) + logicalConstraints.contradiction(ctx, ownership) } /** @@ -192,17 +211,30 @@ open class UPathConstraints( * * @return the merged path constraints. */ - override fun mergeWith(other: UPathConstraints, by: MutableMergeGuard): UPathConstraints? { + override fun mergeWith( + other: UPathConstraints, + by: MutableMergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership, + ): UPathConstraints? { // TODO: elaborate on some merge parameters here - val mergedLogicalConstraints = logicalConstraints.mergeWith(other.logicalConstraints, by) - val mergedEqualityConstraints = equalityConstraints.mergeWith(other.equalityConstraints, by) ?: return null + val mergedLogicalConstraints = + logicalConstraints.mergeWith(other.logicalConstraints, by, thisOwnership, otherOwnership, mergedOwnership) + val mergedEqualityConstraints = + equalityConstraints.mergeWith(other.equalityConstraints, by, thisOwnership, otherOwnership, mergedOwnership) + ?: return null val mergedTypeConstraints = typeConstraints - .clone(mergedEqualityConstraints) - .mergeWith(other.typeConstraints, by) ?: return null - val mergedNumericConstraints = numericConstraints.mergeWith(other.numericConstraints, by) + .clone(mergedEqualityConstraints, thisOwnership, otherOwnership) + .mergeWith(other.typeConstraints, by, thisOwnership, otherOwnership, mergedOwnership) ?: return null + val mergedNumericConstraints = + numericConstraints.mergeWith(other.numericConstraints, by, thisOwnership, otherOwnership, mergedOwnership) + this.changeOwnership(thisOwnership) + other.changeOwnership(otherOwnership) return UPathConstraints( ctx, + mergedOwnership, mergedLogicalConstraints, mergedEqualityConstraints, mergedTypeConstraints, diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/TypeConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/TypeConstraints.kt index 0cf8921575..924a03d413 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/TypeConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/TypeConstraints.kt @@ -1,19 +1,22 @@ package org.usvm.constraints -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UConcreteHeapRef import org.usvm.UContext import org.usvm.UHeapRef -import org.usvm.UNullRef import org.usvm.USymbolicHeapRef +import org.usvm.UNullRef +import org.usvm.collections.immutable.getOrDefault import org.usvm.isStatic import org.usvm.isStaticHeapRef +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.toMutableMap import org.usvm.memory.mapWithStaticAsConcrete import org.usvm.merging.MutableMergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable import org.usvm.solver.UExprTranslator import org.usvm.types.USingleTypeStream import org.usvm.types.UTypeRegion @@ -21,6 +24,7 @@ import org.usvm.types.UTypeStream import org.usvm.types.UTypeSystem import org.usvm.uctx + interface UTypeEvaluator { /** @@ -51,20 +55,26 @@ interface UTypeEvaluator { * precisely, thus we can evaluate the subtyping constraints for them concretely (modulo generic type variables). */ class UTypeConstraints( + private var ownership: MutabilityOwnership, private val typeSystem: UTypeSystem, private val equalityConstraints: UEqualityConstraints, - private var concreteRefToType: PersistentMap = persistentHashMapOf(), - private var symbolicRefToTypeRegion: PersistentMap> = persistentHashMapOf(), -) : UTypeEvaluator, UMergeable, MutableMergeGuard> { + private var concreteRefToType: UPersistentHashMap = persistentHashMapOf(), + private var symbolicRefToTypeRegion: UPersistentHashMap> = persistentHashMapOf(), +) : UTypeEvaluator, UOwnedMergeable, MutableMergeGuard> { private val ctx: UContext<*> get() = equalityConstraints.ctx + fun changeOwnership(ownership: MutabilityOwnership) { + this.ownership = ownership + } + init { equalityConstraints.subscribeEquality(::intersectRegions) } - val inputRefToTypeRegion: Map> get(): Map> { - val inputTypeRegions: MutableMap> = symbolicRefToTypeRegion.toMutableMap() + @Suppress("UNCHECKED_CAST") + val inputTypeRegions: MutableMap> = + symbolicRefToTypeRegion.toMutableMap() as MutableMap> // Add all static refs for ((address, type) in concreteRefToType) { @@ -99,7 +109,7 @@ class UTypeConstraints( * Binds concrete heap address [ref] to its [type]. */ fun allocate(ref: UConcreteHeapAddress, type: Type) { - concreteRefToType = concreteRefToType.put(ref, type) + concreteRefToType = concreteRefToType.put(ref, type, ownership) equalityConstraints.updateDisequality(ctx.mkConcreteHeapRef(ref)) } @@ -146,7 +156,7 @@ class UTypeConstraints( } ?: topTypeRegion } - return symbolicRefToTypeRegion[representative] ?: topTypeRegion + return symbolicRefToTypeRegion.getOrDefault(representative as USymbolicHeapRef, topTypeRegion) } @@ -157,7 +167,7 @@ class UTypeConstraints( return } - symbolicRefToTypeRegion = symbolicRefToTypeRegion.put(representative as USymbolicHeapRef, value) + symbolicRefToTypeRegion = symbolicRefToTypeRegion.put(representative as USymbolicHeapRef, value, ownership) } /** @@ -173,7 +183,7 @@ class UTypeConstraints( is UNullRef -> return is UConcreteHeapRef -> { - val concreteType = concreteRefToType.getValue(ref.address) + val concreteType = concreteRefToType[ref.address]!! if (!typeSystem.isSupertype(supertype, concreteType)) { contradiction() } @@ -200,7 +210,7 @@ class UTypeConstraints( is UNullRef -> contradiction() // the [ref] can't be equal to null is UConcreteHeapRef -> { - val concreteType = concreteRefToType.getValue(ref.address) + val concreteType = concreteRefToType[ref.address]!! if (typeSystem.isSupertype(supertype, concreteType)) { contradiction() } @@ -227,7 +237,7 @@ class UTypeConstraints( is UNullRef -> contradiction() is UConcreteHeapRef -> { - val concreteType = concreteRefToType.getValue(ref.address) + val concreteType = concreteRefToType[ref.address]!! if (!typeSystem.isSupertype(concreteType, subtype)) { contradiction() } @@ -255,7 +265,7 @@ class UTypeConstraints( is UNullRef -> return is UConcreteHeapRef -> { - val concreteType = concreteRefToType.getValue(ref.address) + val concreteType = concreteRefToType[ref.address]!! if (typeSystem.isSupertype(concreteType, subtype)) { contradiction() } @@ -303,8 +313,8 @@ class UTypeConstraints( to is UConcreteHeapRef && from is UConcreteHeapRef -> { // For both concrete refs we need to check types are the same - val toType = concreteRefToType.getValue(to.address) - val fromType = concreteRefToType.getValue(from.address) + val toType = concreteRefToType[to.address] + val fromType = concreteRefToType[from.address] if (toType != fromType) { contradiction() @@ -313,7 +323,7 @@ class UTypeConstraints( to is UConcreteHeapRef -> { // Here we have a pair of symbolic-concrete refs - val concreteToType = concreteRefToType.getValue(to.address) + val concreteToType = concreteRefToType[to.address]!! val symbolicFromType = getTypeRegion(from as USymbolicHeapRef, useRepresentative = false) if (symbolicFromType.addSupertype(concreteToType).isEmpty) { @@ -323,7 +333,7 @@ class UTypeConstraints( from is UConcreteHeapRef -> { // Here to is symbolic and from is concrete - val concreteType = concreteRefToType.getValue(from.address) + val concreteType = concreteRefToType[from.address]!! val symbolicType = getTypeRegion(to as USymbolicHeapRef) // We need to set only the concrete type instead of all these symbolic types - make it using both subtype/supertype val regionFromConcreteType = symbolicType.addSubtype(concreteType).addSupertype(concreteType) @@ -350,7 +360,7 @@ class UTypeConstraints( contradiction() return } - for ((key, value) in symbolicRefToTypeRegion.entries) { + for ((key, value) in symbolicRefToTypeRegion) { // TODO: cache intersections? if (key != ref && value.intersect(newRegion).isEmpty) { // If we have two inputs of incomparable reference types, then they are non equal @@ -372,7 +382,7 @@ class UTypeConstraints( if (newRegion.isEmpty) { equalityConstraints.makeEqual(ref, ref.uctx.nullRef) } - for ((key, value) in symbolicRefToTypeRegion.entries) { + for ((key, value) in symbolicRefToTypeRegion) { // TODO: cache intersections? if (key != ref && value.intersect(newRegion).isEmpty) { // If we have two inputs of incomparable reference types, then they are non equal or both null @@ -388,7 +398,7 @@ class UTypeConstraints( override fun evalIsSubtype(ref: UHeapRef, supertype: Type): UBoolExpr = ref.mapWithStaticAsConcrete( concreteMapper = { concreteRef -> - val concreteType = concreteRefToType.getValue(concreteRef.address) + val concreteType = concreteRefToType[concreteRef.address]!! if (typeSystem.isSupertype(supertype, concreteType)) { concreteRef.ctx.trueExpr } else { @@ -414,7 +424,7 @@ class UTypeConstraints( override fun evalIsSupertype(ref: UHeapRef, subtype: Type): UBoolExpr = ref.mapWithStaticAsConcrete( concreteMapper = { concreteRef -> - val concreteType = concreteRefToType.getValue(concreteRef.address) + val concreteType = concreteRefToType[concreteRef.address]!! if (typeSystem.isSupertype(concreteType, subtype)) { concreteRef.ctx.trueExpr } else { @@ -440,13 +450,17 @@ class UTypeConstraints( /** * Creates a mutable copy of these constraints connected to new instance of [equalityConstraints]. */ - fun clone(equalityConstraints: UEqualityConstraints) = - UTypeConstraints( - typeSystem, - equalityConstraints, - concreteRefToType, - symbolicRefToTypeRegion - ) + fun clone( + equalityConstraints: UEqualityConstraints, + thisOwnership: MutabilityOwnership, + cloneOwnership: MutabilityOwnership + ) = UTypeConstraints( + cloneOwnership, + typeSystem, + equalityConstraints, + concreteRefToType, + symbolicRefToTypeRegion + ).also { this.ownership = thisOwnership } /** * Check if this [UTypeConstraints] can be merged with [other] type constraints. @@ -456,13 +470,21 @@ class UTypeConstraints( * * @return the merged type constraints. */ - override fun mergeWith(other: UTypeConstraints, by: MutableMergeGuard): UTypeConstraints? { + override fun mergeWith( + other: UTypeConstraints, + by: MutableMergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership + ): UTypeConstraints? { // TODO: should we check equality constraints? if (symbolicRefToTypeRegion != other.symbolicRefToTypeRegion) { return null } - val mergedConcreteRefs = concreteRefToType.builder().apply { putAll(other.concreteRefToType) }.build() - return UTypeConstraints(typeSystem, equalityConstraints, mergedConcreteRefs, symbolicRefToTypeRegion) + val mergedConcreteRefs = concreteRefToType.putAll(other.concreteRefToType, mergedOwnership) + this.ownership = thisOwnership + other.ownership = otherOwnership + return UTypeConstraints(mergedOwnership, typeSystem, equalityConstraints, mergedConcreteRefs, symbolicRefToTypeRegion) } diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/ULogicalConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/ULogicalConstraints.kt index 62c0680659..04853d4556 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/ULogicalConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/ULogicalConstraints.kt @@ -1,25 +1,29 @@ package org.usvm.constraints -import kotlinx.collections.immutable.PersistentSet -import kotlinx.collections.immutable.persistentHashSetOf -import kotlinx.collections.immutable.persistentSetOf +import io.ksmt.expr.KExpr +import org.usvm.collections.immutable.persistentHashSetOf import org.usvm.UBoolExpr +import org.usvm.UBoolSort import org.usvm.UContext import org.usvm.algorithms.separate +import org.usvm.collections.immutable.containsAll +import org.usvm.collections.immutable.implementations.immutableSet.UPersistentHashSet +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.isEmpty import org.usvm.isFalse import org.usvm.merging.MutableMergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable class ULogicalConstraints private constructor( - private var constraints: PersistentSet, -) : Set, UMergeable { - operator fun plusAssign(expr: UBoolExpr) { - constraints = constraints.add(expr) + private var constraints: UPersistentHashSet, +) : Set, UOwnedMergeable { + fun add(expr: UBoolExpr, ownership: MutabilityOwnership) { + constraints = constraints.add(expr, ownership) } fun clone(): ULogicalConstraints = ULogicalConstraints(constraints) override val size: Int - get() = constraints.size + get() = constraints.calculateSize() override fun isEmpty(): Boolean = constraints.isEmpty() @@ -32,8 +36,8 @@ class ULogicalConstraints private constructor( val isContradicting: Boolean get() = constraints.any(UBoolExpr::isFalse) - fun contradiction(ctx: UContext<*>) { - constraints = persistentHashSetOf(ctx.falseExpr) + fun contradiction(ctx: UContext<*>, ownership: MutabilityOwnership) { + constraints = persistentHashSetOf().add(ctx.falseExpr, ownership) } /** @@ -43,8 +47,14 @@ class ULogicalConstraints private constructor( * * @return the logical constraints. */ - override fun mergeWith(other: ULogicalConstraints, by: MutableMergeGuard): ULogicalConstraints { - val (overlap, uniqueThis, uniqueOther) = constraints.separate(other.constraints) + override fun mergeWith( + other: ULogicalConstraints, + by: MutableMergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership + ): ULogicalConstraints { + val (overlap, uniqueThis, uniqueOther) = constraints.separate(other.constraints, mergedOwnership) by.appendThis(uniqueThis.asSequence()) by.appendOther(uniqueOther.asSequence()) return ULogicalConstraints(overlap) @@ -53,4 +63,4 @@ class ULogicalConstraints private constructor( companion object { fun empty() = ULogicalConstraints(persistentHashSetOf()) } -} \ No newline at end of file +} diff --git a/usvm-core/src/main/kotlin/org/usvm/constraints/UNumericConstraints.kt b/usvm-core/src/main/kotlin/org/usvm/constraints/UNumericConstraints.kt index ae841b47ca..6e1f00521e 100644 --- a/usvm-core/src/main/kotlin/org/usvm/constraints/UNumericConstraints.kt +++ b/usvm-core/src/main/kotlin/org/usvm/constraints/UNumericConstraints.kt @@ -25,17 +25,19 @@ import io.ksmt.utils.BvUtils.signedLess import io.ksmt.utils.BvUtils.signedLessOrEqual import io.ksmt.utils.asExpr import io.ksmt.utils.uncheckedCast -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.PersistentSet -import kotlinx.collections.immutable.persistentHashMapOf -import kotlinx.collections.immutable.persistentHashSetOf import org.usvm.UBoolExpr import org.usvm.UBvSort import org.usvm.UContext import org.usvm.UExpr +import org.usvm.algorithms.UPersistentMultiMap +import org.usvm.algorithms.addToSet +import org.usvm.algorithms.removeValue import org.usvm.algorithms.separate +import org.usvm.collections.immutable.* +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.merging.MutableMergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable import org.usvm.regions.IntIntervalsRegion import org.usvm.solver.UExprTranslator @@ -60,13 +62,21 @@ private typealias ConstraintTerms = UExpr class UNumericConstraints private constructor( private val ctx: UContext<*>, val sort: Sort, - persistentNumericConstraints: PersistentMap, Constraint>, - persistentConstraintWatchList: PersistentMap, PersistentSet>>, -) : UMergeable, MutableMergeGuard> { - constructor(ctx: UContext<*>, sort: Sort) : this(ctx, sort, persistentHashMapOf(), persistentHashMapOf()) + private var ownership: MutabilityOwnership, + private var numericConstraints: UPersistentHashMap, Constraint>, + private var constraintWatchList: UPersistentMultiMap, ConstraintTerms>, +) : UOwnedMergeable, MutableMergeGuard> { + constructor(ctx: UContext<*>, sort: Sort, ownership: MutabilityOwnership) : this( + ctx, + sort, + ownership, + persistentHashMapOf(), + persistentHashMapOf() + ) - private val numericConstraints = persistentNumericConstraints.builder() - private val constraintWatchList = persistentConstraintWatchList.builder() + fun changeOwnership(ownership: MutabilityOwnership) { + this.ownership = ownership + } private val constraintPropagationQueue = arrayListOf>() @@ -104,7 +114,7 @@ class UNumericConstraints private constructor( return sequenceOf(ctx.falseExpr) } - return numericConstraints.entries.asSequence() + return numericConstraints.asSequence() .flatMap { it.value.mkExpressions() } } @@ -269,14 +279,13 @@ class UNumericConstraints private constructor( private val KBitVecValue.intValue: Int get() = (this as KBitVec32Value).intValue - fun clone(): UNumericConstraints { + fun clone(thisOwnership: MutabilityOwnership, cloneOwnership: MutabilityOwnership): UNumericConstraints { if (this.isContradicting) { return this } - return UNumericConstraints( - ctx, sort, numericConstraints.build(), constraintWatchList.build() - ) + this.ownership = thisOwnership + return UNumericConstraints(ctx, sort, cloneOwnership, numericConstraints, constraintWatchList) } private fun constraintUpdated(update: ConstraintUpdateEvent) { @@ -352,7 +361,7 @@ class UNumericConstraints private constructor( ) private fun updateConstraint(constraint: Constraint) { - numericConstraints[constraint.constrainedTerms] = constraint + numericConstraints = numericConstraints.put(constraint.constrainedTerms, constraint, ownership) } private fun constraintAddDependency(terms: ConstraintTerms, dependency: ConstraintTerms) { @@ -370,12 +379,12 @@ class UNumericConstraints private constructor( } val currentWatchList = watchList ?: persistentHashSetOf() - val updatedWatchList = currentWatchList.add(terms) + val updatedWatchList = currentWatchList.add(terms, ownership) if (updatedWatchList === watchList) { return } - constraintWatchList[dependency] = updatedWatchList + constraintWatchList = constraintWatchList.put(dependency, updatedWatchList, ownership) } private fun propagateConstraints() { @@ -934,7 +943,7 @@ class UNumericConstraints private constructor( val current = concreteLowerBounds[bias] if (current != null && current.value.signedGreaterOrEqual(bound)) return this val isPrimary = current?.isPrimary ?: false - return modifyConcreteLowerBounds(bias, ValueConstraint(bound, isPrimary)) + return modifyConcreteLowerBounds(bias, ValueConstraint(bound, isPrimary), ownership) } private fun BoundsConstraint.refineUpperBound( @@ -944,7 +953,7 @@ class UNumericConstraints private constructor( val current = concreteUpperBounds[bias] if (current != null && current.value.signedLessOrEqual(bound)) return this val isPrimary = current?.isPrimary ?: false - return modifyConcreteUpperBounds(bias, ValueConstraint(bound, isPrimary)) + return modifyConcreteUpperBounds(bias, ValueConstraint(bound, isPrimary), ownership) } private fun BoundsConstraint.updateConcreteLowerBound( @@ -967,7 +976,7 @@ class UNumericConstraints private constructor( // Replace with primary constraint. Constraint value remains unchanged if (it.value == value) { return if (isPrimary && !it.isPrimary) { - modifyConcreteLowerBounds(bias, ValueConstraint(value, isPrimary = true)) + modifyConcreteLowerBounds(bias, ValueConstraint(value, isPrimary = true), ownership) } else { this } @@ -997,7 +1006,7 @@ class UNumericConstraints private constructor( // Replace with primary constraint. Constraint value remains unchanged if (it.value == value) { return if (isPrimary && !it.isPrimary) { - modifyConcreteUpperBounds(bias, ValueConstraint(value, isPrimary = true)) + modifyConcreteUpperBounds(bias, ValueConstraint(value, isPrimary = true), ownership) } else { this } @@ -1035,7 +1044,7 @@ class UNumericConstraints private constructor( return this } - return modifyConcreteDisequalitites(bias, ValueConstraint(value, isPrimary)) + return modifyConcreteDisequalitites(bias, ValueConstraint(value, isPrimary), ownership) } private fun BoundsConstraint.excludedPoints( @@ -1148,7 +1157,7 @@ class UNumericConstraints private constructor( } val constraint = TermsConstraint(rhs.constrainedTerms, rhsBias, isStrict = true) - val modifiedDiseq = termDisequalities.addTermConstraint(lhsBias, constraint) + val modifiedDiseq = termDisequalities.addTermConstraint(lhsBias, constraint, ownership) return modifyTermDisequalities(modifiedDiseq) } @@ -1160,17 +1169,17 @@ class UNumericConstraints private constructor( var updatedReplacement = replacement // this + bias >= bound <=> replacement + (bias - replacementBias) >= bound - updatedReplacement = concreteLowerBounds.entries.fold(updatedReplacement) { result, (bias, bound) -> + updatedReplacement = concreteLowerBounds.fold(updatedReplacement) { result, (bias, bound) -> result.addConcreteLowerBound(sub(bias, replacementBias), bound.value, bound.isPrimary) } // this + bias <= bound <=> replacement + (bias - replacementBias) <= bound - updatedReplacement = concreteUpperBounds.entries.fold(updatedReplacement) { result, (bias, bound) -> + updatedReplacement = concreteUpperBounds.fold(updatedReplacement) { result, (bias, bound) -> result.addConcreteUpperBound(sub(bias, replacementBias), bound.value, bound.isPrimary) } // this + bias != bound <=> replacement + (bias - replacementBias) != bound - updatedReplacement = concreteDisequalitites.entries.fold(updatedReplacement) { result, (bias, bound) -> + updatedReplacement = concreteDisequalitites.fold(updatedReplacement) { result, (bias, bound) -> result.addConcreteDisequality(sub(bias, replacementBias), bound.value, bound.isPrimary) } @@ -1195,7 +1204,8 @@ class UNumericConstraints private constructor( updateConstraint(updatedReplacement) val dependencies = constraintWatchList[constrainedTerms] - dependencies?.forEach { dependentTerms -> + // toList fixes [dependencies] because it can be mutated in foreach body + dependencies?.toList()?.forEach { dependentTerms -> withConstraint( terms = dependentTerms, bounds = { dependencyConstraint, _ -> @@ -1276,7 +1286,8 @@ class UNumericConstraints private constructor( ) val dependencies = constraintWatchList[constrainedTerms] - dependencies?.forEach { dependentTerms -> + // toList fixes [dependencies] because it can be mutated in foreach body + dependencies?.toList()?.forEach { dependentTerms -> withConstraint( terms = dependentTerms, bounds = { dependencyConstraint, _ -> @@ -1387,9 +1398,8 @@ class UNumericConstraints private constructor( Boolean, ) -> BoundsConstraint, ): BoundsConstraint { - var result = updateBounds(initialConstraint, bounds.dropTermsConstraints(terms)) - - val constraints = bounds.termDependency[terms] ?: emptySet() + val constraints = bounds.termDependency.getOrDefault(terms, persistentHashSetOf()) + var result = initialConstraint for (constraint in constraints) { val biasedConstraint = add(termsValue, constraint.bias) val biases = bounds.termConstraints[constraint] ?: continue @@ -1400,6 +1410,7 @@ class UNumericConstraints private constructor( } } + result = updateBounds(result, bounds.dropTermsConstraints(terms, ownership)) return result } @@ -1418,10 +1429,9 @@ class UNumericConstraints private constructor( Boolean, ) -> BoundsConstraint, ): BoundsConstraint { - var result = updateBounds(initialConstraint, bounds.dropTermsConstraints(terms)) - - val constraints = bounds.termDependency[terms] ?: emptySet() - for (constraint in constraints) { + val constraints = bounds.termDependency.getOrDefault(terms, persistentHashSetOf()) + var result = initialConstraint + for (constraint in constraints.toList()) { val biases = bounds.termConstraints[constraint] ?: continue for (bias in biases) { // this + bias (op) terms + constraint.bias && terms = replacement + replacementBias @@ -1434,6 +1444,7 @@ class UNumericConstraints private constructor( } } + result = updateBounds(result, bounds.dropTermsConstraints(terms, ownership)) return result } @@ -1451,7 +1462,7 @@ class UNumericConstraints private constructor( ): BoundsConstraint { var result = target - for ((constraint, biases) in bounds.termConstraints) { + for ((constraint, biases) in bounds.termConstraints.toList()) { // this + bias (op) constraint <=> target + (bias - targetBias) (op) constraint for (bias in biases) { result = addConstraint( @@ -1471,7 +1482,7 @@ class UNumericConstraints private constructor( BoundsConstraint, KBitVecValue, KBitVecValue, Boolean, ) -> BoundsConstraint, ) { - for ((constraint, biases) in bounds.termConstraints) { + for ((constraint, biases) in bounds.termConstraints.toList()) { withConstraint( terms = constraint.terms, bounds = { boundsConstraint, initialConstraintBias -> @@ -1555,7 +1566,7 @@ class UNumericConstraints private constructor( boundsConstraint.inferredTermLowerBounds.findBiasesWithConstraint(constraint) }, removeConstraintForBias = { bc, biasToRemove -> - val modifiedBounds = bc.inferredTermLowerBounds.removeTermConstraint(biasToRemove, constraint) + val modifiedBounds = bc.inferredTermLowerBounds.removeTermConstraint(biasToRemove, constraint, ownership) bc.modifyTermLowerBounds(modifiedBounds) }, removeOppositeConstraintForBias = { bc, biasToRemove -> @@ -1564,7 +1575,7 @@ class UNumericConstraints private constructor( biasToRemove, constraint.isStrict ) - bc.removeTermUpperBound(constraint.bias, oppositeConstraint) + bc.removeTermUpperBound(constraint.bias, oppositeConstraint, ownership) }, cont = { cont(it) } ) @@ -1610,7 +1621,7 @@ class UNumericConstraints private constructor( boundsConstraint.termUpperBounds.findBiasesWithConstraint(constraint) }, removeConstraintForBias = { bc, biasToRemove -> - val modifiedBounds = bc.termUpperBounds.removeTermConstraint(biasToRemove, constraint) + val modifiedBounds = bc.termUpperBounds.removeTermConstraint(biasToRemove, constraint, ownership) bc.modifyTermUpperBounds(modifiedBounds) }, removeOppositeConstraintForBias = { bc, biasToRemove -> @@ -1619,7 +1630,7 @@ class UNumericConstraints private constructor( biasToRemove, constraint.isStrict ) - bc.removeTermLowerBound(constraint.bias, oppositeConstraint) + bc.removeTermLowerBound(constraint.bias, oppositeConstraint, ownership) }, cont = { cont(it) } ) @@ -1640,7 +1651,7 @@ class UNumericConstraints private constructor( .mapNotNull { (bias, c) -> bias.takeIf { c == constraint } } }, removeConstraintForBias = { bc, biasToRemove -> - bc.removeConcreteUpperBound(biasToRemove) + bc.removeConcreteUpperBound(biasToRemove, ownership) }, removeOppositeConstraintForBias = { bc, _ -> bc }, cont = { cont(it) } @@ -1662,7 +1673,7 @@ class UNumericConstraints private constructor( .mapNotNull { (bias, c) -> bias.takeIf { c == constraint } } }, removeConstraintForBias = { bc, biasToRemove -> - bc.removeConcreteLowerBound(biasToRemove) + bc.removeConcreteLowerBound(biasToRemove, ownership) }, removeOppositeConstraintForBias = { bc, _ -> bc }, cont = { cont(it) } @@ -1757,7 +1768,7 @@ class UNumericConstraints private constructor( ): BoundsConstraint = eliminateTermLowerBound(this, bias, rhs, constraint, rhsLB) { boundsConstraint -> constraintAddDependency(boundsConstraint.constrainedTerms, constraint.terms) - val updatedBounds = boundsConstraint.inferredTermLowerBounds.addTermConstraint(bias, constraint) + val updatedBounds = boundsConstraint.inferredTermLowerBounds.addTermConstraint(bias, constraint, ownership) val result = boundsConstraint.modifyTermLowerBounds(updatedBounds) postProcessConstraint(result) } @@ -1771,7 +1782,7 @@ class UNumericConstraints private constructor( ): BoundsConstraint = eliminateTermUpperBound(this, bias, rhs, constraint, rhsUB) { boundsConstraint -> constraintAddDependency(boundsConstraint.constrainedTerms, constraint.terms) - val updatedBounds = boundsConstraint.termUpperBounds.addTermConstraint(bias, constraint) + val updatedBounds = boundsConstraint.termUpperBounds.addTermConstraint(bias, constraint, ownership) val result = boundsConstraint.modifyTermUpperBounds(updatedBounds) postProcessConstraint(result) } @@ -1784,7 +1795,7 @@ class UNumericConstraints private constructor( ) { boundsConstraint -> constraintUpdated(ConstraintUpdateEvent(constrainedTerms, bias, BoundsUpdateKind.UPPER)) return boundsConstraint - .modifyConcreteUpperBounds(bias, constraint) + .modifyConcreteUpperBounds(bias, constraint, ownership) .refineGroundConstraint(bias) } @@ -1796,7 +1807,7 @@ class UNumericConstraints private constructor( ) { boundsConstraint -> constraintUpdated(ConstraintUpdateEvent(constrainedTerms, bias, BoundsUpdateKind.LOWER)) return boundsConstraint - .modifyConcreteLowerBounds(bias, constraint) + .modifyConcreteLowerBounds(bias, constraint, ownership) .refineGroundConstraint(bias) } @@ -1940,9 +1951,9 @@ class UNumericConstraints private constructor( * */ class BoundsConstraint( constrainedTerms: ConstraintTerms, - val concreteLowerBounds: PersistentMap, ValueConstraint>, - val concreteUpperBounds: PersistentMap, ValueConstraint>, - val concreteDisequalitites: PersistentMap, ValueConstraint>, + val concreteLowerBounds: UPersistentHashMap, ValueConstraint>, + val concreteUpperBounds: UPersistentHashMap, ValueConstraint>, + val concreteDisequalitites: UPersistentHashMap, ValueConstraint>, val inferredTermLowerBounds: TermConstraintSet, val termUpperBounds: TermConstraintSet, val termDisequalities: TermConstraintSet, @@ -1958,9 +1969,9 @@ class UNumericConstraints private constructor( ) fun size(): Int = - inferredTermLowerBounds.termConstraints.size + - termUpperBounds.termConstraints.size + - termDisequalities.termConstraints.size + inferredTermLowerBounds.size + + termUpperBounds.size + + termDisequalities.size fun lowerBound(bias: KBitVecValue): ValueConstraint? = concreteLowerBounds[bias] @@ -1971,8 +1982,9 @@ class UNumericConstraints private constructor( fun modifyConcreteLowerBounds( bias: KBitVecValue, bound: ValueConstraint, + ownership: MutabilityOwnership, ): BoundsConstraint { - val modified = concreteLowerBounds.put(bias, bound) + val modified = concreteLowerBounds.put(bias, bound, ownership) if (modified === this.concreteLowerBounds) { return this } @@ -1986,8 +1998,9 @@ class UNumericConstraints private constructor( fun modifyConcreteUpperBounds( bias: KBitVecValue, bound: ValueConstraint, + ownership: MutabilityOwnership, ): BoundsConstraint { - val modified = concreteUpperBounds.put(bias, bound) + val modified = concreteUpperBounds.put(bias, bound, ownership) if (modified === this.concreteUpperBounds) { return this } @@ -1998,8 +2011,8 @@ class UNumericConstraints private constructor( ) } - fun removeConcreteUpperBound(bias: KBitVecValue): BoundsConstraint { - val modified = concreteUpperBounds.remove(bias) + fun removeConcreteUpperBound(bias: KBitVecValue, ownership: MutabilityOwnership): BoundsConstraint { + val modified = concreteUpperBounds.remove(bias, ownership) if (modified === this.concreteUpperBounds) { return this } @@ -2010,8 +2023,8 @@ class UNumericConstraints private constructor( ) } - fun removeConcreteLowerBound(bias: KBitVecValue): BoundsConstraint { - val modified = concreteLowerBounds.remove(bias) + fun removeConcreteLowerBound(bias: KBitVecValue, ownership: MutabilityOwnership): BoundsConstraint { + val modified = concreteLowerBounds.remove(bias, ownership) if (modified === this.concreteLowerBounds) { return this } @@ -2025,8 +2038,9 @@ class UNumericConstraints private constructor( fun modifyConcreteDisequalitites( bias: KBitVecValue, bound: ValueConstraint, + ownership: MutabilityOwnership, ): BoundsConstraint { - val modified = concreteDisequalitites.put(bias, bound) + val modified = concreteDisequalitites.put(bias, bound, ownership) if (modified === this.concreteDisequalitites) { return this } @@ -2070,13 +2084,21 @@ class UNumericConstraints private constructor( ) } - fun removeTermLowerBound(bias: KBitVecValue, constraint: TermsConstraint): BoundsConstraint { - val updatedBounds = inferredTermLowerBounds.removeTermConstraint(bias, constraint) + fun removeTermLowerBound( + bias: KBitVecValue, + constraint: TermsConstraint, + ownership: MutabilityOwnership, + ): BoundsConstraint { + val updatedBounds = inferredTermLowerBounds.removeTermConstraint(bias, constraint, ownership) return modifyTermLowerBounds(updatedBounds) } - fun removeTermUpperBound(bias: KBitVecValue, constraint: TermsConstraint): BoundsConstraint { - val updatedBounds = termUpperBounds.removeTermConstraint(bias, constraint) + fun removeTermUpperBound( + bias: KBitVecValue, + constraint: TermsConstraint, + ownership: MutabilityOwnership, + ): BoundsConstraint { + val updatedBounds = termUpperBounds.removeTermConstraint(bias, constraint, ownership) return modifyTermUpperBounds(updatedBounds) } @@ -2117,7 +2139,7 @@ class UNumericConstraints private constructor( } private inline fun mapPrimaryConcrete( - concrete: PersistentMap, ValueConstraint>, + concrete: UPersistentHashMap, ValueConstraint>, crossinline body: (KBitVecValue, UExpr) -> T, ): Sequence = concrete.asSequence().mapNotNull { (bias, constraint) -> @@ -2163,14 +2185,20 @@ class UNumericConstraints private constructor( } class TermConstraintSet( - val termConstraints: PersistentMap, PersistentSet>>, - val termDependency: PersistentMap, PersistentSet>>, + val termConstraints: UPersistentMultiMap, KBitVecValue>, + val termDependency: UPersistentMultiMap, TermsConstraint>, + val size: Int ) { - constructor() : this(persistentHashMapOf(), persistentHashMapOf()) + constructor() : this(persistentHashMapOf(), persistentHashMapOf(), 0) - fun addTermConstraint(bias: KBitVecValue, constraint: TermsConstraint): TermConstraintSet { - val constraints = termConstraints[constraint] ?: persistentHashSetOf() - val updatedConstraints = constraints.add(bias) + fun addTermConstraint( + bias: KBitVecValue, + constraint: TermsConstraint, + ownership: MutabilityOwnership, + ): TermConstraintSet { + var newSize = size + val constraints = termConstraints[constraint].also { if (it == null) newSize++ } ?: persistentHashSetOf() + val updatedConstraints = constraints.add(bias, ownership) if (updatedConstraints === constraints) { return this } @@ -2178,56 +2206,57 @@ class UNumericConstraints private constructor( val updatedTermDependency = if (constraints.isNotEmpty()) { termDependency } else { - val dependency = termDependency[constraint.terms] ?: persistentHashSetOf() - val updatedDependency = dependency.add(constraint) - termDependency.put(constraint.terms, updatedDependency) + termDependency.addToSet(constraint.terms, constraint, ownership) } return TermConstraintSet( - termConstraints.put(constraint, updatedConstraints), - updatedTermDependency + termConstraints.put(constraint, updatedConstraints, ownership), + updatedTermDependency, + newSize ) } - fun removeTermConstraint(bias: KBitVecValue, constraint: TermsConstraint): TermConstraintSet { + fun removeTermConstraint( + bias: KBitVecValue, + constraint: TermsConstraint, + ownership: MutabilityOwnership, + ): TermConstraintSet { val constraints = termConstraints[constraint] ?: return this - val updatedConstraints = constraints.remove(bias) + val updatedConstraints = constraints.remove(bias, ownership) if (updatedConstraints === constraints) { return this } if (updatedConstraints.isEmpty()) { - val dependency = termDependency[constraint.terms] ?: persistentHashSetOf() - val updatedDependency = dependency.remove(constraint) - val updatedTermDependency = if (updatedDependency.isEmpty()) { - termDependency.remove(constraint.terms) - } else { - termDependency.put(constraint.terms, updatedDependency) - } - return TermConstraintSet( - termConstraints.remove(constraint), - updatedTermDependency + termConstraints.remove(constraint, ownership), + termDependency.removeValue(constraint.terms, constraint, ownership), + size - 1 ) } return TermConstraintSet( - termConstraints.put(constraint, updatedConstraints), - termDependency + termConstraints.put(constraint, updatedConstraints, ownership), + termDependency, + size ) } - fun dropTermsConstraints(terms: ConstraintTerms): TermConstraintSet { + fun dropTermsConstraints(terms: ConstraintTerms, ownership: MutabilityOwnership): TermConstraintSet { val constraints = termDependency[terms] ?: return this var updatedConstraints = termConstraints + var updatedSize = size for (constraint in constraints) { - updatedConstraints = updatedConstraints.remove(constraint) + val (newUpdatedConstraints, hasChanged) = updatedConstraints.removeWithChangeInfo(constraint, ownership) + updatedConstraints = newUpdatedConstraints + if (hasChanged) updatedSize-- } return TermConstraintSet( updatedConstraints, - termDependency.remove(terms) + termDependency.remove(terms, ownership), + updatedSize ) } @@ -2414,17 +2443,25 @@ class UNumericConstraints private constructor( * * @return the numeric constraints. */ - override fun mergeWith(other: UNumericConstraints, by: MutableMergeGuard): UNumericConstraints { - val (overlap, thisUnique, otherUnique) = this.numericConstraints.build() - .separate(other.numericConstraints.build()) - - for (constraint in thisUnique.values) { - by.appendThis(constraint.mkExpressions()) - } - for (constraint in otherUnique.values) { - by.appendOther(constraint.mkExpressions()) - } - - return UNumericConstraints(ctx, sort, overlap, constraintWatchList.build()) + override fun mergeWith( + other: UNumericConstraints, + by: MutableMergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership, + ): UNumericConstraints { + val (overlap, thisUnique, otherUnique) = this.numericConstraints + .separate(other.numericConstraints, mergedOwnership) + + for (entry in thisUnique) { + by.appendThis(entry.value.mkExpressions()) + } + for (entry in otherUnique) { + by.appendOther(entry.value.mkExpressions()) + } + + this.ownership = thisOwnership + other.ownership = otherOwnership + return UNumericConstraints(ctx, sort, mergedOwnership, overlap, constraintWatchList) } } diff --git a/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt b/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt index 3fbdf42a2a..c73ea2ea01 100644 --- a/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt +++ b/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt @@ -125,7 +125,8 @@ inline fun foldHeapRef( ref is USymbolicHeapRef -> blockOnSymbolic(initial, ref with initialGuard) ref is UIteExpr -> { val (concreteHeapRefs, symbolicHeapRefs) = splitUHeapRef( - ref, initialGuard, + ref, + initialGuard, collapseHeapRefs = collapseHeapRefs, staticIsConcrete = staticIsConcrete ) diff --git a/usvm-core/src/main/kotlin/org/usvm/memory/Memory.kt b/usvm-core/src/main/kotlin/org/usvm/memory/Memory.kt index 3b704321fa..251ce742e4 100644 --- a/usvm-core/src/main/kotlin/org/usvm/memory/Memory.kt +++ b/usvm-core/src/main/kotlin/org/usvm/memory/Memory.kt @@ -1,7 +1,5 @@ package org.usvm.memory -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.INITIAL_CONCRETE_ADDRESS import org.usvm.INITIAL_STATIC_ADDRESS import org.usvm.UBoolExpr @@ -14,10 +12,14 @@ import org.usvm.UIndexedMocker import org.usvm.UMockEvaluator import org.usvm.UMocker import org.usvm.USort +import org.usvm.collections.immutable.getOrPut +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.constraints.UTypeConstraints import org.usvm.constraints.UTypeEvaluator import org.usvm.merging.MergeGuard -import org.usvm.merging.UMergeable +import org.usvm.merging.UOwnedMergeable interface UMemoryRegionId { val sort: Sort @@ -30,7 +32,7 @@ interface UReadOnlyMemoryRegion { } interface UMemoryRegion : UReadOnlyMemoryRegion { - fun write(key: Key, value: UExpr, guard: UBoolExpr): UMemoryRegion + fun write(key: Key, value: UExpr, guard: UBoolExpr, ownership: MutabilityOwnership): UMemoryRegion } interface ULValue { @@ -62,6 +64,7 @@ class UAddressCounter { } interface UReadOnlyMemory { + val ownership: MutabilityOwnership val stack: UReadOnlyRegistersStack val mocker: UMockEvaluator val types: UTypeEvaluator @@ -77,7 +80,7 @@ interface UReadOnlyMemory { fun nullRef(): UHeapRef - fun toWritableMemory(): UWritableMemory + fun toWritableMemory(ownership: MutabilityOwnership): UWritableMemory } interface UWritableMemory : UReadOnlyMemory { @@ -92,12 +95,12 @@ interface UWritableMemory : UReadOnlyMemory { @Suppress("MemberVisibilityCanBePrivate") class UMemory( internal val ctx: UContext<*>, + override var ownership: MutabilityOwnership, override val types: UTypeConstraints, override val stack: URegistersStack = URegistersStack(), private val mocks: UIndexedMocker = UIndexedMocker(), - persistentRegions: PersistentMap, UMemoryRegion<*, *>> = persistentHashMapOf(), -) : UWritableMemory, UMergeable, MergeGuard> { - private val regions = persistentRegions.builder() + private var regions: UPersistentHashMap, UMemoryRegion<*, *>> = persistentHashMapOf(), +) : UWritableMemory, UOwnedMergeable, MergeGuard> { override val mocker: UMocker get() = mocks @@ -106,9 +109,9 @@ class UMemory( override fun getRegion(regionId: UMemoryRegionId): UMemoryRegion { if (regionId is URegisterStackId) return stack as UMemoryRegion - return regions.getOrPut(regionId) { - regionId.emptyRegion() - } as UMemoryRegion + val (updatedRegions, region) = regions.getOrPut(regionId, ownership) { regionId.emptyRegion() } + regions = updatedRegions + return region as UMemoryRegion } override fun setRegion( @@ -119,7 +122,7 @@ class UMemory( check(newRegion === stack) { "Stack is mutable" } return } - regions[regionId] = newRegion + regions = regions.put(regionId, newRegion, ownership) } override fun write(lvalue: ULValue, rvalue: UExpr, guard: UBoolExpr) = @@ -132,7 +135,7 @@ class UMemory( guard: UBoolExpr ) { val region = getRegion(regionId) - val newRegion = region.write(key, value, guard) + val newRegion = region.write(key, value, guard, ownership) setRegion(regionId, newRegion) } @@ -152,13 +155,22 @@ class UMemory( override fun nullRef(): UHeapRef = ctx.nullRef - fun clone(typeConstraints: UTypeConstraints): UMemory = - UMemory(ctx, typeConstraints, stack.clone(), mocks.clone(), regions.build()) + fun clone( + typeConstraints: UTypeConstraints, + thisOwnership: MutabilityOwnership, + cloneOwnership: MutabilityOwnership, + ): UMemory = + UMemory( + ctx, cloneOwnership, typeConstraints, stack.clone(), mocks.clone(), regions + ).also { ownership = thisOwnership } - override fun toWritableMemory() = - // To be perfectly rigorous, we should clone stack and types here. - // But in fact they should not be used, so to optimize things up, we don't touch them. - UMemory(ctx, types, stack, mocks, regions.build()) + override fun toWritableMemory(ownership: MutabilityOwnership) = + /* NOTE 1: To be perfectly rigorous, we should clone stack and types here. + But in fact they should not be used, so to optimize things up, we don't touch them. + NOTE 2: method returns *temporary* copy of this [UMemory], so write operations on the copy will not + affect this [UMemory], while write operations on this [UMemory] *can* affect the copy. + */ + UMemory(ctx, ownership, types, stack, mocks, regions) /** @@ -172,9 +184,15 @@ class UMemory( * * @return the merged memory. */ - override fun mergeWith(other: UMemory, by: MergeGuard): UMemory? { - val ids = regions.keys - val otherIds = other.regions.keys + override fun mergeWith( + other: UMemory, + by: MergeGuard, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership, + ): UMemory? { + val ids = regions.keys.toList() + val otherIds = other.regions.keys.toList() if (ids != otherIds) { return null } @@ -189,10 +207,13 @@ class UMemory( } } - val mergedRegions = regions.build() + val mergedRegions = regions val mergedStack = stack.mergeWith(other.stack, by) ?: return null - val mergedMocks = mocks.mergeWith(other.mocks, by) ?: return null + val mergedMocks = mocks.mergeWith(other.mocks, by) + ?: return null - return UMemory(ctx, types, mergedStack, mergedMocks, mergedRegions) + this.ownership = thisOwnership + other.ownership = otherOwnership + return UMemory(ctx, mergedOwnership, types, mergedStack, mergedMocks, mergedRegions) } } diff --git a/usvm-core/src/main/kotlin/org/usvm/memory/RegistersStack.kt b/usvm-core/src/main/kotlin/org/usvm/memory/RegistersStack.kt index f5375d2149..1920a358fb 100644 --- a/usvm-core/src/main/kotlin/org/usvm/memory/RegistersStack.kt +++ b/usvm-core/src/main/kotlin/org/usvm/memory/RegistersStack.kt @@ -5,6 +5,7 @@ import io.ksmt.utils.asExpr import org.usvm.UBoolExpr import org.usvm.UExpr import org.usvm.USort +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.isTrue import org.usvm.merging.MergeGuard import org.usvm.merging.UMergeable @@ -54,6 +55,7 @@ class URegistersStack( key: URegisterStackLValue<*>, value: UExpr, guard: UBoolExpr, + ownership: MutabilityOwnership, ): UMemoryRegion, USort> { check(guard.isTrue) { "Guarded writes are not supported for register" } writeRegister(key.idx, value) diff --git a/usvm-core/src/main/kotlin/org/usvm/memory/USymbolicCollection.kt b/usvm-core/src/main/kotlin/org/usvm/memory/USymbolicCollection.kt index 921e1d1f88..289becb2b7 100644 --- a/usvm-core/src/main/kotlin/org/usvm/memory/USymbolicCollection.kt +++ b/usvm-core/src/main/kotlin/org/usvm/memory/USymbolicCollection.kt @@ -1,12 +1,13 @@ package org.usvm.memory import io.ksmt.utils.asExpr -import kotlinx.collections.immutable.PersistentMap import org.usvm.UBoolExpr import org.usvm.UComposer import org.usvm.UConcreteHeapRef import org.usvm.UExpr import org.usvm.USort +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.isFalse import org.usvm.isTrue import org.usvm.uctx @@ -116,7 +117,8 @@ data class USymbolicCollection, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ): USymbolicCollection { assert(value.sort == sort) @@ -245,16 +247,17 @@ class GuardBuilder(nonMatchingUpdates: UBoolExpr) { get() = nonMatchingUpdatesGuard.isFalse } -inline fun PersistentMap>.guardedWrite( +inline fun UPersistentHashMap>.guardedWrite( key: K, value: UExpr, guard: UBoolExpr, - defaultValue: () -> UExpr -): PersistentMap> { + ownership: MutabilityOwnership, + defaultValue: () -> UExpr, +): UPersistentHashMap> { val guardedValue = guard.uctx.mkIte( guard, { value }, { get(key) ?: defaultValue() } ) - return put(key, guardedValue) + return put(key, guardedValue, ownership) } diff --git a/usvm-core/src/main/kotlin/org/usvm/merging/Merging.kt b/usvm-core/src/main/kotlin/org/usvm/merging/Merging.kt index a5b2f08468..d7e0e46ac2 100644 --- a/usvm-core/src/main/kotlin/org/usvm/merging/Merging.kt +++ b/usvm-core/src/main/kotlin/org/usvm/merging/Merging.kt @@ -1,8 +1,24 @@ package org.usvm.merging +import org.usvm.collections.immutable.internal.MutabilityOwnership + interface UMergeable { /** * @return Merged entity or `null` if `this` and [other] are non-mergeable. */ fun mergeWith(other: Entity, by: By): Entity? } + +interface UOwnedMergeable { + /** + * @return Merged entity with [mergedOwnership] as ownership or `null` if `this` and [other] are non-mergeable. + * Changes [this] and [other] ownerships to [thisOwnership] and [otherOwnership] respectively. + */ + fun mergeWith( + other: Entity, + by: By, + thisOwnership: MutabilityOwnership, + otherOwnership: MutabilityOwnership, + mergedOwnership: MutabilityOwnership, + ): Entity? +} diff --git a/usvm-core/src/main/kotlin/org/usvm/merging/MergingPathSelector.kt b/usvm-core/src/main/kotlin/org/usvm/merging/MergingPathSelector.kt index 4787883b91..d1e7f5716e 100644 --- a/usvm-core/src/main/kotlin/org/usvm/merging/MergingPathSelector.kt +++ b/usvm-core/src/main/kotlin/org/usvm/merging/MergingPathSelector.kt @@ -79,4 +79,4 @@ class MergingPathSelector>( underlyingPathSelector.remove(state) closeStatesSearcher.remove(state) } -} \ No newline at end of file +} diff --git a/usvm-core/src/main/kotlin/org/usvm/model/Model.kt b/usvm-core/src/main/kotlin/org/usvm/model/Model.kt index 59feb5405f..2e23a4bcd5 100644 --- a/usvm-core/src/main/kotlin/org/usvm/model/Model.kt +++ b/usvm-core/src/main/kotlin/org/usvm/model/Model.kt @@ -10,6 +10,7 @@ import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.UMockEvaluator import org.usvm.USort +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.ULValue import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -37,9 +38,10 @@ open class UModelBase( override val mocker: UMockEvaluator, val regions: Map, UReadOnlyMemoryRegion<*, *>>, val nullRef: UConcreteHeapRef, + override val ownership: MutabilityOwnership = MutabilityOwnership(), ) : UModel, UWritableMemory { @Suppress("LeakingThis") - protected open val composer = ctx.composer(this) + protected open val composer = ctx.composer(this, ownership) /** * The evaluator supports only expressions with symbols inheriting [org.usvm.USymbol]. @@ -61,7 +63,7 @@ open class UModelBase( override fun nullRef(): UHeapRef = nullRef - override fun toWritableMemory(): UWritableMemory = this + override fun toWritableMemory(ownership: MutabilityOwnership): UWritableMemory = this override fun setRegion( regionId: UMemoryRegionId, diff --git a/usvm-core/src/main/kotlin/org/usvm/model/ModelRegions.kt b/usvm-core/src/main/kotlin/org/usvm/model/ModelRegions.kt index 5604612f43..8c926146d8 100644 --- a/usvm-core/src/main/kotlin/org/usvm/model/ModelRegions.kt +++ b/usvm-core/src/main/kotlin/org/usvm/model/ModelRegions.kt @@ -1,10 +1,11 @@ package org.usvm.model import io.ksmt.expr.KExpr -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.usvm.UExpr import org.usvm.USort +import org.usvm.collections.immutable.getOrDefault +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.memory.UMemoryRegion import org.usvm.memory.UReadOnlyMemoryRegion @@ -13,7 +14,7 @@ import org.usvm.memory.UReadOnlyMemoryRegion * A specific [UMemoryRegion] for one-dimensional regions generalized by a single expression of a [KeySort]. */ class UMemory1DArray internal constructor( - private val values: PersistentMap, UExpr>, + private val values: UPersistentHashMap, UExpr>, private val constValue: UExpr, ) : UReadOnlyMemoryRegion, Sort> { @@ -25,7 +26,8 @@ class UMemory1DArray internal constructor( mappedConstValue: UExpr, ) : this(persistentHashMapOf(), mappedConstValue) - override fun read(key: KExpr): UExpr = values.getOrDefault(key, constValue) + override fun read(key: KExpr): UExpr = + values.getOrDefault(key, constValue) } /** @@ -33,7 +35,7 @@ class UMemory1DArray internal constructor( * of two expressions with [Key1Sort] and [Key2Sort] sorts. */ class UMemory2DArray internal constructor( - val values: PersistentMap, UExpr>, UExpr>, + val values: UPersistentHashMap, UExpr>, UExpr>, val constValue: UExpr, ) : UReadOnlyMemoryRegion, KExpr>, Sort> { /** diff --git a/usvm-core/src/main/kotlin/org/usvm/model/UModelEvaluator.kt b/usvm-core/src/main/kotlin/org/usvm/model/UModelEvaluator.kt index be68857e07..0bb7e7e2c0 100644 --- a/usvm-core/src/main/kotlin/org/usvm/model/UModelEvaluator.kt +++ b/usvm-core/src/main/kotlin/org/usvm/model/UModelEvaluator.kt @@ -27,7 +27,7 @@ import io.ksmt.sort.KSort import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort import io.ksmt.utils.uncheckedCast -import kotlinx.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.NULL_ADDRESS import org.usvm.UContext import org.usvm.UExpr @@ -114,20 +114,20 @@ open class UModelEvaluator( ): UMemory1DArray { val interpretation = model.interpretation(translated) - val stores = persistentHashMapOf, UExpr>().builder() + var stores = persistentHashMapOf, UExpr>() val defaultValue = interpretation?.let { traverse1DArrayEntries(interpretation) { idx, value -> - stores[idx.mapAddress(addressesMapping)] = value.mapAddress(addressesMapping) + stores = stores.put(idx.mapAddress(addressesMapping), value.mapAddress(addressesMapping), ctx.defaultOwnership) } } if (defaultValue != null) { - return UMemory1DArray(stores.build(), defaultValue.mapAddress(addressesMapping)) + return UMemory1DArray(stores, defaultValue.mapAddress(addressesMapping)) } return completed1DArrays.getOrPut(translated) { val completedDefault = translated.sort.range.accept(this) - UMemory1DArray(stores.build(), completedDefault.uncheckedCast()) + UMemory1DArray(stores, completedDefault.uncheckedCast()) }.uncheckedCast() } @@ -136,26 +136,26 @@ open class UModelEvaluator( * Complete model according to the array range (value) sort if array is free in the model. * */ open fun evalAndCompleteArray2DMemoryRegion( - translated: KDecl>, + translated: KDecl> ): UMemory2DArray { val interpretation = model.interpretation(translated) - val stores = persistentHashMapOf, UExpr>, UExpr>().builder() + var stores = persistentHashMapOf, UExpr>, UExpr>() val defaultValue = interpretation?.let { traverse2DArrayEntries(interpretation) { idx1, idx2, value -> val mappedIdx1 = idx1.mapAddress(addressesMapping) val mappedIdx2 = idx2.mapAddress(addressesMapping) - stores[mappedIdx1 to mappedIdx2] = value.mapAddress(addressesMapping) + stores = stores.put(mappedIdx1 to mappedIdx2, value.mapAddress(addressesMapping), ctx.defaultOwnership) } } if (defaultValue != null) { - return UMemory2DArray(stores.build(), defaultValue.mapAddress(addressesMapping)) + return UMemory2DArray(stores, defaultValue.mapAddress(addressesMapping)) } return completed2DArrays.getOrPut(translated) { val completedDefault = translated.sort.range.accept(this) - UMemory2DArray(stores.build(), completedDefault.uncheckedCast()) + UMemory2DArray(stores, completedDefault.uncheckedCast()) }.uncheckedCast() } diff --git a/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt b/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt index 827b95d67d..b0f50aca0c 100644 --- a/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt +++ b/usvm-core/src/main/kotlin/org/usvm/solver/Solver.kt @@ -149,7 +149,7 @@ open class USolverBase( } fun emptyModel(): UModelBase = - (check(UPathConstraints(ctx)) as USatResult>).model + (check(UPathConstraints(ctx, ctx.defaultOwnership)) as USatResult>).model override fun close() { smtSolver.close() diff --git a/usvm-core/src/test/kotlin/org/usvm/CompositionTest.kt b/usvm-core/src/test/kotlin/org/usvm/CompositionTest.kt index 62e12db253..27b016e41e 100644 --- a/usvm-core/src/test/kotlin/org/usvm/CompositionTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/CompositionTest.kt @@ -27,6 +27,7 @@ import org.usvm.collection.array.USymbolicArrayInputToInputCopyAdapter import org.usvm.collection.array.length.UInputArrayLengthId import org.usvm.collection.field.UFieldLValue import org.usvm.collection.field.UInputFieldId +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UTypeEvaluator import org.usvm.memory.UFlatUpdates import org.usvm.memory.UMemory @@ -54,6 +55,7 @@ internal class CompositionTest { private lateinit var memory: UReadOnlyMemory private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var concreteNull: UConcreteHeapRef private lateinit var composer: UComposer @@ -63,8 +65,9 @@ internal class CompositionTest { every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory -> UComposer(ctx, memory) } } + every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } concreteNull = ctx.mkConcreteHeapRef(NULL_ADDRESS) stackEvaluator = mockk() @@ -76,7 +79,7 @@ internal class CompositionTest { every { memory.stack } returns stackEvaluator every { memory.mocker } returns mockEvaluator - composer = UComposer(ctx, memory) + composer = UComposer(ctx, memory, ownership) } @Test @@ -255,12 +258,12 @@ internal class CompositionTest { val fstValueFromHeap = 42.toBv() val sndValueFromHeap = 43.toBv() - val heapToComposeWith = UMemory>, Any>(ctx, mockk()) + val heapToComposeWith = UMemory>, Any>(ctx, MutabilityOwnership(), mockk()) heapToComposeWith.writeArrayLength(fstConcreteAddress, fstValueFromHeap, arrayType, sizeSort) heapToComposeWith.writeArrayLength(sndConcreteAddress, sndValueFromHeap, arrayType, sizeSort) - val composer = UComposer(ctx, heapToComposeWith) + val composer = UComposer(ctx, heapToComposeWith, MutabilityOwnership()) every { fstAddress.accept(composer) } returns fstConcreteAddress every { sndAddress.accept(composer) } returns sndConcreteAddress @@ -310,7 +313,8 @@ internal class CompositionTest { val answer = 43.toBv() - val composer = UComposer(ctx, UMemory, Any>(ctx, mockk())) // TODO replace with jacoDB type + // TODO replace with jacoDB type + val composer = UComposer(ctx, UMemory, Any>(ctx, ownership, mockk()), MutabilityOwnership()) every { fstAddress.accept(composer) } returns sndAddress every { fstIndex.accept(composer) } returns sndIndex @@ -340,17 +344,17 @@ internal class CompositionTest { // create a reading from the region val fstArrayIndexReading = mkInputArrayReading(region, fstAddress, fstIndex) - val sndMemory = UMemory, Any>(ctx, mockk(), mockk()) + val sndMemory = UMemory, Any>(ctx, MutabilityOwnership(), mockk(), mockk()) // create a heap with a record: (sndAddress, sndIndex) = 2 sndMemory.writeArrayIndex(sndAddress, sndIndex, arrayType, mkBv32Sort(), 2.toBv(), mkTrue()) - val sndComposer = UComposer(ctx, sndMemory) + val sndComposer = UComposer(ctx, sndMemory, MutabilityOwnership()) - val fstMemory = UMemory, Any>(ctx, mockk(), mockk()) + val fstMemory = UMemory, Any>(ctx, ownership, mockk(), mockk()) // create a heap with a record: (fstAddress, fstIndex) = 1 fstMemory.writeArrayIndex(fstAddress, fstIndex, arrayType, mkBv32Sort(), 1.toBv(), mkTrue()) - val fstComposer = UComposer(ctx, fstMemory) // TODO replace with jacoDB type + val fstComposer = UComposer(ctx, fstMemory, MutabilityOwnership()) // TODO replace with jacoDB type // Both heaps leave everything untouched every { sndAddress.accept(sndComposer) } returns sndAddress @@ -416,7 +420,7 @@ internal class CompositionTest { val fstValue = 42.toBv() val sndValue = 43.toBv() - val heapToComposeWith = UMemory>, Any>(ctx, mockk()) + val heapToComposeWith = UMemory>, Any>(ctx, MutabilityOwnership(), mockk()) heapToComposeWith.writeArrayIndex( fstAddressForCompose, concreteIndex, arrayType, regionArray.sort, fstValue, guard = trueExpr @@ -425,7 +429,7 @@ internal class CompositionTest { sndAddressForCompose, concreteIndex, arrayType, regionArray.sort, sndValue, guard = trueExpr ) - val composer = UComposer(ctx, heapToComposeWith) + val composer = UComposer(ctx, heapToComposeWith, MutabilityOwnership()) every { fstSymbolicIndex.accept(composer) } returns concreteIndex every { sndSymbolicIndex.accept(composer) } returns concreteIndex @@ -448,8 +452,8 @@ internal class CompositionTest { val regionArray = UAllocatedArrayId<_, _, USizeSort>(arrayType, addressSort, 0) .emptyRegion() - .write(mkBv(0), symbolicAddress, trueExpr) - .write(mkBv(1), mkConcreteHeapRef(1), trueExpr) + .write(mkBv(0), symbolicAddress, trueExpr, ownership) + .write(mkBv(1), mkConcreteHeapRef(1), trueExpr, ownership) val reading = mkAllocatedArrayReading(regionArray, symbolicIndex) @@ -459,7 +463,7 @@ internal class CompositionTest { ctx, mockk(), mockk(), mockk(), emptyMap(), concreteNullRef ) - val composer = spyk(UComposer(ctx, heapToComposeWith)) + val composer = spyk(UComposer(ctx, heapToComposeWith, MutabilityOwnership())) every { symbolicIndex.accept(composer) } returns mkBv(2) every { symbolicAddress.accept(composer) } returns mkConcreteHeapRef(-1) @@ -487,7 +491,7 @@ internal class CompositionTest { val region = USymbolicCollection( UInputFieldId(field, bv32Sort), updates, - ).write(aAddress, 43.toBv(), guard = trueExpr) + ).write(aAddress, 43.toBv(), guard = trueExpr, ownership) every { aAddress.accept(any()) } returns aAddress every { bAddress.accept(any()) } returns aAddress @@ -505,10 +509,10 @@ internal class CompositionTest { val answer = 43.toBv() - val composeMemory = UMemory(ctx, mockk()) + val composeMemory = UMemory(ctx, MutabilityOwnership(), mockk()) composeMemory.writeField(aAddress, field, bv32Sort, 42.toBv(), guard = trueExpr) - val composer = UComposer(ctx, composeMemory) + val composer = UComposer(ctx, composeMemory, MutabilityOwnership()) val composedExpression = composer.compose(expression) @@ -523,7 +527,7 @@ internal class CompositionTest { ) val model = UModelBase(ctx, stackModel, mockk(), mockk(), emptyMap(), concreteNull) - val composer = UComposer(this, model) + val composer = UComposer(this, model, MutabilityOwnership()) val heapRefEvalEq = mkHeapRefEq(mkRegisterReading(0, addressSort), mkRegisterReading(1, addressSort)) @@ -539,7 +543,7 @@ internal class CompositionTest { every { composedMemory.nullRef() } returns concreteNull every { composedMemory.stack } returns stackModel - val composer = UComposer(this, composedMemory) + val composer = UComposer(this, composedMemory, MutabilityOwnership()) val heapRefEvalEq = mkHeapRefEq(mkRegisterReading(0, addressSort), nullRef) @@ -556,7 +560,7 @@ internal class CompositionTest { val symbolicRef2 = mkRegisterReading(2, addressSort) as UHeapRef val composedSymbolicHeapRef = mkConcreteHeapRef(1) - val composeMemory = UMemory(ctx, mockk()) + val composeMemory = UMemory(ctx, MutabilityOwnership(), mockk()) composeMemory.writeArrayIndex(composedSymbolicHeapRef, mkBv(3), arrayType, bv32Sort, mkBv(1337), trueExpr) @@ -566,11 +570,11 @@ internal class CompositionTest { composeMemory.stack.writeRegister(2, composedSymbolicHeapRef) composeMemory.stack.writeRegister(3, mkRegisterReading(3, bv32Sort)) - val composer = UComposer(ctx, composeMemory) + val composer = UComposer(ctx, composeMemory, MutabilityOwnership()) val fromRegion0 = UInputArrayId<_, _, USizeSort>(arrayType, bv32Sort) .emptyRegion() - .write(symbolicRef0 to mkBv(0), mkBv(42), trueExpr) + .write(symbolicRef0 to mkBv(0), mkBv(42), trueExpr, ownership) val adapter1 = USymbolicArrayInputToInputCopyAdapter( symbolicRef0 to mkSizeExpr(0), @@ -624,7 +628,7 @@ internal class CompositionTest { val stackModel = URegistersStackEagerModel(concreteNull, mapOf(0 to mkBv(0), 1 to mkBv(0), 2 to mkBv(2))) every { composedMemory.stack } returns stackModel - val composer = UComposer(this, composedMemory) + val composer = UComposer(this, composedMemory, MutabilityOwnership()) val region = UAllocatedArrayId<_, _, USizeSort>(mockk(), addressSort, 1).emptyRegion() val reading = region.read(mkRegisterReading(0, sizeSort)) @@ -645,12 +649,13 @@ internal class CompositionTest { val region = UInputFieldId(field, bv32Sort) .emptyRegion() - .write(ref0, mkBv(0), trueExpr) - .write(ref1, mkBv(1), trueExpr) + .write(ref0, mkBv(0), trueExpr, ownership) + .write(ref1, mkBv(1), trueExpr, ownership) val reading = region.read(ref2) - val composer = spyk(UComposer(this, composedMemory)) + val composerOwnership = MutabilityOwnership() + val composer = spyk(UComposer(this, composedMemory, composerOwnership)) val writableMemory: UWritableMemory = mockk() @@ -658,7 +663,7 @@ internal class CompositionTest { every { composer.transform(ref1) } returns ref1 every { composer.transform(ref2) } returns ref2 - every { composedMemory.toWritableMemory() } returns writableMemory + every { composedMemory.toWritableMemory(composerOwnership) } returns writableMemory every { val lvalue = UFieldLValue(bv32Sort, ref1, field) diff --git a/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt b/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt index 833ffaa5f0..e35f5a1938 100644 --- a/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt +++ b/usvm-core/src/test/kotlin/org/usvm/TestUtil.kt @@ -3,6 +3,7 @@ package org.usvm import io.mockk.every import io.mockk.mockk import io.mockk.spyk +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.memory.USymbolicCollectionKeyInfo @@ -39,12 +40,14 @@ internal class TestTarget(method: TestMethod, offset: Int) : UTarget, - callStack: UCallStack, pathConstraints: UPathConstraints, + ownership: MutabilityOwnership, + callStack: UCallStack, + pathConstraints: UPathConstraints, memory: UMemory, models: List>, pathLocation: PathNode, targetTrees: UTargetsSet = UTargetsSet.empty(), override val entrypoint: TestMethod = "" -) : UState, TestTarget, TestState>(ctx, callStack, pathConstraints, memory, models, pathLocation, PathNode.root(), targetTrees) { +) : UState, TestTarget, TestState>(ctx, ownership, callStack, pathConstraints, memory, models, pathLocation, PathNode.root(), targetTrees) { override fun clone(newConstraints: UPathConstraints?): TestState = this @@ -71,7 +74,11 @@ internal fun mockState(id: StateId, startMethod: TestMethod, startInstruction: I val ctxMock = mockk>() every { ctxMock.getNextStateId() } returns id val callStack = UCallStack(startMethod) - val spyk = spyk(TestState(ctxMock, callStack, mockk(), mockk(), emptyList(), mockk(), UTargetsSet.from(targets))) + val spyk = spyk( + TestState( + ctxMock, MutabilityOwnership(), callStack, mockk(), mockk(), emptyList(), mockk(), UTargetsSet.from(targets) + ) + ) every { spyk.currentStatement } returns TestInstruction(startMethod, startInstruction) return spyk } diff --git a/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt b/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt index 703449e3d6..6564912ec2 100644 --- a/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt +++ b/usvm-core/src/test/kotlin/org/usvm/api/collections/SymbolicCollectionTestBase.kt @@ -19,6 +19,7 @@ import org.usvm.UExpr import org.usvm.USizeSort import org.usvm.UState import org.usvm.WithSolverStateForker +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.forkblacklists.UForkBlackList import org.usvm.memory.UMemory @@ -35,6 +36,7 @@ import kotlin.time.Duration.Companion.INFINITE abstract class SymbolicCollectionTestBase { lateinit var ctx: UContext + lateinit var ownership: MutabilityOwnership lateinit var pathConstraints: UPathConstraints lateinit var memory: UMemory lateinit var scope: StepScope> @@ -47,9 +49,10 @@ abstract class SymbolicCollectionTestBase { every { components.mkTypeSystem(any()) } returns mockk() every { components.mkSolver(any()) } answers { uSolver.uncheckedCast() } ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkComposer(ctx) } answers { - { memory: UReadOnlyMemory -> UComposer(ctx, memory) } + { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } val translator = UExprTranslator(ctx) @@ -61,24 +64,32 @@ abstract class SymbolicCollectionTestBase { every { components.mkStatesForkProvider() } answers { WithSolverStateForker } - pathConstraints = UPathConstraints(ctx) - memory = UMemory(ctx, pathConstraints.typeConstraints) - scope = StepScope(StateStub(ctx, pathConstraints, memory), UForkBlackList.createDefault()) + pathConstraints = UPathConstraints(ctx, ownership) + memory = UMemory(ctx, ownership, pathConstraints.typeConstraints) + scope = StepScope(StateStub(ctx, ownership, pathConstraints, memory), UForkBlackList.createDefault()) } class TargetStub : UTarget() class StateStub( ctx: UContext, + ownership: MutabilityOwnership, pathConstraints: UPathConstraints, memory: UMemory, ) : UState, TargetStub, StateStub>( - ctx, UCallStack(), + ctx, ownership, UCallStack(), pathConstraints, memory, emptyList(), PathNode.root(), PathNode.root(), UTargetsSet.empty() ) { override fun clone(newConstraints: UPathConstraints?): StateStub { - val clonedConstraints = newConstraints ?: pathConstraints.clone() - return StateStub(ctx, clonedConstraints, memory.clone(clonedConstraints.typeConstraints)) + val thisOwnership = MutabilityOwnership() + val cloneOwnership = MutabilityOwnership() + val clonedConstraints = newConstraints ?: pathConstraints.clone(thisOwnership, cloneOwnership) + return StateStub( + ctx, + cloneOwnership, + clonedConstraints, + memory.clone(clonedConstraints.typeConstraints, thisOwnership, cloneOwnership) + ) } override val isExceptional: Boolean diff --git a/usvm-core/src/test/kotlin/org/usvm/constraints/EqualityConstraintsTests.kt b/usvm-core/src/test/kotlin/org/usvm/constraints/EqualityConstraintsTests.kt index 2143106f41..1254fcdb8f 100644 --- a/usvm-core/src/test/kotlin/org/usvm/constraints/EqualityConstraintsTests.kt +++ b/usvm-core/src/test/kotlin/org/usvm/constraints/EqualityConstraintsTests.kt @@ -8,11 +8,14 @@ import org.junit.jupiter.api.Test import org.usvm.UComponents import org.usvm.UContext import org.usvm.USizeSort +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.isEmpty import kotlin.test.assertSame import kotlin.test.assertTrue class EqualityConstraintsTests { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var constraints: UEqualityConstraints @BeforeEach @@ -20,7 +23,8 @@ class EqualityConstraintsTests { val components: UComponents<*, USizeSort> = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) - constraints = UEqualityConstraints(ctx) + ownership = MutabilityOwnership() + constraints = UEqualityConstraints(ctx, ownership) } @Test @@ -41,13 +45,13 @@ class EqualityConstraintsTests { // Add ref2 != ref3 constraints.makeNonEqual(ref2, ref3) // ref1 still can be equal to ref3 - assertSame(3, constraints.distinctReferences.size) + assertSame(3, constraints.distinctReferences.calculateSize()) assertTrue(constraints.referenceDisequalities[ref2]!!.contains(ref3)) assertTrue(constraints.referenceDisequalities[ref3]!!.contains(ref2)) constraints.makeNonEqual(ref1, ref3) // Now ref1, ref2 and ref3 are guaranteed to be distinct - assertSame(4, constraints.distinctReferences.size) + assertSame(4, constraints.distinctReferences.calculateSize()) assertTrue(constraints.referenceDisequalities.all { it.value.isEmpty() }) // Adding some entry into referenceDisequalities @@ -107,10 +111,10 @@ class EqualityConstraintsTests { // (3) ref4 != null // (4) ref3 != ref4 || (ref3 == ref4 == null) // These two should be automatically simplified to ref3 != ref4. - assertSame(2, constraints.distinctReferences.size) + assertSame(2, constraints.distinctReferences.calculateSize()) constraints.makeNonEqual(ref3, ctx.nullRef) // Now we have obtained that null, ref3 and ref4 are 3 distinct references. This should be represented as clique // constraint... - assertSame(3, constraints.distinctReferences.size) + assertSame(3, constraints.distinctReferences.calculateSize()) } } diff --git a/usvm-core/src/test/kotlin/org/usvm/constraints/NumericConstraintsTests.kt b/usvm-core/src/test/kotlin/org/usvm/constraints/NumericConstraintsTests.kt index ef7c97c862..b8b0feec27 100644 --- a/usvm-core/src/test/kotlin/org/usvm/constraints/NumericConstraintsTests.kt +++ b/usvm-core/src/test/kotlin/org/usvm/constraints/NumericConstraintsTests.kt @@ -17,6 +17,7 @@ import org.usvm.UContext import org.usvm.UExpr import org.usvm.UNotExpr import org.usvm.USizeSort +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.isFalse import org.usvm.logger import org.usvm.regions.IntIntervalsRegion @@ -26,6 +27,7 @@ import kotlin.test.assertTrue class NumericConstraintsTests { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var bvSort: UBvSort private lateinit var constraints: UNumericConstraints private var previousConstraints: UNumericConstraints? = null @@ -36,13 +38,14 @@ class NumericConstraintsTests { val components: UComponents<*, USizeSort> = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() bvSort = ctx.mkBvSort(sizeBits = 8u) resetConstraints() } private fun resetConstraints() { - constraints = UNumericConstraints(ctx, bvSort) + constraints = UNumericConstraints(ctx, bvSort, ownership) previousConstraints = null unsimplifiedConstraints = mutableListOf() @@ -80,7 +83,7 @@ class NumericConstraintsTests { @Test fun testLinearPatternConstraintPropagation(): Unit = KZ3Solver(ctx).use { solver -> bvSort = ctx.mkBvSort(sizeBits = 32u) - constraints = UNumericConstraints(ctx, bvSort) + constraints = UNumericConstraints(ctx, bvSort, ownership) val bound = ctx.mkConst("bound", bvSort) var x: UExpr = ctx.mkConst("x", bvSort) @@ -102,7 +105,7 @@ class NumericConstraintsTests { fun testConcreteBoundsSimplification(): Unit = with(ctx) { KZ3Solver(ctx).use { solver -> bvSort = ctx.mkBvSort(sizeBits = 4u) - constraints = UNumericConstraints(ctx, bvSort) + constraints = UNumericConstraints(ctx, bvSort, ownership) val x by bvSort val zero = mkBv(0, bvSort) @@ -117,10 +120,44 @@ class NumericConstraintsTests { } } + @Test + fun test() = with(ctx) { + val a = ctx.mkConst("a", bvSort) + val b = ctx.mkConst("b", bvSort) + val c = ctx.mkConst("c", bvSort) + + val constraintsArray = arrayOf( + mkNotNoSimplify(mkBvSignedLessOrEqualExprNoSimplify( + mkBvAddExprNoSimplify(mkBvNegationExprNoSimplify(c), mkBvNegationExprNoSimplify(b)), + mkBvAddExprNoSimplify(c, mkBv(0xEE, bvSort)))), + mkNotNoSimplify(mkBvSignedGreaterOrEqualExprNoSimplify(mkBvAddExprNoSimplify(b, a), mkBv(0x3E, bvSort))), + mkNotNoSimplify(mkBvSignedGreaterOrEqualExprNoSimplify( + mkBvAddExprNoSimplify(c, mkBv(0xF8, bvSort)), + mkBvAddExprNoSimplify(mkBvNegationExprNoSimplify(b), mkBvNegationExprNoSimplify(a)))), + mkBvSignedGreaterOrEqualExprNoSimplify(mkBv(0x79, bvSort), + mkBvAddExprNoSimplify(mkBvNegationExprNoSimplify(b), mkBv(0x48, bvSort))), + mkEqNoSimplify(mkBv(0x05, bvSort), mkBvAddExprNoSimplify( + mkBvNegationExprNoSimplify(b), mkBvNegationExprNoSimplify(a))), + mkNotNoSimplify(mkBvSignedGreaterOrEqualExprNoSimplify( + mkBvAddExprNoSimplify(mkBv(0xAD, bvSort), a), mkBv(0x65, bvSort))), + mkNotNoSimplify(mkBvSignedLessOrEqualExprNoSimplify( + mkBvAddExprNoSimplify(mkBv(0x6E, bvSort), mkBvNegationExprNoSimplify(c)), + mkBvAddExprNoSimplify(mkBv(0xE0, bvSort), mkBvNegationExprNoSimplify(b))) + ) + ) + + KYicesSolver(ctx).use { solver -> + for (constraint in constraintsArray) { + addConstraint(constraint) + solver.checkConstraints(0) + } + } + } + @Test fun testEvalInterval(): Unit = with(ctx) { bvSort = ctx.mkBvSort(sizeBits = 32u) - constraints = UNumericConstraints(ctx, bvSort) + constraints = UNumericConstraints(ctx, bvSort, ownership) val x by bvSort // x in [-5, -1] U [1, 5] @@ -199,7 +236,7 @@ class NumericConstraintsTests { } private fun addConstraint(expr: UBoolExpr) { - previousConstraints = constraints.clone() + previousConstraints = constraints.clone(ownership, MutabilityOwnership()) constraints.addConstraint(expr) unsimplifiedConstraints.add(expr) } @@ -245,8 +282,8 @@ class NumericConstraintsTests { } val lastExpr = unsimplifiedConstraints.last() - logger.debug { "Incorrect state after add: $lastExpr" } - logger.debug { "Unsatisfied statements: $failedStatements" } + logger.error { "Incorrect state after add: $lastExpr" } + logger.error { "Unsatisfied statements: $failedStatements" } previousConstraints?.addConstraint(lastExpr) } diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemCpyTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemCpyTest.kt index 9b61c1424c..cbfd870413 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemCpyTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemCpyTest.kt @@ -8,6 +8,7 @@ import org.usvm.api.allocateArray import org.usvm.api.memcpy import org.usvm.api.readArrayIndex import org.usvm.api.writeArrayIndex +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UEqualityConstraints import org.usvm.constraints.UTypeConstraints import kotlin.test.Test @@ -15,6 +16,7 @@ import kotlin.test.assertEquals class HeapMemCpyTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var heap: UMemory private lateinit var arrayType: Type private lateinit var arrayValueSort: USizeSort @@ -24,10 +26,11 @@ class HeapMemCpyTest { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - val eqConstraints = UEqualityConstraints(ctx) - val typeConstraints = UTypeConstraints(components.mkTypeSystem(ctx), eqConstraints) - heap = UMemory(ctx, typeConstraints) + val eqConstraints = UEqualityConstraints(ctx, ownership) + val typeConstraints = UTypeConstraints(ownership, components.mkTypeSystem(ctx), eqConstraints) + heap = UMemory(ctx, ownership, typeConstraints) arrayType = mockk() arrayValueSort = ctx.sizeSort } diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemsetTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemsetTest.kt index f001390809..2ae6dfff8a 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemsetTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/HeapMemsetTest.kt @@ -14,6 +14,7 @@ import org.usvm.api.allocateArrayInitialized import org.usvm.api.memset import org.usvm.api.readArrayIndex import org.usvm.api.readArrayLength +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UEqualityConstraints import org.usvm.constraints.UTypeConstraints import org.usvm.mkSizeExpr @@ -25,6 +26,7 @@ import kotlin.test.assertTrue class HeapMemsetTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var heap: UMemory private lateinit var arrayType: Type private lateinit var arrayValueSort: UAddressSort @@ -34,10 +36,11 @@ class HeapMemsetTest { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - val eqConstraints = UEqualityConstraints(ctx) - val typeConstraints = UTypeConstraints(components.mkTypeSystem(ctx), eqConstraints) - heap = UMemory(ctx, typeConstraints) + val eqConstraints = UEqualityConstraints(ctx, ownership) + val typeConstraints = UTypeConstraints(ownership, components.mkTypeSystem(ctx), eqConstraints) + heap = UMemory(ctx, ownership, typeConstraints) arrayType = mockk() arrayValueSort = ctx.addressSort } diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefEqTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefEqTest.kt index eb0854c5c8..5309f81d3b 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefEqTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefEqTest.kt @@ -10,18 +10,21 @@ import org.usvm.UComponents import org.usvm.UContext import org.usvm.USizeSort import org.usvm.api.allocateConcreteRef +import org.usvm.collections.immutable.internal.MutabilityOwnership import kotlin.test.assertSame class HeapRefEqTest { private lateinit var ctx: UContext private lateinit var heap: UMemory + private lateinit var ownership: MutabilityOwnership @BeforeEach fun initializeContext() { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) - heap = UMemory(ctx, mockk()) + ownership = MutabilityOwnership() + heap = UMemory(ctx, ownership, mockk()) } @Test @@ -156,4 +159,4 @@ class HeapRefEqTest { val expected = concreteRef1EqConcreteRef2 or (symbolicRefEqGuard and symbolicIteEq) assertSame(expected, refsEq) } -} \ No newline at end of file +} diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefSplittingTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefSplittingTest.kt index 6a35092087..9488981e04 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefSplittingTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/HeapRefSplittingTest.kt @@ -26,6 +26,7 @@ import org.usvm.sizeSort import org.usvm.mkSizeExpr import org.usvm.api.memcpy import org.usvm.api.allocateArray +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UEqualityConstraints import org.usvm.constraints.UTypeConstraints import kotlin.test.assertEquals @@ -36,6 +37,7 @@ import kotlin.test.assertSame class HeapRefSplittingTest { private lateinit var ctx: UContext private lateinit var heap: UMemory + private lateinit var ownership: MutabilityOwnership private lateinit var valueFieldDescr: Pair private lateinit var addressFieldDescr: Pair @@ -46,10 +48,11 @@ class HeapRefSplittingTest { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - val eqConstraints = UEqualityConstraints(ctx) - val typeConstraints = UTypeConstraints(components.mkTypeSystem(ctx), eqConstraints) - heap = UMemory(ctx, typeConstraints) + val eqConstraints = UEqualityConstraints(ctx, ownership) + val typeConstraints = UTypeConstraints(ownership, components.mkTypeSystem(ctx), eqConstraints) + heap = UMemory(ctx, ownership, typeConstraints) valueFieldDescr = mockk() to ctx.bv32Sort addressFieldDescr = mockk() to ctx.addressSort diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/MemoryRegionTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/MemoryRegionTest.kt index b1ec385875..03064c40dd 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/MemoryRegionTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/MemoryRegionTest.kt @@ -16,6 +16,7 @@ import org.usvm.UHeapRef import org.usvm.USizeSort import org.usvm.collection.array.UAllocatedArrayId import org.usvm.collection.array.UInputArrayId +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.mkSizeExpr import org.usvm.regions.SetRegion import org.usvm.regions.emptyRegionTree @@ -26,12 +27,14 @@ import kotlin.test.assertTrue class MemoryRegionTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership @BeforeEach fun initializeContext() { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } } @@ -98,15 +101,15 @@ class MemoryRegionTest { val memoryRegion = UAllocatedArrayId<_, _, USizeSort>(mockk(), sizeSort, 0) .emptyRegion() - .write(idx1, mkBv(0), trueExpr) - .write(idx2, mkBv(1), trueExpr) + .write(idx1, mkBv(0), trueExpr, ownership) + .write(idx2, mkBv(1), trueExpr, ownership) val updatesBefore = memoryRegion.updates.toList() assertEquals(2, updatesBefore.size) assertTrue(updatesBefore.first().includesConcretely(idx1, trueExpr)) assertTrue(updatesBefore.last().includesConcretely(idx2, trueExpr)) - val memoryRegionAfter = memoryRegion.write(idx2, mkBv(2), trueExpr) + val memoryRegionAfter = memoryRegion.write(idx2, mkBv(2), trueExpr, ownership) val updatesAfter = memoryRegionAfter.updates.toList() assertEquals(2, updatesAfter.size) @@ -141,7 +144,7 @@ class MemoryRegionTest { val idx = indices.random(random) val value = refs.random(random) - memoryRegion = memoryRegion.write(ref to idx, value, trueExpr) + memoryRegion = memoryRegion.write(ref to idx, value, trueExpr, ownership) } val readRef = symbolicRefs.random(random) @@ -150,4 +153,4 @@ class MemoryRegionTest { memoryRegion.read(readRef to readIdx) } } -} \ No newline at end of file +} diff --git a/usvm-core/src/test/kotlin/org/usvm/memory/SetEntriesTest.kt b/usvm-core/src/test/kotlin/org/usvm/memory/SetEntriesTest.kt index d4d9d1f12f..6f8bb729d6 100644 --- a/usvm-core/src/test/kotlin/org/usvm/memory/SetEntriesTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/memory/SetEntriesTest.kt @@ -18,6 +18,7 @@ import org.usvm.collection.set.primitive.setEntries import org.usvm.collection.set.primitive.setUnion import org.usvm.collection.set.ref.refSetEntries import org.usvm.collection.set.ref.refSetUnion +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UEqualityConstraints import org.usvm.constraints.UTypeConstraints import org.usvm.isTrue @@ -28,6 +29,7 @@ import kotlin.test.assertTrue class SetEntriesTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var heap: UMemory private lateinit var setType: Type @@ -36,10 +38,11 @@ class SetEntriesTest { val components: UComponents = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - val eqConstraints = UEqualityConstraints(ctx) - val typeConstraints = UTypeConstraints(components.mkTypeSystem(ctx), eqConstraints) - heap = UMemory(ctx, typeConstraints) + val eqConstraints = UEqualityConstraints(ctx, ownership) + val typeConstraints = UTypeConstraints(ownership, components.mkTypeSystem(ctx), eqConstraints) + heap = UMemory(ctx, ownership, typeConstraints) setType = mockk() } diff --git a/usvm-core/src/test/kotlin/org/usvm/merging/CloseStatesSearcherTest.kt b/usvm-core/src/test/kotlin/org/usvm/merging/CloseStatesSearcherTest.kt index c552da72ac..922845bf75 100644 --- a/usvm-core/src/test/kotlin/org/usvm/merging/CloseStatesSearcherTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/merging/CloseStatesSearcherTest.kt @@ -8,6 +8,7 @@ import org.usvm.TestInstruction import org.usvm.TestState import org.usvm.UCallStack import org.usvm.UContext +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.ps.ExecutionTreeTracker import org.usvm.statistics.ApplicationGraph import org.usvm.statistics.distances.CfgStatisticsImpl @@ -164,7 +165,9 @@ class CloseStatesSearcherTest { val ctxMock = mockk>() every { ctxMock.getNextStateId() } returns 0u val callStack = UCallStack("") - val spyk = spyk(TestState(ctxMock, callStack, mockk(), mockk(), emptyList(), pathNode, mockk())) + val spyk = spyk( + TestState(ctxMock, MutabilityOwnership(), callStack, mockk(), mockk(), emptyList(), pathNode, mockk()) + ) spyk } val executionTreeTracker = ExecutionTreeTracker(rootNode).apply { add(states) } @@ -210,4 +213,4 @@ class CloseStatesSearcherTest { override fun statementsOf(method: String): Sequence = statements.asSequence().map { TestInstruction("", it) } } -} \ No newline at end of file +} diff --git a/usvm-core/src/test/kotlin/org/usvm/merging/MemoryMergingTest.kt b/usvm-core/src/test/kotlin/org/usvm/merging/MemoryMergingTest.kt index 21f21d745b..320f1ad912 100644 --- a/usvm-core/src/test/kotlin/org/usvm/merging/MemoryMergingTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/merging/MemoryMergingTest.kt @@ -17,6 +17,7 @@ import org.usvm.USort import org.usvm.api.allocateConcreteRef import org.usvm.api.readField import org.usvm.api.writeField +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.memory.URegisterStackLValue @@ -29,6 +30,7 @@ import kotlin.test.assertFails class MemoryMergingTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var translator: UExprTranslator private lateinit var smtSolver: KZ3Solver @@ -38,6 +40,7 @@ class MemoryMergingTest { every { components.mkTypeSystem(any()) } returns SingleTypeSystem every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } ctx = UContext(components) + ownership = MutabilityOwnership() translator = UExprTranslator(ctx) smtSolver = KZ3Solver(ctx) } @@ -45,9 +48,9 @@ class MemoryMergingTest { @Test fun `Empty memory`() = with(ctx) { val byCondition = mkConst("cond", boolSort) - val pathConstraints = UPathConstraints(this) - val memoryLeft = UMemory(this, pathConstraints.typeConstraints) - val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints) + val pathConstraints = UPathConstraints(this, ownership) + val memoryLeft = UMemory(this, ownership, pathConstraints.typeConstraints) + val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints, MutabilityOwnership(), MutabilityOwnership()) checkMergedEqualsToOriginal( memoryLeft, @@ -62,14 +65,14 @@ class MemoryMergingTest { @Test fun `Distinct stack`() = with(ctx) { val byCondition = mkConst("cond", boolSort) - val pathConstraints = UPathConstraints(this) + val pathConstraints = UPathConstraints(this, ownership) - val memoryLeft = UMemory(this, pathConstraints.typeConstraints) + val memoryLeft = UMemory(this, ownership, pathConstraints.typeConstraints) memoryLeft.stack.push(3) memoryLeft.stack.writeRegister(0, mkBv(42)) memoryLeft.stack.writeRegister(1, mkBv(1337)) - val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints) + val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints, MutabilityOwnership(), MutabilityOwnership()) memoryRight.stack.writeRegister(0, mkBv(13)) memoryRight.stack.writeRegister(2, mkBv(9)) @@ -87,9 +90,10 @@ class MemoryMergingTest { fun `Distinct regions`(): Unit = with(ctx) { assertFails { // TODO: improve memory regions constraints merging val byCondition = mkConst("cond", boolSort) - val pathConstraints = UPathConstraints(this) + val pathConstraints = UPathConstraints(this, ownership) - val memoryLeft = UMemory(this, pathConstraints.typeConstraints) + val leftOwnership = ownership + val memoryLeft = UMemory(this, leftOwnership, pathConstraints.typeConstraints) val ref1 = allocateConcreteRef() val ref2 = allocateConcreteRef() @@ -99,7 +103,7 @@ class MemoryMergingTest { memoryLeft.writeField(ref2, Unit, addressSort, mkRegisterReading(2, addressSort), trueExpr) memoryLeft.writeField(ref3, Unit, addressSort, mkRegisterReading(3, addressSort), trueExpr) - val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints) + val memoryRight = memoryLeft.clone(pathConstraints.typeConstraints, MutabilityOwnership(), MutabilityOwnership()) memoryRight.writeField(ref1, Unit, addressSort, mkRegisterReading(-1, addressSort), trueExpr) memoryRight.writeField(ref2, Unit, addressSort, mkRegisterReading(-2, addressSort), trueExpr) memoryRight.writeField(ref3, Unit, addressSort, mkRegisterReading(-3, addressSort), trueExpr) @@ -122,7 +126,9 @@ class MemoryMergingTest { vararg getters: (UMemory) -> UExpr, ) = with(ctx) { val mergeGuard = MutableMergeGuard(this).apply { appendThis(sequenceOf(byCondition)) } - val mergedMemory = checkNotNull(memoryLeft.mergeWith(memoryRight, mergeGuard)) + val mergedMemory = checkNotNull(memoryLeft.mergeWith( + memoryRight, mergeGuard, MutabilityOwnership(), MutabilityOwnership(), MutabilityOwnership() + )) for (getter in getters) { val leftExpr: UExpr = getter(memoryLeft).uncheckedCast() @@ -137,4 +143,4 @@ class MemoryMergingTest { smtSolver.pop() } } -} \ No newline at end of file +} diff --git a/usvm-core/src/test/kotlin/org/usvm/merging/PathConstraintsMergingTest.kt b/usvm-core/src/test/kotlin/org/usvm/merging/PathConstraintsMergingTest.kt index a7488c7238..157f6d62f7 100644 --- a/usvm-core/src/test/kotlin/org/usvm/merging/PathConstraintsMergingTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/merging/PathConstraintsMergingTest.kt @@ -9,6 +9,7 @@ import org.usvm.UBv32SizeExprProvider import org.usvm.UBv32Sort import org.usvm.UComponents import org.usvm.UContext +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.sizeSort import org.usvm.solver.UExprTranslator @@ -21,6 +22,7 @@ import kotlin.test.assertNotNull class PathConstraintsMergingTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var translator: UExprTranslator private lateinit var smtSolver: KZ3Solver @@ -30,14 +32,15 @@ class PathConstraintsMergingTest { every { components.mkTypeSystem(any()) } returns SingleTypeSystem every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } ctx = UContext(components) + ownership = MutabilityOwnership() translator = UExprTranslator(ctx) smtSolver = KZ3Solver(ctx) } @Test fun `Empty path constraints`() = with(ctx) { - val pcLeft = UPathConstraints(this) - val pcRight = pcLeft.clone() + val pcLeft = UPathConstraints(this, ownership) + val pcRight = pcLeft.clone(MutabilityOwnership(), MutabilityOwnership()) checkMergedEqualsOriginals(pcLeft, pcRight) } @@ -84,7 +87,7 @@ class PathConstraintsMergingTest { } private fun buildCommonPrefix(): Pair, UPathConstraints> = with(ctx) { - val pcLeft = UPathConstraints(this) + val pcLeft = UPathConstraints(this, ownership) // logical constraints pcLeft += (mkRegisterReading(0, sizeSort) eq mkRegisterReading(1, sizeSort)) or @@ -98,14 +101,16 @@ class PathConstraintsMergingTest { pcLeft += mkRegisterReading(4, addressSort) eq mkRegisterReading(5, addressSort) pcLeft += mkRegisterReading(6, addressSort) neq mkRegisterReading(7, addressSort) - val pcRight = pcLeft.clone() + val pcRight = pcLeft.clone(MutabilityOwnership(), MutabilityOwnership()) return pcLeft to pcRight } private fun checkMergedEqualsOriginals(left: UPathConstraints, right: UPathConstraints) = with(ctx) { val mergeGuard = MutableMergeGuard(this) - val result = left.mergeWith(right, mergeGuard) + val result = left.mergeWith( + right, mergeGuard, MutabilityOwnership(), MutabilityOwnership(), MutabilityOwnership() + ) assertNotNull(result) result +=ctx.mkOr(mergeGuard.thisConstraint, mergeGuard.otherConstraint) val constraintsAreNotEqual = run { @@ -119,4 +124,4 @@ class PathConstraintsMergingTest { val status = smtSolver.check() assertEquals(KSolverStatus.UNSAT, status) } -} \ No newline at end of file +} diff --git a/usvm-core/src/test/kotlin/org/usvm/model/ModelCompositionTest.kt b/usvm-core/src/test/kotlin/org/usvm/model/ModelCompositionTest.kt index 32ef5be7fd..b71b570bf9 100644 --- a/usvm-core/src/test/kotlin/org/usvm/model/ModelCompositionTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/model/ModelCompositionTest.kt @@ -2,7 +2,7 @@ package org.usvm.model import io.mockk.every import io.mockk.mockk -import kotlinx.collections.immutable.persistentMapOf +import org.usvm.collections.immutable.persistentHashMapOf import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test @@ -27,6 +27,7 @@ import org.usvm.collection.array.length.UInputArrayLengthId import org.usvm.collection.field.UFieldsEagerModelRegion import org.usvm.collection.field.UFieldsRegionId import org.usvm.collection.field.UInputFieldId +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UReadOnlyMemory import org.usvm.memory.key.USizeExprKeyInfo import org.usvm.mkSizeExpr @@ -36,6 +37,7 @@ import kotlin.test.assertSame class ModelCompositionTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var concreteNull: UConcreteHeapRef @BeforeEach @@ -43,8 +45,9 @@ class ModelCompositionTest { val components: UComponents<*, USizeSort> = mockk() every { components.mkTypeSystem(any()) } returns mockk() ctx = UContext(components) + ownership = MutabilityOwnership() - every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory -> UComposer(ctx, memory) } } + every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } concreteNull = ctx.mkConcreteHeapRef(NULL_ADDRESS) } @@ -57,14 +60,14 @@ class ModelCompositionTest { ) val model = UModelBase(ctx, stackModel, mockk(), mockk(), emptyMap(), concreteNull) - val composer = UComposer(this, model) + val composer = UComposer(this, model, defaultOwnership) val region = UAllocatedArrayId<_, _, USizeSort>(mockk(), bv32Sort, 1) .emptyRegion() - .write(0.toBv(), 0.toBv(), trueExpr) - .write(1.toBv(), 1.toBv(), trueExpr) - .write(mkRegisterReading(1, sizeSort), 2.toBv(), trueExpr) - .write(mkRegisterReading(2, sizeSort), 3.toBv(), trueExpr) + .write(0.toBv(), 0.toBv(), trueExpr, ownership) + .write(1.toBv(), 1.toBv(), trueExpr, ownership) + .write(mkRegisterReading(1, sizeSort), 2.toBv(), trueExpr, ownership) + .write(mkRegisterReading(2, sizeSort), 3.toBv(), trueExpr, ownership) val reading = region.read(mkRegisterReading(0, sizeSort)) val expr = composer.compose(reading) @@ -78,8 +81,9 @@ class ModelCompositionTest { val composedSymbolicHeapRef = ctx.mkConcreteHeapRef(-1) val inputArray = UMemory2DArray( - persistentMapOf((composedSymbolicHeapRef to mkBv(0)) to mkBv(1)), mkBv(0) + persistentHashMapOf(ownership, (composedSymbolicHeapRef to mkBv(0)) to mkBv(1)), mkBv(0) ) + val arrayModel = UArrayEagerModelRegion(arrayMemoryId, inputArray) val stackModel = URegistersStackEagerModel( @@ -89,7 +93,7 @@ class ModelCompositionTest { val model = UModelBase( ctx, stackModel, mockk(), mockk(), mapOf(arrayMemoryId to arrayModel), concreteNull ) - val composer = UComposer(this, model) + val composer = UComposer(this, model, defaultOwnership) val symbolicRef = mkRegisterReading(0, addressSort) as UHeapRef @@ -131,7 +135,7 @@ class ModelCompositionTest { val arrayType = mockk() val arrayLengthMemoryId = UArrayLengthsRegionId(sizeSort, arrayType) - val inputLength = UMemory1DArray(persistentMapOf(composedRef0 to mkBv(42)), mkBv(0)) + val inputLength = UMemory1DArray(persistentHashMapOf(ownership, composedRef0 to mkBv(42)), mkBv(0)) val arrayLengthModel = UArrayLengthEagerModelRegion(arrayLengthMemoryId, inputLength) val stackModel = URegistersStackEagerModel( @@ -148,13 +152,13 @@ class ModelCompositionTest { ctx, stackModel, mockk(), mockk(), mapOf(arrayLengthMemoryId to arrayLengthModel), concreteNull ) - val composer = UComposer(this, model) + val composer = UComposer(this, model, defaultOwnership) val region = UInputArrayLengthId(arrayType, bv32Sort) .emptyRegion() - .write(symbolicRef1, 0.toBv(), trueExpr) - .write(symbolicRef2, 1.toBv(), trueExpr) - .write(symbolicRef3, 2.toBv(), trueExpr) + .write(symbolicRef1, 0.toBv(), trueExpr, ownership) + .write(symbolicRef2, 1.toBv(), trueExpr, ownership) + .write(symbolicRef3, 2.toBv(), trueExpr, ownership) val reading = region.read(symbolicRef0) val expr = composer.compose(reading) @@ -176,7 +180,7 @@ class ModelCompositionTest { val field = mockk() val fieldMemoryId = UFieldsRegionId(field, addressSort) - val inputField = UMemory1DArray(persistentMapOf(composedRef0 to composedRef0), concreteNull) + val inputField = UMemory1DArray(persistentHashMapOf(ownership, composedRef0 to composedRef0), concreteNull) val fieldModel = UFieldsEagerModelRegion(fieldMemoryId, inputField) val stackModel = URegistersStackEagerModel( @@ -193,13 +197,13 @@ class ModelCompositionTest { ctx, stackModel, mockk(), mockk(), mapOf(fieldMemoryId to fieldModel), concreteNull ) - val composer = UComposer(this, model) + val composer = UComposer(this, model, defaultOwnership) val region = UInputFieldId(field, addressSort) .emptyRegion() - .write(symbolicRef1, symbolicRef1, trueExpr) - .write(symbolicRef2, symbolicRef2, trueExpr) - .write(symbolicRef3, symbolicRef3, trueExpr) + .write(symbolicRef1, symbolicRef1, trueExpr, ownership) + .write(symbolicRef2, symbolicRef2, trueExpr, ownership) + .write(symbolicRef3, symbolicRef3, trueExpr, ownership) val reading = region.read(symbolicRef0) val expr = composer.compose(reading) @@ -223,14 +227,14 @@ class ModelCompositionTest { ctx, stackModel, mockk(), mockk(), emptyMap(), concreteNull ) - val composer = UComposer(this, model) + val composer = UComposer(this, model, defaultOwnership) val emptyRegion = UAllocatedArrayId<_, _, USizeSort>(mockk(), bv32Sort, 1).emptyRegion() run { val region = emptyRegion - .write(index0, nonDefaultValue0, trueGuard) - .write(index0, nonDefaultValue1, falseGuard) + .write(index0, nonDefaultValue0, trueGuard, ownership) + .write(index0, nonDefaultValue1, falseGuard, ownership) val reading = region.read(index0) val expr = composer.compose(reading) @@ -239,8 +243,8 @@ class ModelCompositionTest { run { val region = emptyRegion - .write(index1, nonDefaultValue0, trueGuard) - .write(index0, nonDefaultValue1, falseGuard) + .write(index1, nonDefaultValue0, trueGuard, ownership) + .write(index0, nonDefaultValue1, falseGuard, ownership) val reading = region.read(index0) val expr = composer.compose(reading) diff --git a/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt b/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt index ad80ccf8d1..1b0cbe67d0 100644 --- a/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/model/ModelDecodingTest.kt @@ -26,6 +26,7 @@ import org.usvm.api.writeField import org.usvm.collection.array.UArrayIndexLValue import org.usvm.collection.set.primitive.setEntries import org.usvm.collection.set.ref.refSetEntries +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.memory.UReadOnlyMemory @@ -46,6 +47,7 @@ private typealias Type = SingleTypeSystem.SingleType class ModelDecodingTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var solver: USolverBase private lateinit var pc: UPathConstraints @@ -59,20 +61,21 @@ class ModelDecodingTest { every { components.mkTypeSystem(any()) } returns SingleTypeSystem ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory -> UComposer(ctx, memory) } } + every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } val translator = UExprTranslator(ctx) val decoder = ULazyModelDecoder(translator) val typeSolver = UTypeSolver(SingleTypeSystem) solver = USolverBase(ctx, KZ3Solver(ctx), typeSolver, translator, decoder, timeout = INFINITE) - pc = UPathConstraints(ctx) + pc = UPathConstraints(ctx, ownership) stack = URegistersStack() stack.push(10) mocker = UIndexedMocker() - heap = UMemory(ctx, pc.typeConstraints, stack, mocker) + heap = UMemory(ctx, ownership, pc.typeConstraints, stack, mocker) } @Test @@ -131,7 +134,7 @@ class ModelDecodingTest { val field = mockk() val method = mockk() - val mockedValue = mocker.call(method, emptySequence(), addressSort) + val mockedValue = mocker.call(method, emptySequence(), addressSort, ownership) val ref1 = heap.readField(mockedValue, field, addressSort) heap.writeField(ref1, field, addressSort, allocateConcreteRef(), trueExpr) val ref2 = heap.readField(mockedValue, field, addressSort) @@ -153,7 +156,7 @@ class ModelDecodingTest { val field = mockk() val method = mockk() - val mockedValue = mocker.call(method, emptySequence(), addressSort) + val mockedValue = mocker.call(method, emptySequence(), addressSort, ownership) val ref1 = heap.readField(mockedValue, field, addressSort) heap.writeField(ref1, field, addressSort, ref1, trueExpr) val ref2 = heap.readField(mockedValue, field, addressSort) diff --git a/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt b/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt index 675fec1f39..603b3205a1 100644 --- a/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/solver/SoftConstraintsTest.kt @@ -15,6 +15,7 @@ import org.usvm.UComposer import org.usvm.UContext import org.usvm.USizeSort import org.usvm.collection.array.length.UInputArrayLengthId +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UReadOnlyMemory import org.usvm.model.ULazyModelDecoder @@ -27,6 +28,7 @@ private typealias Type = SingleTypeSystem.SingleType open class SoftConstraintsTest { private lateinit var ctx: UContext + private lateinit var ownership: MutabilityOwnership private lateinit var softConstraintsProvider: USoftConstraintsProvider private lateinit var translator: UExprTranslator private lateinit var decoder: ULazyModelDecoder @@ -38,8 +40,9 @@ open class SoftConstraintsTest { every { components.mkTypeSystem(any()) } returns SingleTypeSystem ctx = UContext(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - every { components.mkComposer(any()) } answers { { memory: UReadOnlyMemory -> UComposer(ctx, memory) } } + every { components.mkComposer(any()) } answers { { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } softConstraintsProvider = USoftConstraintsProvider(ctx) @@ -58,7 +61,7 @@ open class SoftConstraintsTest { val sndRegister = mkRegisterReading(idx = 1, bv32Sort) val expr = mkBvSignedLessOrEqualExpr(fstRegister, sndRegister) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx, ownership) pc += expr val softConstraints = softConstraintsProvider.makeSoftConstraints(pc) @@ -86,7 +89,7 @@ open class SoftConstraintsTest { every { softConstraintsProvider.provide(any()) } answers { callOriginal() } - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx, ownership) pc += fstExpr pc += sndExpr pc += sameAsFirstExpr @@ -129,13 +132,13 @@ open class SoftConstraintsTest { val secondInputRef = mkRegisterReading(1, addressSort) val region = UInputArrayLengthId(arrayType, sizeSort) .emptyRegion() - .write(inputRef, mkRegisterReading(3, sizeSort), guard = trueExpr) + .write(inputRef, mkRegisterReading(3, sizeSort), guard = trueExpr, ownership) val size = 25 val reading = region.read(secondInputRef) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx, ownership) pc += reading eq size.toBv() pc += inputRef eq secondInputRef pc += (inputRef eq nullRef).not() @@ -155,9 +158,9 @@ open class SoftConstraintsTest { val inputRef = mkRegisterReading(0, addressSort) val region = UInputArrayLengthId(arrayType, sizeSort) .emptyRegion() - .write(inputRef, mkRegisterReading(3, sizeSort), guard = trueExpr) + .write(inputRef, mkRegisterReading(3, sizeSort), guard = trueExpr, ownership) - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx, ownership) pc += (inputRef eq nullRef).not() val softConstraints = softConstraintsProvider.makeSoftConstraints(pc) @@ -175,7 +178,7 @@ open class SoftConstraintsTest { val bvValue = 0.toBv() val expression = mkBvSignedLessOrEqualExpr(bvValue, inputRef).not() - val pc = UPathConstraints(ctx) + val pc = UPathConstraints(ctx, ownership) pc += expression val softConstraints = softConstraintsProvider.makeSoftConstraints(pc) diff --git a/usvm-core/src/test/kotlin/org/usvm/solver/TranslationTest.kt b/usvm-core/src/test/kotlin/org/usvm/solver/TranslationTest.kt index 16a05dcad3..a2e9f36b4d 100644 --- a/usvm-core/src/test/kotlin/org/usvm/solver/TranslationTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/solver/TranslationTest.kt @@ -32,6 +32,7 @@ import org.usvm.collection.array.USymbolicArrayInputToInputCopyAdapter import org.usvm.collection.array.length.UInputArrayLengthId import org.usvm.collection.field.UInputFieldId import org.usvm.collection.map.ref.URefMapEntryLValue +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UMemory import org.usvm.mkSizeExpr import org.usvm.memory.key.USizeExprKeyInfo @@ -41,6 +42,7 @@ import kotlin.test.assertSame class TranslationTest { private lateinit var ctx: RecordingCtx + private lateinit var ownership: MutabilityOwnership private lateinit var heap: UMemory private lateinit var translator: UExprTranslator @@ -69,8 +71,9 @@ class TranslationTest { every { components.mkTypeSystem(any()) } returns mockk() ctx = RecordingCtx(components) + ownership = MutabilityOwnership() every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - heap = UMemory(ctx, mockk()) + heap = UMemory(ctx, ownership, mockk()) translator = UExprTranslator(ctx) valueFieldDescr = mockk() to ctx.bv32Sort @@ -144,8 +147,8 @@ class TranslationTest { val region = UInputArrayId<_, _, USizeSort>(valueArrayDescr, bv32Sort) .emptyRegion() - .write(ref1 to idx1, val1, trueExpr) - .write(ref2 to idx2, val2, trueExpr) + .write(ref1 to idx1, val1, trueExpr, ownership) + .write(ref2 to idx2, val2, trueExpr, ownership) val ref3 = mkRegisterReading(4, addressSort) val idx3 = mkRegisterReading(5, sizeSort) @@ -175,8 +178,8 @@ class TranslationTest { val region = UInputArrayId<_, _, USizeSort>(valueArrayDescr, bv32Sort) .emptyRegion() - .write(ref1 to idx1, val1, trueExpr) - .write(ref2 to idx2, val2, trueExpr) + .write(ref1 to idx1, val1, trueExpr, ownership) + .write(ref2 to idx2, val2, trueExpr, ownership) val concreteRef = allocateConcreteRef() @@ -227,9 +230,9 @@ class TranslationTest { val region = UInputFieldId(mockk(), bv32Sort) .emptyRegion() - .write(ref1, mkBv(1), g1) - .write(ref2, mkBv(2), g2) - .write(ref3, mkBv(3), g3) + .write(ref1, mkBv(1), g1, ownership) + .write(ref2, mkBv(2), g2, ownership) + .write(ref3, mkBv(3), g3, ownership) val ref0 = mkRegisterReading(0, addressSort) val reading = region.read(ref0) @@ -253,9 +256,9 @@ class TranslationTest { val region = UInputArrayLengthId(mockk(), bv32Sort) .emptyRegion() - .write(ref1, mkBv(1), trueExpr) - .write(ref2, mkBv(2), trueExpr) - .write(ref3, mkBv(3), trueExpr) + .write(ref1, mkBv(1), trueExpr, ownership) + .write(ref2, mkBv(2), trueExpr, ownership) + .write(ref3, mkBv(3), trueExpr, ownership) val ref0 = mkRegisterReading(0, addressSort) val reading = region.read(ref0) @@ -283,8 +286,8 @@ class TranslationTest { val inputRegion1 = UInputArrayId<_, _, USizeSort>(valueArrayDescr, bv32Sort) .emptyRegion() - .write(ref1 to idx1, val1, trueExpr) - .write(ref2 to idx2, val2, trueExpr) + .write(ref1 to idx1, val1, trueExpr, ownership) + .write(ref2 to idx2, val2, trueExpr, ownership) val adapter = USymbolicArrayInputToInputCopyAdapter( @@ -383,8 +386,8 @@ class TranslationTest { val inputRegion1 = UInputArrayId<_, _, USizeSort>(valueArrayDescr, addressSort) .emptyRegion() - .write(ref1 to idx1, val1, trueExpr) - .write(ref2 to idx2, val2, trueExpr) + .write(ref1 to idx1, val1, trueExpr, ownership) + .write(ref2 to idx2, val2, trueExpr, ownership) val adapter = USymbolicArrayInputToInputCopyAdapter( ref1 to mkSizeExpr(0), @@ -424,8 +427,8 @@ class TranslationTest { val allocatedRegion1 = UAllocatedArrayId<_, _, USizeSort>(valueArrayDescr, addressSort, 1) .emptyRegion() - .write(idx1, val1, trueExpr) - .write(idx2, val2, trueExpr) + .write(idx1, val1, trueExpr, ownership) + .write(idx2, val2, trueExpr, ownership) val adapter = USymbolicArrayAllocatedToAllocatedCopyAdapter( mkSizeExpr(0), mkSizeExpr(0), mkSizeExpr(5), USizeExprKeyInfo() @@ -458,12 +461,12 @@ class TranslationTest { fun testCachingOfTranslatedMemoryUpdates() = with(ctx) { val allocatedRegion = UAllocatedArrayId<_, _, USizeSort>(valueArrayDescr, sizeSort, 0) .emptyRegion() - .write(mkRegisterReading(0, sizeSort), mkBv(0), trueExpr) - .write(mkRegisterReading(1, sizeSort), mkBv(1), trueExpr) + .write(mkRegisterReading(0, sizeSort), mkBv(0), trueExpr, ownership) + .write(mkRegisterReading(1, sizeSort), mkBv(1), trueExpr, ownership) val allocatedRegionExtended = allocatedRegion - .write(mkRegisterReading(2, sizeSort), mkBv(2), trueExpr) - .write(mkRegisterReading(3, sizeSort), mkBv(3), trueExpr) + .write(mkRegisterReading(2, sizeSort), mkBv(2), trueExpr, ownership) + .write(mkRegisterReading(3, sizeSort), mkBv(3), trueExpr, ownership) val reading = allocatedRegion.read(mkRegisterReading(4, sizeSort)) val readingExtended = allocatedRegionExtended.read(mkRegisterReading(5, sizeSort)) diff --git a/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt b/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt index 8fddbff544..ba6acb47f9 100644 --- a/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt +++ b/usvm-core/src/test/kotlin/org/usvm/types/TypeSolverTest.kt @@ -18,6 +18,7 @@ import org.usvm.api.readField import org.usvm.api.typeStreamOf import org.usvm.api.writeField import org.usvm.collection.array.UInputArrayId +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.isFalse import org.usvm.isTrue @@ -60,6 +61,7 @@ class TypeSolverTest { private val typeSystem = testTypeSystem private val components = mockk>() private val ctx = UContext(components) + private val ownership = MutabilityOwnership() private val solver: USolverBase private val typeSolver: UTypeSolver @@ -73,12 +75,11 @@ class TypeSolverTest { every { components.mkSolver(ctx) } returns solver every { components.mkTypeSystem(ctx) } returns typeSystem every { components.mkSizeExprProvider(any()) } answers { UBv32SizeExprProvider(ctx) } - every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory -> UComposer(ctx, memory) } } + every { components.mkComposer(ctx) } answers { { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> UComposer(ctx, memory, ownership) } } } - private val pc = UPathConstraints(ctx) - private val memory = UMemory(ctx, pc.typeConstraints) - + private val pc = UPathConstraints(ctx, ownership) + private val memory = UMemory(ctx, ownership, pc.typeConstraints) @Test fun `Test concrete ref -- open type inheritance`() { val ref = memory.allocConcrete(base1) @@ -235,7 +236,7 @@ class TypeSolverTest { pc += mkHeapRefEq(b1, b2) - with(pc.clone()) { + with(pc.clone(MutabilityOwnership(), MutabilityOwnership())) { val result = solver.check(this) assertIs>>(result) @@ -248,7 +249,7 @@ class TypeSolverTest { assertTrue(concreteA != concreteB1 || concreteB1 != concreteC || concreteC != concreteA) } - with(pc.clone()) { + with(pc.clone(MutabilityOwnership(), MutabilityOwnership())) { val model = mockk> { every { eval(a) } returns mkConcreteHeapRef(INITIAL_INPUT_ADDRESS) every { eval(b1) } returns mkConcreteHeapRef(INITIAL_INPUT_ADDRESS) @@ -267,7 +268,7 @@ class TypeSolverTest { } - with(pc.clone()) { + with(pc.clone(MutabilityOwnership(), MutabilityOwnership())) { this += mkHeapRefEq(a, c) and mkHeapRefEq(b1, c) val result = solver.check(this) assertIs>>(result) @@ -344,15 +345,15 @@ class TypeSolverTest { val idx2 = 0.toBv() val field = mockk() - val heap = UMemory(ctx, mockk()) + val heap = UMemory(ctx, ownership, mockk()) heap.writeField(val1, field, bv32Sort, 1.toBv(), trueExpr) heap.writeField(val2, field, bv32Sort, 2.toBv(), trueExpr) val inputRegion = UInputArrayId<_, _, USizeSort>(mockk(), addressSort) .emptyRegion() - .write(arr1 to idx1, val1, trueExpr) - .write(arr2 to idx2, val2, trueExpr) + .write(arr1 to idx1, val1, trueExpr, ownership) + .write(arr2 to idx2, val2, trueExpr, ownership) val firstReading = inputRegion.read(arr1 to idx1) val secondReading = inputRegion.read(arr2 to idx2) @@ -503,12 +504,12 @@ class TypeSolverTest { val ref = mkConcreteHeapRef(1) pc.typeConstraints.allocate(ref.address, base1) - with(pc.clone()) { + with(pc.clone(MutabilityOwnership(), MutabilityOwnership())) { this += mkIsSubtypeExpr(ref, top).not() assertTrue(isFalse) } - with(pc.clone()) { + with(pc.clone(MutabilityOwnership(), MutabilityOwnership())) { this += mkIsSupertypeExpr(ref, derived1A).not() assertTrue(isFalse) } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt index b50d890df5..42344b579e 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt @@ -7,13 +7,13 @@ import org.usvm.UComposer import org.usvm.UContext import org.usvm.UMachineOptions import org.usvm.USizeExprProvider +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.memory.UReadOnlyMemory import org.usvm.model.ULazyModelDecoder import org.usvm.solver.UExprTranslator import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase import org.usvm.solver.UTypeSolver -import kotlin.time.Duration class JcComponents( private val typeSystem: JcTypeSystem, @@ -34,8 +34,8 @@ class JcComponents( override fun > mkComposer( ctx: Context - ): (UReadOnlyMemory) -> UComposer = - { memory: UReadOnlyMemory -> JcComposer(ctx, memory) } + ): (UReadOnlyMemory, MutabilityOwnership) -> UComposer = + { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> JcComposer(ctx, memory, ownership) } override fun > mkSolver(ctx: Context): USolverBase { val (translator, decoder) = buildTranslatorAndLazyDecoder(ctx) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcTransformer.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcTransformer.kt index 253aeaf092..50de13c8c4 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcTransformer.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcTransformer.kt @@ -10,6 +10,7 @@ import org.usvm.UContext import org.usvm.UExpr import org.usvm.USort import org.usvm.UTransformer +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.machine.interpreter.statics.JcStaticFieldLValue import org.usvm.machine.interpreter.statics.JcStaticFieldReading import org.usvm.machine.interpreter.statics.JcStaticFieldRegionId @@ -27,7 +28,8 @@ interface JcTransformer : UTransformer { class JcComposer( ctx: UContext, memory: UReadOnlyMemory, -) : UComposer(ctx, memory), JcTransformer { + ownership: MutabilityOwnership, +) : UComposer(ctx, memory, ownership), JcTransformer { override fun transform(expr: JcStaticFieldReading): UExpr = memory.read(JcStaticFieldLValue(expr.field, expr.sort)) } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcCallSiteRegion.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcCallSiteRegion.kt index c40114a44e..2459b0bc0a 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcCallSiteRegion.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcCallSiteRegion.kt @@ -1,13 +1,14 @@ package org.usvm.machine.interpreter -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import org.jacodb.api.jvm.cfg.JcLambdaExpr import org.usvm.UAddressSort import org.usvm.UBoolExpr import org.usvm.UConcreteHeapAddress import org.usvm.UConcreteHeapRef import org.usvm.UExpr +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.machine.JcContext import org.usvm.memory.UMemoryRegion import org.usvm.memory.UMemoryRegionId @@ -22,10 +23,10 @@ class JcLambdaCallSiteRegionId(private val ctx: JcContext) : UMemoryRegionId = persistentHashMapOf() + private val callSites: UPersistentHashMap = persistentHashMapOf(), ) : UMemoryRegion { - fun writeCallSite(callSite: JcLambdaCallSite) = - JcLambdaCallSiteMemoryRegion(ctx, callSites.put(callSite.ref.address, callSite)) + fun writeCallSite(callSite: JcLambdaCallSite, ownership: MutabilityOwnership) = + JcLambdaCallSiteMemoryRegion(ctx, callSites.put(callSite.ref.address, callSite, ownership)) fun findCallSite(ref: UConcreteHeapRef): JcLambdaCallSite? = callSites[ref.address] @@ -36,7 +37,8 @@ internal class JcLambdaCallSiteMemoryRegion( override fun write( key: Nothing, value: UExpr, - guard: UBoolExpr + guard: UBoolExpr, + ownership: MutabilityOwnership, ): UMemoryRegion { error("Unsupported operation for call site region") } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt index b615e7cc93..2adba5cf61 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt @@ -431,7 +431,7 @@ class JcExprResolver( private fun UWritableMemory.writeCallSite(callSite: JcLambdaCallSite) { val callSiteRegion = getRegion(ctx.lambdaCallSiteRegionId) as JcLambdaCallSiteMemoryRegion - val updatedRegion = callSiteRegion.writeCallSite(callSite) + val updatedRegion = callSiteRegion.writeCallSite(callSite, ownership) setRegion(ctx.lambdaCallSiteRegionId, updatedRegion) } @@ -1015,7 +1015,10 @@ class JcExprResolver( if (sort === voidSort) return@forEach val memoryRegion = memory.getRegion(JcStaticFieldRegionId(sort)) as JcStaticFieldsMemoryRegion<*> - memoryRegion.mutatePrimitiveStaticFieldValuesToSymbolic(staticInitializer.enclosingClass) + memoryRegion.mutatePrimitiveStaticFieldValuesToSymbolic( + staticInitializer.enclosingClass, + memory.ownership + ) } } } @@ -1161,4 +1164,4 @@ class JcSimpleValueResolver( scope.calcOnState { mkStringConstRef(value) } -} \ No newline at end of file +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt index 128dfefae1..396c791c33 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt @@ -22,7 +22,6 @@ import org.jacodb.api.jvm.cfg.JcImmediate import org.jacodb.api.jvm.cfg.JcInst import org.jacodb.api.jvm.cfg.JcInstList import org.jacodb.api.jvm.cfg.JcInstRef -import org.jacodb.api.jvm.cfg.JcLocal import org.jacodb.api.jvm.cfg.JcLocalVar import org.jacodb.api.jvm.cfg.JcReturnInst import org.jacodb.api.jvm.cfg.JcSwitchInst @@ -48,6 +47,7 @@ import org.usvm.api.mapTypeStream import org.usvm.api.targets.JcTarget import org.usvm.collection.array.UArrayIndexLValue import org.usvm.collection.field.UFieldLValue +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.forkblacklists.UForkBlackList import org.usvm.machine.JcApplicationGraph import org.usvm.machine.JcConcreteMethodCallInst @@ -101,7 +101,8 @@ class JcInterpreter( } fun getInitialState(method: JcMethod, targets: List = emptyList()): JcState { - val state = JcState(ctx, method, targets = UTargetsSet.from(targets)) + val initOwnership = MutabilityOwnership() + val state = JcState(ctx, initOwnership, method, targets = UTargetsSet.from(targets)) val typedMethod = with(applicationGraph) { method.typed } val entrypointArguments = mutableListOf>() diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/statics/JcStaticFieldsRegion.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/statics/JcStaticFieldsRegion.kt index 87e58f9244..43299ebd8d 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/statics/JcStaticFieldsRegion.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/statics/JcStaticFieldsRegion.kt @@ -1,8 +1,6 @@ package org.usvm.machine.interpreter.statics import kotlinx.collections.immutable.PersistentList -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.persistentHashMapOf import kotlinx.collections.immutable.persistentListOf import org.jacodb.api.jvm.JcClassOrInterface import org.jacodb.api.jvm.JcField @@ -16,6 +14,10 @@ import org.usvm.UBoolExpr import org.usvm.UBoolSort import org.usvm.UExpr import org.usvm.USort +import org.usvm.collections.immutable.getOrDefault +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf import org.usvm.isTrue import org.usvm.machine.JcContext import org.usvm.machine.jctx @@ -46,7 +48,9 @@ data class JcStaticFieldRegionId( internal class JcStaticFieldsMemoryRegion( private val sort: Sort, - private var fieldValuesByClass: PersistentMap>> = persistentHashMapOf(), + // TODO multimap + private var fieldValuesByClass: UPersistentHashMap>> = + persistentHashMapOf(), private var initialStatics: PersistentList = persistentListOf() ) : UMemoryRegion, Sort> { val mutableStaticFields: List @@ -62,36 +66,30 @@ internal class JcStaticFieldsMemoryRegion( key: JcStaticFieldLValue, value: UExpr, guard: UBoolExpr, + ownership: MutabilityOwnership, ): UMemoryRegion, Sort> { val field = key.field - val enclosingClass = field.enclosingClass - if (enclosingClass !in fieldValuesByClass) { - fieldValuesByClass = fieldValuesByClass.put(enclosingClass, persistentHashMapOf()) - } + val classFields = fieldValuesByClass.getOrDefault(enclosingClass, persistentHashMapOf()) - val newFieldValues = fieldValuesByClass - .getValue(enclosingClass) - .guardedWrite(key.field, value, guard) { key.sort.sampleUValue() } - val newFieldsByClass = fieldValuesByClass.put(enclosingClass, newFieldValues) + val newFieldValues = classFields.guardedWrite(key.field, value, guard, ownership) { key.sort.sampleUValue() } + val newFieldsByClass = fieldValuesByClass.put(enclosingClass, newFieldValues, ownership) return JcStaticFieldsMemoryRegion(sort, newFieldsByClass, initialStatics) } - fun mutatePrimitiveStaticFieldValuesToSymbolic(enclosingClass: JcClassOrInterface) { - val staticFields = fieldValuesByClass[enclosingClass] ?: return + fun mutatePrimitiveStaticFieldValuesToSymbolic(enclosingClass: JcClassOrInterface, ownership: MutabilityOwnership) { + var staticFields = fieldValuesByClass[enclosingClass] ?: return - val staticsToRemove = staticFields - .keys - .filter { fieldShouldBeSymbolic(it) } + val staticsToRemove = staticFields.keys.filterTo(mutableListOf()) { fieldShouldBeSymbolic(it) } - initialStatics = initialStatics.addAll(staticsToRemove) - - // Remove concrete fields from the region - val updatedStaticFields = staticsToRemove.fold(staticFields) { acc, field -> - acc.remove(field) + for (static in staticsToRemove) { + initialStatics = initialStatics.add(static) + // Remove concrete fields from the region + staticFields = staticFields.remove(static, ownership) } - fieldValuesByClass = fieldValuesByClass.put(enclosingClass, updatedStaticFields) + + fieldValuesByClass = fieldValuesByClass.put(enclosingClass, staticFields, ownership) } companion object { diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/mocks/JcMocker.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/mocks/JcMocker.kt index d443bfdc14..342a111b73 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/mocks/JcMocker.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/mocks/JcMocker.kt @@ -34,7 +34,7 @@ fun mockMethod(scope: JcStepScope, methodCall: JcMethodCall, returnType: JcType) val mockSort = ctx.typeToSort(returnType) val mockValue = scope.calcOnState { - memory.mocker.call(method, arguments.asSequence(), mockSort) + memory.mocker.call(method, arguments.asSequence(), mockSort, memory.ownership) } if (mockSort == ctx.addressSort) { diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt index b802df8219..24f8290c4a 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcState.kt @@ -7,6 +7,7 @@ import org.usvm.PathNode import org.usvm.UCallStack import org.usvm.UState import org.usvm.api.targets.JcTarget +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.machine.JcContext import org.usvm.memory.UMemory @@ -16,10 +17,11 @@ import org.usvm.targets.UTargetsSet class JcState( ctx: JcContext, + ownership: MutabilityOwnership, override val entrypoint: JcMethod, callStack: UCallStack = UCallStack(), - pathConstraints: UPathConstraints = UPathConstraints(ctx), - memory: UMemory = UMemory(ctx, pathConstraints.typeConstraints), + pathConstraints: UPathConstraints = UPathConstraints(ctx, ownership), + memory: UMemory = UMemory(ctx, ownership, pathConstraints.typeConstraints), models: List> = listOf(), pathNode: PathNode = PathNode.root(), forkPoints: PathNode> = PathNode.root(), @@ -27,6 +29,7 @@ class JcState( targets: UTargetsSet = UTargetsSet.empty(), ) : UState( ctx, + ownership, callStack, pathConstraints, memory, @@ -36,13 +39,20 @@ class JcState( targets ) { override fun clone(newConstraints: UPathConstraints?): JcState { - val clonedConstraints = newConstraints ?: pathConstraints.clone() + val newThisOwnership = MutabilityOwnership() + val cloneOwnership = MutabilityOwnership() + val clonedConstraints = newConstraints?.also { + this.pathConstraints.changeOwnership(newThisOwnership) + it.changeOwnership(cloneOwnership) + } ?: pathConstraints.clone(newThisOwnership, cloneOwnership) + this.ownership = newThisOwnership return JcState( ctx, + cloneOwnership, entrypoint, callStack.clone(), clonedConstraints, - memory.clone(clonedConstraints.typeConstraints), + memory.clone(clonedConstraints.typeConstraints, newThisOwnership, cloneOwnership), models, pathNode, forkPoints, @@ -57,6 +67,10 @@ class JcState( * @return the merged state. TODO: Now it may reuse some of the internal components of the former states. */ override fun mergeWith(other: JcState, by: Unit): JcState? { + val newThisOwnership = MutabilityOwnership() + val newOtherOwnership = MutabilityOwnership() + val mergedOwnership = MutabilityOwnership() + require(entrypoint == other.entrypoint) { "Cannot merge states with different entrypoints" } // TODO: copy-paste @@ -65,10 +79,13 @@ class JcState( val mergeGuard = MutableMergeGuard(ctx) val mergedCallStack = callStack.mergeWith(other.callStack, Unit) ?: return null - val mergedPathConstraints = pathConstraints.mergeWith(other.pathConstraints, mergeGuard) - ?: return null - val mergedMemory = memory.clone(mergedPathConstraints.typeConstraints).mergeWith(other.memory, mergeGuard) - ?: return null + val mergedPathConstraints = pathConstraints.mergeWith( + other.pathConstraints, mergeGuard, newThisOwnership, newOtherOwnership, mergedOwnership + ) ?: return null + val mergedMemory = + memory.clone(mergedPathConstraints.typeConstraints, newThisOwnership, newOtherOwnership) + .mergeWith(other.memory, mergeGuard, newThisOwnership, newOtherOwnership, mergedOwnership) + ?: return null val mergedModels = models + other.models val methodResult = if (other.methodResult == JcMethodResult.NoCall && methodResult == JcMethodResult.NoCall) { JcMethodResult.NoCall @@ -78,8 +95,11 @@ class JcState( val mergedTargets = targets.takeIf { it == other.targets } ?: return null mergedPathConstraints += ctx.mkOr(mergeGuard.thisConstraint, mergeGuard.otherConstraint) + this.ownership = newThisOwnership + other.ownership = newOtherOwnership return JcState( ctx, + mergedOwnership, entrypoint, mergedCallStack, mergedPathConstraints, diff --git a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/Mocking.kt b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/Mocking.kt index b682d031f9..5a12591581 100644 --- a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/Mocking.kt +++ b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/Mocking.kt @@ -22,7 +22,11 @@ fun PyState.mock(what: MockHeader): MockResult { if (cached != null) { return MockResult(UninterpretedSymbolicPythonObject(cached, typeSystem), false, cached) } - val result = memory.mocker.call(what.method, what.args.map { it.address }.asSequence(), ctx.addressSort) + val result = memory.mocker.call( + what.method, what.args.map { it.address }.asSequence(), + ctx.addressSort, + memory.ownership + ) mocks[what] = result what.methodOwner?.let { mockedObjects.add(it) } return MockResult(UninterpretedSymbolicPythonObject(result, typeSystem), true, result) diff --git a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyMachine.kt b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyMachine.kt index d1dfd600eb..a33fbaa349 100644 --- a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyMachine.kt +++ b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyMachine.kt @@ -2,6 +2,7 @@ package org.usvm.machine import org.usvm.UMachine import org.usvm.UPathSelector +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.language.PyCallable import org.usvm.language.PyPinnedCallable import org.usvm.language.PyProgram @@ -60,10 +61,12 @@ class PyMachine( ) private fun getInitialState(target: PyUnpinnedCallable): PyState { - val pathConstraints = PyPathConstraints(ctx) + val initOwnership = MutabilityOwnership() + val pathConstraints = PyPathConstraints(ctx, initOwnership) val memory = UMemory( ctx, - pathConstraints.typeConstraints + initOwnership, + pathConstraints.typeConstraints, ).apply { stack.push(target.numberOfArguments) } @@ -78,6 +81,7 @@ class PyMachine( return PyState( ctx, + initOwnership, target, symbols, pathConstraints, diff --git a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyPathConstraints.kt b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyPathConstraints.kt index c7384dd257..550a1f67f2 100644 --- a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyPathConstraints.kt +++ b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyPathConstraints.kt @@ -5,6 +5,7 @@ import kotlinx.collections.immutable.persistentHashSetOf import org.usvm.UBoolExpr import org.usvm.UBv32Sort import org.usvm.UContext +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UEqualityConstraints import org.usvm.constraints.ULogicalConstraints import org.usvm.constraints.UNumericConstraints @@ -14,26 +15,31 @@ import org.usvm.machine.types.PythonType class PyPathConstraints( ctx: UContext<*>, + override var ownership: MutabilityOwnership, logicalConstraints: ULogicalConstraints = ULogicalConstraints.empty(), - equalityConstraints: UEqualityConstraints = UEqualityConstraints(ctx), + equalityConstraints: UEqualityConstraints = UEqualityConstraints(ctx, ownership), typeConstraints: UTypeConstraints = UTypeConstraints( + ownership, ctx.typeSystem(), equalityConstraints ), - numericConstraints: UNumericConstraints = UNumericConstraints(ctx, sort = ctx.bv32Sort), + numericConstraints: UNumericConstraints = + UNumericConstraints(ctx, sort = ctx.bv32Sort, ownership = ownership), var pythonSoftConstraints: PersistentSet = persistentHashSetOf(), ) : UPathConstraints( ctx, + ownership, logicalConstraints, equalityConstraints, typeConstraints, numericConstraints ) { - override fun clone(): PyPathConstraints { + override fun clone(thisOwnership: MutabilityOwnership, cloneOwnership: MutabilityOwnership): PyPathConstraints { val clonedLogicalConstraints = logicalConstraints.clone() - val clonedEqualityConstraints = equalityConstraints.clone() - val clonedTypeConstraints = typeConstraints.clone(clonedEqualityConstraints) - val clonedNumericConstraints = numericConstraints.clone() + val clonedEqualityConstraints = equalityConstraints.clone(thisOwnership, cloneOwnership) + val clonedTypeConstraints = typeConstraints.clone(clonedEqualityConstraints, thisOwnership, cloneOwnership) + val clonedNumericConstraints = numericConstraints.clone(thisOwnership, cloneOwnership) + this.ownership = thisOwnership return PyPathConstraints( ctx = ctx, logicalConstraints = clonedLogicalConstraints, @@ -41,6 +47,7 @@ class PyPathConstraints( typeConstraints = clonedTypeConstraints, numericConstraints = clonedNumericConstraints, pythonSoftConstraints = pythonSoftConstraints, + ownership = cloneOwnership ) } } diff --git a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyState.kt b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyState.kt index 1c4c3e3150..87c8a1e20b 100644 --- a/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyState.kt +++ b/usvm-python/usvm-python-main/src/main/kotlin/org/usvm/machine/PyState.kt @@ -10,6 +10,7 @@ import org.usvm.UCallStack import org.usvm.UMockSymbol import org.usvm.UPathSelector import org.usvm.UState +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.language.PyCallable import org.usvm.language.PyInstruction @@ -32,6 +33,7 @@ private val targets = UTargetsSet.empty() class PyState( ctx: PyContext, + ownership: MutabilityOwnership, private val pythonCallable: PyUnpinnedCallable, val inputSymbols: List, override val pathConstraints: PyPathConstraints, @@ -50,6 +52,7 @@ class PyState( var uniqueInstructions: PersistentSet = persistentSetOf(), ) : UState( ctx, + ownership, callStack, pathConstraints, memory, @@ -60,10 +63,16 @@ class PyState( ) { override fun clone(newConstraints: UPathConstraints?): PyState { require(newConstraints is PyPathConstraints?) - val newPathConstraints = newConstraints ?: pathConstraints.clone() - val newMemory = memory.clone(newPathConstraints.typeConstraints) + val newThisOwnership = MutabilityOwnership() + val cloneOwnership = MutabilityOwnership() + val newPathConstraints = newConstraints?.also { + this.pathConstraints.changeOwnership(newThisOwnership) + it.changeOwnership(cloneOwnership) + } ?: pathConstraints.clone(newThisOwnership, cloneOwnership) + val newMemory = memory.clone(newPathConstraints.typeConstraints, newThisOwnership, cloneOwnership) return PyState( ctx, + cloneOwnership, pythonCallable, inputSymbols, newPathConstraints, diff --git a/usvm-python/usvm-python-runner/src/test/kotlin/org/usvm/runner/PrintingResultReceiver.kt b/usvm-python/usvm-python-runner/src/test/kotlin/org/usvm/runner/PrintingResultReceiver.kt index 69f9dc5eda..b88cbb99f2 100644 --- a/usvm-python/usvm-python-runner/src/test/kotlin/org/usvm/runner/PrintingResultReceiver.kt +++ b/usvm-python/usvm-python-runner/src/test/kotlin/org/usvm/runner/PrintingResultReceiver.kt @@ -2,7 +2,7 @@ package org.usvm.runner import mu.KLogging -private val logger = object : KLogging() {}.logger +private val logger = object : KLogging() {}.logger class PrintingResultReceiver : USVMPythonAnalysisResultReceiver { var cnt: Int = 0 diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt index ee0cf88de1..3741d29245 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleMachine.kt @@ -5,6 +5,7 @@ import org.usvm.StateCollectionStrategy import org.usvm.UContext import org.usvm.UMachine import org.usvm.UMachineOptions +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.language.Method import org.usvm.language.Program import org.usvm.language.SampleType @@ -123,7 +124,7 @@ class SampleMachine( method: Method<*>, targets: List ): SampleState = - SampleState(ctx, method, targets = UTargetsSet.from(targets)).apply { + SampleState(ctx, MutabilityOwnership(), method, targets = UTargetsSet.from(targets)).apply { addEntryMethodCall(applicationGraph, method) val model = solver.emptyModel() models = persistentListOf(model) diff --git a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt index df33adcd57..8315b9eb09 100644 --- a/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt +++ b/usvm-sample-language/src/main/kotlin/org/usvm/machine/SampleState.kt @@ -6,6 +6,7 @@ import org.usvm.UContext import org.usvm.UExpr import org.usvm.USort import org.usvm.UState +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.language.Method import org.usvm.language.ProgramException @@ -20,10 +21,11 @@ import org.usvm.targets.UTargetsSet class SampleState( ctx: UContext, + ownership: MutabilityOwnership, override val entrypoint: Method<*>, callStack: UCallStack, Stmt> = UCallStack(), - pathConstraints: UPathConstraints = UPathConstraints(ctx), - memory: UMemory> = UMemory(ctx, pathConstraints.typeConstraints), + pathConstraints: UPathConstraints = UPathConstraints(ctx, ownership), + memory: UMemory> = UMemory(ctx, ownership, pathConstraints.typeConstraints), models: List> = listOf(), pathNode: PathNode = PathNode.root(), forkPoints: PathNode> = PathNode.root(), @@ -32,6 +34,7 @@ class SampleState( targets: UTargetsSet = UTargetsSet.empty(), ) : UState, Stmt, UContext, SampleTarget, SampleState>( ctx, + ownership, callStack, pathConstraints, memory, @@ -41,13 +44,19 @@ class SampleState( targets ) { override fun clone(newConstraints: UPathConstraints?): SampleState { - val clonedConstraints = newConstraints ?: pathConstraints.clone() + val newThisOwnership = MutabilityOwnership() + val cloneOwnership = MutabilityOwnership() + val clonedConstraints = newConstraints?.also { + this.pathConstraints.changeOwnership(newThisOwnership) + it.changeOwnership(cloneOwnership) + } ?: pathConstraints.clone(newThisOwnership, cloneOwnership) return SampleState( ctx, + cloneOwnership, entrypoint, callStack.clone(), clonedConstraints, - memory.clone(clonedConstraints.typeConstraints), + memory.clone(clonedConstraints.typeConstraints, newThisOwnership, cloneOwnership), models, pathNode, forkPoints, @@ -64,15 +73,20 @@ class SampleState( */ override fun mergeWith(other: SampleState, by: Unit): SampleState? { require(entrypoint == other.entrypoint) { "Cannot merge states with different entrypoints" } - + val thisOwnership = MutabilityOwnership() + val otherOwnership = MutabilityOwnership() + val mergedOwnership = MutabilityOwnership() val mergedPathNode = pathNode.mergeWith(other.pathNode, Unit) ?: return null val mergedForkPoints = forkPoints.mergeWith(other.forkPoints, Unit) ?: return null val mergeGuard = MutableMergeGuard(ctx) val mergedCallStack = callStack.mergeWith(other.callStack, Unit) ?: return null - val mergedPathConstraints = pathConstraints.mergeWith(other.pathConstraints, mergeGuard) - ?: return null - val mergedMemory = memory.clone(mergedPathConstraints.typeConstraints).mergeWith(other.memory, mergeGuard) + val mergedPathConstraints = + pathConstraints.mergeWith( + other.pathConstraints, mergeGuard, thisOwnership, otherOwnership, mergedOwnership + ) ?: return null + val mergedMemory = memory.clone(mergedPathConstraints.typeConstraints, thisOwnership, otherOwnership) + .mergeWith(other.memory, mergeGuard, thisOwnership, otherOwnership, mergedOwnership) ?: return null val mergedModels = models + other.models val mergedReturnRegister = if (returnRegister == null && other.returnRegister == null) { @@ -88,8 +102,11 @@ class SampleState( val mergedTargets = targets.takeIf { it == other.targets } ?: return null mergedPathConstraints += ctx.mkOr(mergeGuard.thisConstraint, mergeGuard.otherConstraint) + this.ownership = thisOwnership + other.ownership = otherOwnership return SampleState( ctx, + mergedOwnership, entrypoint, mergedCallStack, mergedPathConstraints, diff --git a/usvm-ts/src/main/kotlin/org/usvm/TSInterpreter.kt b/usvm-ts/src/main/kotlin/org/usvm/TSInterpreter.kt index 1770bdc844..82d56090c5 100644 --- a/usvm-ts/src/main/kotlin/org/usvm/TSInterpreter.kt +++ b/usvm-ts/src/main/kotlin/org/usvm/TSInterpreter.kt @@ -16,6 +16,7 @@ import org.jacodb.ets.base.EtsThrowStmt import org.jacodb.ets.base.EtsType import org.jacodb.ets.base.EtsValue import org.jacodb.ets.model.EtsMethod +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.forkblacklists.UForkBlackList import org.usvm.memory.URegisterStackLValue import org.usvm.solver.USatResult @@ -164,7 +165,7 @@ class TSInterpreter( fun getInitialState(method: EtsMethod, targets: List): TSState { - val state = TSState(ctx, method, targets = UTargetsSet.from(targets)) + val state = TSState(ctx, MutabilityOwnership(), method, targets = UTargetsSet.from(targets)) with(ctx) { val params = List(method.parameters.size) { idx -> diff --git a/usvm-ts/src/main/kotlin/org/usvm/state/TSState.kt b/usvm-ts/src/main/kotlin/org/usvm/state/TSState.kt index d22857a653..620e6e499a 100644 --- a/usvm-ts/src/main/kotlin/org/usvm/state/TSState.kt +++ b/usvm-ts/src/main/kotlin/org/usvm/state/TSState.kt @@ -8,6 +8,7 @@ import org.usvm.TSContext import org.usvm.TSTarget import org.usvm.UCallStack import org.usvm.UState +import org.usvm.collections.immutable.internal.MutabilityOwnership import org.usvm.constraints.UPathConstraints import org.usvm.memory.UMemory import org.usvm.model.UModelBase @@ -15,10 +16,11 @@ import org.usvm.targets.UTargetsSet class TSState( ctx: TSContext, + ownership: MutabilityOwnership, override val entrypoint: EtsMethod, callStack: UCallStack = UCallStack(), - pathConstraints: UPathConstraints = UPathConstraints(ctx), - memory: UMemory = UMemory(ctx, pathConstraints.typeConstraints), + pathConstraints: UPathConstraints = UPathConstraints(ctx, ownership), + memory: UMemory = UMemory(ctx, ownership, pathConstraints.typeConstraints), models: List> = listOf(), pathNode: PathNode = PathNode.root(), forkPoints: PathNode> = PathNode.root(), @@ -26,6 +28,7 @@ class TSState( targets: UTargetsSet = UTargetsSet.empty(), ) : UState( ctx, + ownership, callStack, pathConstraints, memory, @@ -35,14 +38,21 @@ class TSState( targets ) { override fun clone(newConstraints: UPathConstraints?): TSState { - val clonedConstraints = newConstraints ?: pathConstraints.clone() + val newThisOwnership = MutabilityOwnership() + val cloneOwnership = MutabilityOwnership() + val clonedConstraints = newConstraints?.also { + this.pathConstraints.changeOwnership(newThisOwnership) + it.changeOwnership(cloneOwnership) + } ?: pathConstraints.clone(newThisOwnership, cloneOwnership) + this.ownership = newThisOwnership return TSState( ctx, + cloneOwnership, entrypoint, callStack.clone(), clonedConstraints, - memory.clone(clonedConstraints.typeConstraints), + memory.clone(clonedConstraints.typeConstraints, newThisOwnership, cloneOwnership), models, pathNode, forkPoints, diff --git a/usvm-util/src/main/kotlin/org/usvm/algorithms/PersistentMultiMapBuilder.kt b/usvm-util/src/main/kotlin/org/usvm/algorithms/PersistentMultiMapBuilder.kt deleted file mode 100644 index 46c35c81f2..0000000000 --- a/usvm-util/src/main/kotlin/org/usvm/algorithms/PersistentMultiMapBuilder.kt +++ /dev/null @@ -1,95 +0,0 @@ -package org.usvm.algorithms - -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.PersistentSet -import kotlinx.collections.immutable.persistentHashSetOf -import kotlinx.collections.immutable.toPersistentHashSet - -typealias PersistentMultiMap = PersistentMap> - -/** - * Provides an efficient way to perform multiple mutations on [PersistentMultiMap]. - * */ -class PersistentMultiMapBuilder(original: PersistentMultiMap) : Iterable> { - private val map = original.builder() - - fun build(): PersistentMultiMap = map.build() - - fun containsValue(key: K, value: V): Boolean = - map[key]?.contains(value) ?: false - - fun isEmpty(): Boolean = map.isEmpty() - - operator fun get(key: K): Set? = map[key] - - fun add(key: K, value: V) { - val current = map[key] - val updated = current?.add(value) ?: persistentHashSetOf(value) - map[key] = updated - } - - fun addAll(key: K, values: Set) { - val current = map[key] - val updated = current?.addAll(values) ?: values.toPersistentHashSet() - map[key] = updated - } - - fun remove(key: K): Set? = map.remove(key) - - fun removeValue(key: K, value: V) { - val current = map[key] ?: return - val updated = current.remove(value) - map[key] = updated - } - - fun removeAllValues(key: K, values: Set) { - val current = map[key] ?: return - val updated = current.removeAll(values) - map[key] = updated - } - - fun clear() { - map.clear() - } - - override fun iterator(): Iterator> = MultiMapIterator(map) - - override fun equals(other: Any?): Boolean = when { - this === other -> true - other is PersistentMultiMapBuilder<*, *> -> map == other.map - other is Map<*, *> -> map == other - else -> false - } - - override fun hashCode(): Int = map.hashCode() - - private class MultiMapIterator( - mapBuilder: MutableMap> - ) : Iterator> { - private val valueIterators = mapBuilder.entries.mapTo(mutableListOf()) { it.key to it.value.iterator() } - private var value: Pair? = null - - override fun hasNext(): Boolean { - propagate() - return value != null - } - - override fun next(): Pair { - propagate() - val result = value ?: throw NoSuchElementException("Iterator is empty") - value = null - return result - } - - private fun propagate() { - while (value === null) { - val lastIterator = valueIterators.lastOrNull() ?: return - if (!lastIterator.second.hasNext()) { - valueIterators.removeLast() - continue - } - value = lastIterator.first to lastIterator.second.next() - } - } - } -} diff --git a/usvm-util/src/main/kotlin/org/usvm/algorithms/SeparationUtils.kt b/usvm-util/src/main/kotlin/org/usvm/algorithms/SeparationUtils.kt index 8b3882511a..42093a5c1f 100644 --- a/usvm-util/src/main/kotlin/org/usvm/algorithms/SeparationUtils.kt +++ b/usvm-util/src/main/kotlin/org/usvm/algorithms/SeparationUtils.kt @@ -1,7 +1,9 @@ package org.usvm.algorithms -import kotlinx.collections.immutable.PersistentMap -import kotlinx.collections.immutable.PersistentSet +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.implementations.immutableSet.UPersistentHashSet +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf data class SeparationResult( val overlap: C, @@ -9,16 +11,26 @@ data class SeparationResult( val rightUnique: C, ) -fun PersistentSet.separate(other: PersistentSet): SeparationResult> { - val overlap = this.retainAll(other) - val leftUnique = this.removeAll(overlap) - val rightUnique = other.removeAll(overlap) +fun UPersistentHashSet.separate( + other: UPersistentHashSet, + ownership: MutabilityOwnership +): SeparationResult> { + val overlap = this.retainAll(other, ownership) + val leftUnique = this.removeAll(overlap, ownership) + val rightUnique = other.removeAll(overlap, ownership) return SeparationResult(overlap, leftUnique, rightUnique) } -fun PersistentMap.separate(other: PersistentMap): SeparationResult> { - val overlap = this.builder().apply { entries.retainAll(other.entries) }.build() - val leftUnique = this.builder().apply { keys.removeAll(overlap.keys) }.build() - val rightUnique = other.builder().apply { keys.removeAll(overlap.keys) }.build() +// should be used with ownership that does not occur in both maps so that they will not be mutated +fun UPersistentHashMap.separate( + other: UPersistentHashMap, + ownership: MutabilityOwnership, +): SeparationResult> { + val overlap = + this.fold(persistentHashMapOf()) { map, entry -> + if (other.containsKey(entry.key)) map.put(entry.key, entry.value, ownership) else map + } + val leftUnique = overlap.fold(this) { map, entry -> map.remove(entry.key, ownership) } + val rightUnique = overlap.fold(other) { map, entry -> map.remove(entry.key, ownership) } return SeparationResult(overlap, leftUnique, rightUnique) -} \ No newline at end of file +} diff --git a/usvm-util/src/main/kotlin/org/usvm/algorithms/UPersistentMultiMap.kt b/usvm-util/src/main/kotlin/org/usvm/algorithms/UPersistentMultiMap.kt new file mode 100644 index 0000000000..08f90fcfee --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/algorithms/UPersistentMultiMap.kt @@ -0,0 +1,86 @@ +package org.usvm.algorithms + +import org.usvm.collections.immutable.* +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.implementations.immutableSet.UPersistentHashSet +import org.usvm.collections.immutable.internal.MutabilityOwnership + +typealias UPersistentMultiMap = UPersistentHashMap> + +/** + * Provides an efficient way to perform multiple mutations on [PersistentMultiMap]. + * */ + +fun UPersistentMultiMap.containsValue(key: K, value: V): Boolean = + this[key]?.contains(value) ?: false + +fun UPersistentMultiMap.addToSet( + key: K, + value: V, + ownership: MutabilityOwnership, +): UPersistentMultiMap { + val current = getOrDefault(key, persistentHashSetOf()) + val updated = current.add(value, ownership) + return this.put(key, updated, ownership) +} + +fun UPersistentMultiMap.addAll( + key: K, + values: Set, + ownership: MutabilityOwnership, +): UPersistentMultiMap { + val current = getOrDefault(key, persistentHashSetOf()) + val updated = current.addAll(values, ownership) + return this.put(key, updated, ownership) +} + +fun UPersistentMultiMap.removeValue( + key: K, + value: V, + ownership: MutabilityOwnership, +): UPersistentMultiMap { + val current = this[key] ?: return this + val updated = current.remove(value, ownership) + if (updated.isEmpty()) return this.remove(key, ownership) + return this.put(key, updated, ownership) +} + +fun UPersistentMultiMap.removeAllValues( + key: K, + values: Iterable, + ownership: MutabilityOwnership, +): UPersistentMultiMap { + val current = this[key] ?: return this + val updated = current.removeAll(values, ownership) + return this.put(key, updated, ownership) +} + +fun UPersistentMultiMap.multiMapIterator() = MultiMapIterator(this) + +class MultiMapIterator(multiMap: UPersistentMultiMap) : Iterator> { + private val valueIterators = multiMap.mapTo(mutableListOf()) { it.key to it.value.iterator() } + private var value: Pair? = null + + override fun hasNext(): Boolean { + propagate() + return value != null + } + + override fun next(): Pair { + propagate() + val result = value ?: throw NoSuchElementException("Iterator is empty") + value = null + return result + } + + private fun propagate() { + while (value === null) { + val lastIterator = valueIterators.lastOrNull() ?: return + if (!lastIterator.second.hasNext()) { + valueIterators.removeLast() + continue + } + value = lastIterator.first to lastIterator.second.next() + } + } +} diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/extensions.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/extensions.kt new file mode 100644 index 0000000000..7be7c79109 --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/extensions.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + + +package org.usvm.collections.immutable + +import org.usvm.collections.immutable.implementations.immutableMap.TrieNode +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.implementations.immutableSet.UPersistentHashSet +import org.usvm.collections.immutable.internal.MutabilityOwnership + + +fun UPersistentHashSet.isEmpty() = none() + +fun UPersistentHashSet.isNotEmpty() = any() + +fun UPersistentHashSet.addAll(elements: Collection, owner: MutabilityOwnership) : UPersistentHashSet = + elements.fold(this) { node, e -> node.add(e, owner) } + +fun UPersistentHashSet.removeAll(elements: Collection, owner: MutabilityOwnership): UPersistentHashSet = + elements.fold(this) { node, e -> node.remove(e, owner) } + +fun UPersistentHashSet.removeAll(elements: Iterable, owner: MutabilityOwnership): UPersistentHashSet = + elements.fold(this) { node, e -> node.remove(e, owner) } + +fun UPersistentHashSet.containsAll(elements: Collection): Boolean = elements.all { e -> this.contains(e) } + +fun UPersistentHashMap.isEmpty() = none() + +fun UPersistentHashMap.isNotEmpty() = any() + +fun UPersistentHashMap.getOrDefault(key: K, defaultValue: V) = get(key) ?: defaultValue + +inline fun UPersistentHashMap.getOrPut( + key: K, + owner: MutabilityOwnership, + defaultValue: () -> V, +): Pair, V> { + val current = get(key) ?: defaultValue().let { return put(key, it, owner) to it } + return this to current +} + +fun UPersistentHashMap.removeAll(keys: Iterable, owner: MutabilityOwnership): UPersistentHashMap = + keys.fold(this) { node, k -> node.remove(k, owner) } + +fun UPersistentHashMap.putAll(map: Map, owner: MutabilityOwnership): TrieNode = + map.asSequence().fold(this) { acc, entry -> acc.put(entry.key, entry.value, owner) } + +fun UPersistentHashMap.toMutableMap(): MutableMap = + mutableMapOf().also { this.forEach { entry -> it[entry.key] = entry.value } } + +/** + * Returns an empty persistent map. + */ +@Suppress("UNCHECKED_CAST") +fun persistentHashMapOf(): UPersistentHashMap = UPersistentHashMap.EMPTY as UPersistentHashMap + +fun persistentHashMapOf(map: Map, owner: MutabilityOwnership): UPersistentHashMap = + persistentHashMapOf().putAll(map, owner) + +fun persistentHashMapOf(owner: MutabilityOwnership, vararg pairs: Pair): UPersistentHashMap = + pairs.fold(persistentHashMapOf()) { acc, (k, v) -> acc.put(k, v, owner) } + +/** + * Returns an empty persistent set. + */ +@Suppress("UNCHECKED_CAST") +public fun persistentHashSetOf(): UPersistentHashSet = UPersistentHashSet.EMPTY as UPersistentHashSet diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieIterator.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieIterator.kt new file mode 100644 index 0000000000..e6a76a812f --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieIterator.kt @@ -0,0 +1,180 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.implementations.immutableMap + + +internal const val TRIE_MAX_HEIGHT = 7 + +internal abstract class TrieNodeBaseIterator : Iterator { + protected var buffer = TrieNode.EMPTY.buffer + private set + private var dataSize = 0 + protected var index = 0 + + fun reset(buffer: Array, dataSize: Int, index: Int) { + this.buffer = buffer + this.dataSize = dataSize + this.index = index + } + + fun reset(buffer: Array, dataSize: Int) { + reset(buffer, dataSize, 0) + } + + fun hasNextKey(): Boolean { + return index < dataSize + } + + fun currentKey(): K { + assert(hasNextKey()) + @Suppress("UNCHECKED_CAST") + return buffer[index] as K + } + + fun moveToNextKey() { + assert(hasNextKey()) + index += 2 + } + + fun hasNextNode(): Boolean { + assert(index >= dataSize) + return index < buffer.size + } + + fun currentNode(): TrieNode { + assert(hasNextNode()) + @Suppress("UNCHECKED_CAST") + return buffer[index] as TrieNode + } + + fun moveToNextNode() { + assert(hasNextNode()) + index++ + } + + override fun hasNext(): Boolean { + return hasNextKey() + } +} + +internal class TrieNodeKeysIterator : TrieNodeBaseIterator() { + override fun next(): K { + assert(hasNextKey()) + index += 2 + @Suppress("UNCHECKED_CAST") + return buffer[index - 2] as K + } +} + +internal class TrieNodeValuesIterator : TrieNodeBaseIterator() { + override fun next(): V { + assert(hasNextKey()) + index += 2 + @Suppress("UNCHECKED_CAST") + return buffer[index - 1] as V + } +} + +internal class TrieNodeEntriesIterator : TrieNodeBaseIterator>() { + override fun next(): Map.Entry { + assert(hasNextKey()) + index += 2 + @Suppress("UNCHECKED_CAST") + return MapEntry(buffer[index - 2] as K, buffer[index - 1] as V) + } +} + +internal open class MapEntry(override val key: K, override val value: V) : Map.Entry { + override fun hashCode(): Int = key.hashCode() xor value.hashCode() + override fun equals(other: Any?): Boolean = + (other as? Map.Entry<*, *>)?.let { it.key == key && it.value == value } ?: false + + override fun toString(): String = key.toString() + "=" + value.toString() +} + + +internal abstract class UPersistentHashMapBaseIterator( + node: TrieNode, + protected val path: Array>, +) : Iterator { + + protected var pathLastIndex = 0 + private var hasNext = true + + init { + path[0].reset(node.buffer, ENTRY_SIZE * node.entryCount()) + pathLastIndex = 0 + ensureNextEntryIsReady() + } + + private fun moveToNextNodeWithData(pathIndex: Int): Int { + if (path[pathIndex].hasNextKey()) { + return pathIndex + } + if (path[pathIndex].hasNextNode()) { + val node = path[pathIndex].currentNode() + if (pathIndex == TRIE_MAX_HEIGHT - 1) { // collision + path[pathIndex + 1].reset(node.buffer, node.buffer.size) + } else { + path[pathIndex + 1].reset(node.buffer, ENTRY_SIZE * node.entryCount()) + } + return moveToNextNodeWithData(pathIndex + 1) + } + return -1 + } + + private fun ensureNextEntryIsReady() { + if (path[pathLastIndex].hasNextKey()) { + return + } + for (i in pathLastIndex downTo 0) { + var result = moveToNextNodeWithData(i) + + if (result == -1 && path[i].hasNextNode()) { + path[i].moveToNextNode() + result = moveToNextNodeWithData(i) + } + if (result != -1) { + pathLastIndex = result + return + } + if (i > 0) { + path[i - 1].moveToNextNode() + } + path[i].reset(TrieNode.EMPTY.buffer, 0) + } + hasNext = false + } + + protected fun currentKey(): K { + checkHasNext() + return path[pathLastIndex].currentKey() + } + + override fun hasNext(): Boolean { + return hasNext + } + + override fun next(): T { + checkHasNext() + val result = path[pathLastIndex].next() + ensureNextEntryIsReady() + return result + } + + private fun checkHasNext() { + if (!hasNext()) + throw NoSuchElementException() + } +} + +internal class UPersistentHashMapEntriesIterator( + node: TrieNode +) : UPersistentHashMapBaseIterator>( + node, Array(TRIE_MAX_HEIGHT + 1) { TrieNodeEntriesIterator() }) + +internal class UPersistentHashMapKeysIterator(node: TrieNode) + : UPersistentHashMapBaseIterator(node, Array(TRIE_MAX_HEIGHT + 1) { TrieNodeKeysIterator() }) diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieNode.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieNode.kt new file mode 100644 index 0000000000..8722004522 --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableMap/TrieNode.kt @@ -0,0 +1,935 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.implementations.immutableMap + +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.internal.forEachOneBit + +typealias UPersistentHashMap = TrieNode + + +internal const val MAX_BRANCHING_FACTOR = 32 +internal const val LOG_MAX_BRANCHING_FACTOR = 5 +internal const val MAX_BRANCHING_FACTOR_MINUS_ONE = MAX_BRANCHING_FACTOR - 1 +internal const val ENTRY_SIZE = 2 +internal const val MAX_SHIFT = 30 + +/** + * Gets trie index segment of the specified [index] at the level specified by [shift]. + * + * `shift` equal to zero corresponds to the root level. + * For each lower level `shift` increments by [LOG_MAX_BRANCHING_FACTOR]. + */ +internal fun indexSegment(index: Int, shift: Int): Int = + (index shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE + +private fun Array.insertEntryAtIndex(keyIndex: Int, key: K, value: V): Array { + val newBuffer = arrayOfNulls(this.size + ENTRY_SIZE) + this.copyInto(newBuffer, endIndex = keyIndex) + this.copyInto(newBuffer, keyIndex + ENTRY_SIZE, startIndex = keyIndex, endIndex = this.size) + newBuffer[keyIndex] = key + newBuffer[keyIndex + 1] = value + return newBuffer +} + +private fun Array.replaceNodeWithEntry(nodeIndex: Int, keyIndex: Int, key: K, value: V): Array { + val newBuffer = this.copyOf(this.size + 1) + newBuffer.copyInto(newBuffer, nodeIndex + 2, nodeIndex + 1, this.size) + newBuffer.copyInto(newBuffer, keyIndex + 2, keyIndex, nodeIndex) + newBuffer[keyIndex] = key + newBuffer[keyIndex + 1] = value + return newBuffer +} + +private fun Array.replaceEntryWithNode(keyIndex: Int, nodeIndex: Int, newNode: TrieNode<*, *>): Array { + val newNodeIndex = nodeIndex - ENTRY_SIZE // place where to insert new node in the new buffer + val newBuffer = arrayOfNulls(this.size - ENTRY_SIZE + 1) + this.copyInto(newBuffer, endIndex = keyIndex) + this.copyInto(newBuffer, keyIndex, startIndex = keyIndex + ENTRY_SIZE, endIndex = nodeIndex) + newBuffer[newNodeIndex] = newNode + this.copyInto(newBuffer, newNodeIndex + 1, startIndex = nodeIndex, endIndex = this.size) + return newBuffer +} + +private fun Array.removeEntryAtIndex(keyIndex: Int): Array { + val newBuffer = arrayOfNulls(this.size - ENTRY_SIZE) + this.copyInto(newBuffer, endIndex = keyIndex) + this.copyInto(newBuffer, keyIndex, startIndex = keyIndex + ENTRY_SIZE, endIndex = this.size) + return newBuffer +} + +private fun Array.removeNodeAtIndex(nodeIndex: Int): Array { + val newBuffer = arrayOfNulls(this.size - 1) + this.copyInto(newBuffer, endIndex = nodeIndex) + this.copyInto(newBuffer, nodeIndex, startIndex = nodeIndex + 1, endIndex = this.size) + return newBuffer +} + + +class TrieNode( + private var dataMap: Int, + private var nodeMap: Int, + buffer: Array, + private val ownedBy: MutabilityOwnership? +) : Iterable> { + constructor(dataMap: Int, nodeMap: Int, buffer: Array) : this(dataMap, nodeMap, buffer, null) + + internal var buffer: Array = buffer + private set + + /** Returns number of entries stored in this trie node (not counting subnodes) */ + internal fun entryCount(): Int = dataMap.countOneBits() + + // here and later: + // positionMask — an int in form 2^n, i.e. having the single bit set, whose ordinal is a logical position in buffer + + + /** Returns true if the data bit map has the bit specified by [positionMask] set, indicating there's a data entry + * in the buffer at that position. */ + internal fun hasEntryAt(positionMask: Int): Boolean { + return dataMap and positionMask != 0 + } + + /** Returns true if the node bit map has the bit specified by [positionMask] set, indicating there's a subtrie node + * in the buffer at that position. */ + private fun hasNodeAt(positionMask: Int): Boolean { + return nodeMap and positionMask != 0 + } + + /** Gets the index in buffer of the data entry key corresponding to the position specified by [positionMask]. */ + internal fun entryKeyIndex(positionMask: Int): Int { + return ENTRY_SIZE * (dataMap and (positionMask - 1)).countOneBits() + } + + /** Gets the index in buffer of the subtrie node entry corresponding to the position specified by [positionMask]. */ + internal fun nodeIndex(positionMask: Int): Int { + return buffer.size - 1 - (nodeMap and (positionMask - 1)).countOneBits() + } + + /** Retrieves the buffer element at the given [keyIndex] as key of a data entry. */ + private fun keyAtIndex(keyIndex: Int): K { + @Suppress("UNCHECKED_CAST") + return buffer[keyIndex] as K + } + + /** Retrieves the buffer element next to the given [keyIndex] as value of a data entry. */ + private fun valueAtKeyIndex(keyIndex: Int): V { + @Suppress("UNCHECKED_CAST") + return buffer[keyIndex + 1] as V + } + + /** Retrieves the buffer element at the given [nodeIndex] as subtrie node. */ + internal fun nodeAtIndex(nodeIndex: Int): TrieNode { + @Suppress("UNCHECKED_CAST") + return buffer[nodeIndex] as TrieNode + } + + private fun insertEntryAt(positionMask: Int, key: K, value: V): TrieNode { +// assert(!hasEntryAt(positionMask)) + + val keyIndex = entryKeyIndex(positionMask) + val newBuffer = buffer.insertEntryAtIndex(keyIndex, key, value) + return TrieNode(dataMap or positionMask, nodeMap, newBuffer) + } + + private fun mutableInsertEntryAt(positionMask: Int, key: K, value: V, owner: MutabilityOwnership): TrieNode { +// assert(!hasEntryAt(positionMask)) + val keyIndex = entryKeyIndex(positionMask) + if (ownedBy === owner) { + buffer = buffer.insertEntryAtIndex(keyIndex, key, value) + dataMap = dataMap or positionMask + return this + } + val newBuffer = buffer.insertEntryAtIndex(keyIndex, key, value) + return TrieNode(dataMap or positionMask, nodeMap, newBuffer, owner) + } + + private fun updateValueAtIndex(keyIndex: Int, value: V): TrieNode { +// assert(buffer[keyIndex + 1] !== value) + + val newBuffer = buffer.copyOf() + newBuffer[keyIndex + 1] = value + return TrieNode(dataMap, nodeMap, newBuffer) + } + + private fun mutableUpdateValueAtIndex(keyIndex: Int, value: V, owner: MutabilityOwnership): TrieNode { +// assert(buffer[keyIndex + 1] !== value) + // If the [mutator] is exclusive owner of this node, update value at specified index in-place. + if (ownedBy === owner) { + buffer[keyIndex + 1] = value + return this + } + // Create new node with updated value at specified index. + val newBuffer = buffer.copyOf() + newBuffer[keyIndex + 1] = value + return TrieNode(dataMap, nodeMap, newBuffer, owner) + } + + private fun updateNodeAtIndex(nodeIndex: Int, positionMask: Int, newNode: TrieNode): TrieNode { +// assert(buffer[nodeIndex] !== newNode) + val newNodeBuffer = newNode.buffer + if (newNodeBuffer.size == 2 && newNode.nodeMap == 0) { + if (buffer.size == 1) { +// assert(dataMap == 0 && nodeMap xor positionMask == 0) + newNode.dataMap = nodeMap + return newNode + } + + val keyIndex = entryKeyIndex(positionMask) + val newBuffer = buffer.replaceNodeWithEntry(nodeIndex, keyIndex, newNodeBuffer[0], newNodeBuffer[1]) + return TrieNode( + dataMap xor positionMask, + nodeMap xor positionMask, + newBuffer + ) + } + + val newBuffer = buffer.copyOf(buffer.size) + newBuffer[nodeIndex] = newNode + return TrieNode(dataMap, nodeMap, newBuffer) + } + + /** The given [newNode] must not be a part of any persistent map instance. */ + private fun mutableUpdateNodeAtIndex( + nodeIndex: Int, + newNode: TrieNode, + owner: MutabilityOwnership + ): TrieNode { + assert(newNode.ownedBy === owner) +// assert(buffer[nodeIndex] !== newNode) + + // nodes (including collision nodes) that have only one entry are upped if they have no siblings + if (buffer.size == 1 && newNode.buffer.size == ENTRY_SIZE && newNode.nodeMap == 0) { +// assert(dataMap == 0 && nodeMap xor positionMask == 0) + newNode.dataMap = nodeMap + return newNode + } + + if (ownedBy === owner) { + buffer[nodeIndex] = newNode + return this + } + val newBuffer = buffer.copyOf() + newBuffer[nodeIndex] = newNode + return TrieNode(dataMap, nodeMap, newBuffer, owner) + } + + private fun removeNodeAtIndex(nodeIndex: Int, positionMask: Int): TrieNode? { +// assert(hasNodeAt(positionMask)) + if (buffer.size == 1) return null + + val newBuffer = buffer.removeNodeAtIndex(nodeIndex) + return TrieNode(dataMap, nodeMap xor positionMask, newBuffer) + } + + private fun mutableRemoveNodeAtIndex( + nodeIndex: Int, + positionMask: Int, + owner: MutabilityOwnership + ): TrieNode? { +// assert(hasNodeAt(positionMask)) + if (buffer.size == 1) return null + + if (ownedBy === owner) { + buffer = buffer.removeNodeAtIndex(nodeIndex) + nodeMap = nodeMap xor positionMask + return this + } + val newBuffer = buffer.removeNodeAtIndex(nodeIndex) + return TrieNode(dataMap, nodeMap xor positionMask, newBuffer, owner) + } + + private fun bufferMoveEntryToNode( + keyIndex: Int, positionMask: Int, newKeyHash: Int, + newKey: K, newValue: V, shift: Int, owner: MutabilityOwnership? + ): Array { + val storedKey = keyAtIndex(keyIndex) + val storedKeyHash = storedKey.hashCode() + val storedValue = valueAtKeyIndex(keyIndex) + val newNode = makeNode( + storedKeyHash, storedKey, storedValue, + newKeyHash, newKey, newValue, shift + LOG_MAX_BRANCHING_FACTOR, owner + ) + + val nodeIndex = nodeIndex(positionMask) + 1 // place where to insert new node in the current buffer + + return buffer.replaceEntryWithNode(keyIndex, nodeIndex, newNode) + } + + private fun moveEntryToNode( + keyIndex: Int, positionMask: Int, newKeyHash: Int, + newKey: K, newValue: V, shift: Int + ): TrieNode { +// assert(hasEntryAt(positionMask)) +// assert(!hasNodeAt(positionMask)) + + val newBuffer = bufferMoveEntryToNode(keyIndex, positionMask, newKeyHash, newKey, newValue, shift, null) + return TrieNode(dataMap xor positionMask, nodeMap or positionMask, newBuffer) + } + + private fun mutableMoveEntryToNode( + keyIndex: Int, + positionMask: Int, + newKeyHash: Int, + newKey: K, + newValue: V, + shift: Int, + owner: MutabilityOwnership, + ): TrieNode { +// assert(hasEntryAt(positionMask)) +// assert(!hasNodeAt(positionMask)) + if (ownedBy === owner) { + buffer = bufferMoveEntryToNode(keyIndex, positionMask, newKeyHash, newKey, newValue, shift, owner) + dataMap = dataMap xor positionMask + nodeMap = nodeMap or positionMask + return this + } + val newBuffer = bufferMoveEntryToNode(keyIndex, positionMask, newKeyHash, newKey, newValue, shift, owner) + return TrieNode(dataMap xor positionMask, nodeMap or positionMask, newBuffer, owner) + } + + /** Creates a new TrieNode for holding two given key value entries */ + private fun makeNode( + keyHash1: Int, key1: K, value1: V, + keyHash2: Int, key2: K, value2: V, shift: Int, owner: MutabilityOwnership? + ): TrieNode { + if (shift > MAX_SHIFT) { +// assert(key1 != key2) + // when two key hashes are entirely equal: the last level subtrie node stores them just as unordered list + return TrieNode(0, 0, arrayOf(key1, value1, key2, value2), owner) + } + + val setBit1 = indexSegment(keyHash1, shift) + val setBit2 = indexSegment(keyHash2, shift) + + if (setBit1 != setBit2) { + val nodeBuffer = if (setBit1 < setBit2) { + arrayOf(key1, value1, key2, value2) + } else { + arrayOf(key2, value2, key1, value1) + } + return TrieNode((1 shl setBit1) or (1 shl setBit2), 0, nodeBuffer, owner) + } + // hash segments at the given shift are equal: move these entries into the subtrie + val node = makeNode(keyHash1, key1, value1, keyHash2, key2, value2, shift + LOG_MAX_BRANCHING_FACTOR, owner) + return TrieNode(0, 1 shl setBit1, arrayOf(node), owner) + } + + private fun removeEntryAtIndex(keyIndex: Int, positionMask: Int): TrieNode? { +// assert(hasEntryAt(positionMask)) + if (buffer.size == ENTRY_SIZE) return null + val newBuffer = buffer.removeEntryAtIndex(keyIndex) + return TrieNode(dataMap xor positionMask, nodeMap, newBuffer) + } + + private fun mutableRemoveEntryAtIndex( + keyIndex: Int, + positionMask: Int, + owner: MutabilityOwnership + ): TrieNode? { +// assert(hasEntryAt(positionMask)) + if (buffer.size == ENTRY_SIZE) return null + + if (ownedBy === owner) { + buffer = buffer.removeEntryAtIndex(keyIndex) + dataMap = dataMap xor positionMask + return this + } + val newBuffer = buffer.removeEntryAtIndex(keyIndex) + return TrieNode(dataMap xor positionMask, nodeMap, newBuffer, owner) + } + + private fun collisionRemoveEntryAtIndex(i: Int): TrieNode? { + if (buffer.size == ENTRY_SIZE) return null + val newBuffer = buffer.removeEntryAtIndex(i) + return TrieNode(0, 0, newBuffer) + } + + private fun mutableCollisionRemoveEntryAtIndex(i: Int, owner: MutabilityOwnership): TrieNode? { + if (buffer.size == ENTRY_SIZE) return null + + if (ownedBy === owner) { + buffer = buffer.removeEntryAtIndex(i) + return this + } + val newBuffer = buffer.removeEntryAtIndex(i) + return TrieNode(0, 0, newBuffer, owner) + } + + private fun collisionKeyIndex(key: Any?): Int { + for (i in 0 until buffer.size step ENTRY_SIZE) { + if (key == keyAtIndex(i)) return i + } + return -1 + } + + private fun collisionContainsKey(key: K): Boolean { + return collisionKeyIndex(key) != -1 + } + + private fun collisionGet(key: K): V? { + val keyIndex = collisionKeyIndex(key) + return if (keyIndex != -1) valueAtKeyIndex(keyIndex) else null + } + + private fun collisionPut(key: K, value: V): TrieNode { + val keyIndex = collisionKeyIndex(key) + if (keyIndex != -1) { + if (value === valueAtKeyIndex(keyIndex)) { + return this + } + val newBuffer = buffer.copyOf() + newBuffer[keyIndex + 1] = value + return TrieNode(0, 0, newBuffer) + } + val newBuffer = buffer.insertEntryAtIndex(0, key, value) + return TrieNode(0, 0, newBuffer) + } + + private fun mutableCollisionPut(key: K, value: V, owner: MutabilityOwnership): TrieNode { + // Check if there is an entry with the specified key. + val keyIndex = collisionKeyIndex(key) + if (keyIndex != -1) { // found entry with the specified key + // If the [mutator] is exclusive owner of this node, update value of the entry in-place. + if (ownedBy === owner) { + buffer[keyIndex + 1] = value + return this + } + + // Structural change due to node replacement. + // Create new node with updated entry value. + val newBuffer = buffer.copyOf() + newBuffer[keyIndex + 1] = value + return TrieNode(0, 0, newBuffer, owner) + } + // Create new collision node with the specified entry added to it. + val newBuffer = buffer.insertEntryAtIndex(0, key, value) + return TrieNode(0, 0, newBuffer, owner) + } + + private fun collisionRemove(key: K): Pair?, Boolean> { + val keyIndex = collisionKeyIndex(key) + if (keyIndex != -1) { + return collisionRemoveEntryAtIndex(keyIndex) to true + } + return this to false + } + + private fun mutableCollisionRemove(key: K, owner: MutabilityOwnership): Pair?, Boolean> { + val keyIndex = collisionKeyIndex(key) + if (keyIndex != -1) { + return mutableCollisionRemoveEntryAtIndex(keyIndex, owner) to true + } + return this to false + } + + private fun mutableCollisionRemoveAndGetValue(key: K, owner: MutabilityOwnership): Pair?, V?> { + val keyIndex = collisionKeyIndex(key) + if (keyIndex != -1) { + val value = valueAtKeyIndex(keyIndex) + return mutableCollisionRemoveEntryAtIndex(keyIndex, owner) to value + } + return this to null + } + + private fun mutableCollisionPutAll(otherNode: TrieNode, owner: MutabilityOwnership): TrieNode { + assert(nodeMap == 0) + assert(dataMap == 0) + assert(otherNode.nodeMap == 0) + assert(otherNode.dataMap == 0) + val tempBuffer = this.buffer.copyOf(newSize = this.buffer.size + otherNode.buffer.size) + var i = this.buffer.size + for (j in 0 until otherNode.buffer.size step ENTRY_SIZE) { + @Suppress("UNCHECKED_CAST") + if (!this.collisionContainsKey(otherNode.buffer[j] as K)) { + tempBuffer[i] = otherNode.buffer[j] + tempBuffer[i + 1] = otherNode.buffer[j + 1] + i += ENTRY_SIZE + } + } + + return when (val newSize = i) { + this.buffer.size -> this + otherNode.buffer.size -> otherNode + tempBuffer.size -> TrieNode(0, 0, tempBuffer, owner) + else -> TrieNode(0, 0, tempBuffer.copyOf(newSize), owner) + } + } + + /** + * Updates the cell of this node at [positionMask] with entries from the cell of [otherNode] at [positionMask]. + */ + private fun mutablePutAllFromOtherNodeCell( + otherNode: TrieNode, + positionMask: Int, + shift: Int, + owner: MutabilityOwnership + ): TrieNode = when { + this.hasNodeAt(positionMask) -> { + val targetNode = this.nodeAtIndex(nodeIndex(positionMask)) + when { + otherNode.hasNodeAt(positionMask) -> { + val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask)) + targetNode.mutablePutAll(otherTargetNode, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + + otherNode.hasEntryAt(positionMask) -> { + val keyIndex = otherNode.entryKeyIndex(positionMask) + val key = otherNode.keyAtIndex(keyIndex) + val value = otherNode.valueAtKeyIndex(keyIndex) + targetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + + else -> targetNode + } + } + + otherNode.hasNodeAt(positionMask) -> { + val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask)) + when { + this.hasEntryAt(positionMask) -> { + // if otherTargetNode already has a value associated with the key, do not put this entry + val keyIndex = this.entryKeyIndex(positionMask) + val key = this.keyAtIndex(keyIndex) + if (otherTargetNode.containsKey(key.hashCode(), key, shift + LOG_MAX_BRANCHING_FACTOR)) { + otherTargetNode + } else { + val value = this.valueAtKeyIndex(keyIndex) + otherTargetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + } + + else -> otherTargetNode + } + } + + else -> { // two entries, and they are not equal by key. See (**) in mutablePutAll + val thisKeyIndex = this.entryKeyIndex(positionMask) + val thisKey = this.keyAtIndex(thisKeyIndex) + val thisValue = this.valueAtKeyIndex(thisKeyIndex) + val otherKeyIndex = otherNode.entryKeyIndex(positionMask) + val otherKey = otherNode.keyAtIndex(otherKeyIndex) + val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex) + makeNode( + thisKey.hashCode(), + thisKey, + thisValue, + otherKey.hashCode(), + otherKey, + otherValue, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + } + + private fun elementsIdentityEquals(otherNode: TrieNode): Boolean { + if (this === otherNode) return true + if (nodeMap != otherNode.nodeMap) return false + if (dataMap != otherNode.dataMap) return false + for (i in 0 until buffer.size) { + if (buffer[i] !== otherNode.buffer[i]) return false + } + return true + } + + private fun containsKey(keyHash: Int, key: K, shift: Int): Boolean { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + return key == keyAtIndex(entryKeyIndex(keyPositionMask)) + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val targetNode = nodeAtIndex(nodeIndex(keyPositionMask)) + if (shift == MAX_SHIFT) { + return targetNode.collisionContainsKey(key) + } + return targetNode.containsKey(keyHash, key, shift + LOG_MAX_BRANCHING_FACTOR) + } + + // key is absent + return false + } + + private fun get(keyHash: Int, key: K, shift: Int): V? { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + return valueAtKeyIndex(keyIndex) + } + return null + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val targetNode = nodeAtIndex(nodeIndex(keyPositionMask)) + if (shift == MAX_SHIFT) { + return targetNode.collisionGet(key) + } + return targetNode.get(keyHash, key, shift + LOG_MAX_BRANCHING_FACTOR) + } + + // key is absent + return null + } + + private fun mutablePutAll(otherNode: TrieNode, shift: Int, owner: MutabilityOwnership): TrieNode { + if (this === otherNode) { + return this + } + // the collision case + if (shift > MAX_SHIFT) { + return mutableCollisionPutAll(otherNode, owner) + } + + // new nodes are where either of the old ones were + var newNodeMap = nodeMap or otherNode.nodeMap + // entries stay being entries only if one bits were in exactly one of input nodes + // but not in the new data nodes + var newDataMap = dataMap xor otherNode.dataMap and newNodeMap.inv() + // (**) now, this is tricky: we have a number of entry-entry pairs and we don't know yet whether + // they result in an entry (if keys are equal) or a new node (if they are not) + // but we want to keep it to single allocation, so we check and mark equal ones here + (dataMap and otherNode.dataMap).forEachOneBit { positionMask, _ -> + val leftKey = this.keyAtIndex(this.entryKeyIndex(positionMask)) + val rightKey = otherNode.keyAtIndex(otherNode.entryKeyIndex(positionMask)) + // if they are equal, put them in the data map + if (leftKey == rightKey) newDataMap = newDataMap or positionMask + // if they are not, put them in the node map + else newNodeMap = newNodeMap or positionMask + // we can use this later to skip calling equals() again + } + check(newNodeMap and newDataMap == 0) + val mutableNode = when { + this.ownedBy == owner && this.dataMap == newDataMap && this.nodeMap == newNodeMap -> this + else -> { + val newBuffer = arrayOfNulls(newDataMap.countOneBits() * ENTRY_SIZE + newNodeMap.countOneBits()) + TrieNode(newDataMap, newNodeMap, newBuffer) + } + } + newNodeMap.forEachOneBit { positionMask, index -> + val newNodeIndex = mutableNode.buffer.size - 1 - index + mutableNode.buffer[newNodeIndex] = mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, owner) + } + newDataMap.forEachOneBit { positionMask, index -> + val newKeyIndex = index * ENTRY_SIZE + when { + !otherNode.hasEntryAt(positionMask) -> { + val oldKeyIndex = this.entryKeyIndex(positionMask) + mutableNode.buffer[newKeyIndex] = this.keyAtIndex(oldKeyIndex) + mutableNode.buffer[newKeyIndex + 1] = this.valueAtKeyIndex(oldKeyIndex) + } + // there is either only one entry in otherNode, or + // both entries are here => they are equal, see ** above + // so just overwrite that + else -> { + val oldKeyIndex = otherNode.entryKeyIndex(positionMask) + mutableNode.buffer[newKeyIndex] = otherNode.keyAtIndex(oldKeyIndex) + mutableNode.buffer[newKeyIndex + 1] = otherNode.valueAtKeyIndex(oldKeyIndex) + } + } + } + return when { + this.elementsIdentityEquals(mutableNode) -> this + otherNode.elementsIdentityEquals(mutableNode) -> otherNode + else -> mutableNode + } + } + + fun remove(keyHash: Int, key: K, shift: Int): Pair?, Boolean> { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + return removeEntryAtIndex(keyIndex, keyPositionMask) to true + } + return this to false + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val nodeIndex = nodeIndex(keyPositionMask) + + val targetNode = nodeAtIndex(nodeIndex) + val (newNode, hasChanged) = if (shift == MAX_SHIFT) { + targetNode.collisionRemove(key) + } else { + targetNode.remove(keyHash, key, shift + LOG_MAX_BRANCHING_FACTOR) + } + return replaceNode(targetNode, newNode, nodeIndex, keyPositionMask) to hasChanged + } + + // key is absent + return this to false + } + + private fun replaceNode(targetNode: TrieNode, newNode: TrieNode?, nodeIndex: Int, positionMask: Int) = + when { + newNode == null -> + removeNodeAtIndex(nodeIndex, positionMask) + + targetNode !== newNode -> + updateNodeAtIndex(nodeIndex, positionMask, newNode) + + else -> + this + } + + fun put(keyHash: Int, key: K, value: @UnsafeVariance V, shift: Int): TrieNode { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + if (valueAtKeyIndex(keyIndex) === value) return this + + return updateValueAtIndex(keyIndex, value) + } + return moveEntryToNode(keyIndex, keyPositionMask, keyHash, key, value, shift) + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val nodeIndex = nodeIndex(keyPositionMask) + + val targetNode = nodeAtIndex(nodeIndex) + val putResult = if (shift == MAX_SHIFT) { + targetNode.collisionPut(key, value) + } else { + targetNode.put(keyHash, key, value, shift + LOG_MAX_BRANCHING_FACTOR) + } + return replaceNode(targetNode, putResult, nodeIndex, keyPositionMask)!! + } + + // no entry at this key hash segment + return insertEntryAt(keyPositionMask, key, value) + } + + + private fun mutablePut( + keyHash: Int, + key: K, + value: @UnsafeVariance V, + shift: Int, + owner: MutabilityOwnership + ): TrieNode { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + if (valueAtKeyIndex(keyIndex) === value) { + return this + } + + return mutableUpdateValueAtIndex(keyIndex, value, owner) + } + return mutableMoveEntryToNode(keyIndex, keyPositionMask, keyHash, key, value, shift, owner) + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val nodeIndex = nodeIndex(keyPositionMask) + + val targetNode = nodeAtIndex(nodeIndex) + val newNode = if (shift == MAX_SHIFT) { + targetNode.mutableCollisionPut(key, value, owner) + } else { + targetNode.mutablePut(keyHash, key, value, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + if (targetNode === newNode) { + return this + } + return mutableUpdateNodeAtIndex(nodeIndex, newNode, owner) + } + + // key is absent + return mutableInsertEntryAt(keyPositionMask, key, value, owner) + } + + private fun mutableRemove( + keyHash: Int, + key: K, + shift: Int, + owner: MutabilityOwnership + ): Pair?, Boolean> { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + return mutableRemoveEntryAtIndex(keyIndex, keyPositionMask, owner) to true + } + return this to false + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val nodeIndex = nodeIndex(keyPositionMask) + + val targetNode = nodeAtIndex(nodeIndex) + val (newNode, hasChanged) = if (shift == MAX_SHIFT) { + targetNode.mutableCollisionRemove(key, owner) + } else { + targetNode.mutableRemove(keyHash, key, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + return mutableReplaceNode(targetNode, newNode, nodeIndex, keyPositionMask, owner) to hasChanged + } + + // key is absent + return this to false + } + + private fun mutableRemoveAndGetValue( + keyHash: Int, + key: K, + shift: Int, + owner: MutabilityOwnership + ): Pair?, V?> { + val keyPositionMask = 1 shl indexSegment(keyHash, shift) + + if (hasEntryAt(keyPositionMask)) { // key is directly in buffer + val keyIndex = entryKeyIndex(keyPositionMask) + + if (key == keyAtIndex(keyIndex)) { + val value = valueAtKeyIndex(keyIndex) + return mutableRemoveEntryAtIndex(keyIndex, keyPositionMask, owner) to value + } + return this to null + } + if (hasNodeAt(keyPositionMask)) { // key is in node + val nodeIndex = nodeIndex(keyPositionMask) + + val targetNode = nodeAtIndex(nodeIndex) + val (newNode, value) = if (shift == MAX_SHIFT) { + targetNode.mutableCollisionRemoveAndGetValue(key, owner) + } else { + targetNode.mutableRemoveAndGetValue(keyHash, key, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + return mutableReplaceNode(targetNode, newNode, nodeIndex, keyPositionMask, owner) to value + } + + // key is absent + return this to null + } + + private fun mutableReplaceNode( + targetNode: TrieNode, + newNode: TrieNode?, + nodeIndex: Int, + positionMask: Int, + owner: MutabilityOwnership + ) = when { + newNode == null -> + mutableRemoveNodeAtIndex(nodeIndex, positionMask, owner) + + targetNode !== newNode -> + mutableUpdateNodeAtIndex(nodeIndex, newNode, owner) + + else -> this + } + + // For testing trie structure + internal fun accept(visitor: (node: TrieNode, shift: Int, hash: Int, dataMap: Int, nodeMap: Int) -> Unit) { + accept(visitor, 0, 0) + } + + private fun accept( + visitor: (node: TrieNode, shift: Int, hash: Int, dataMap: Int, nodeMap: Int) -> Unit, + hash: Int, + shift: Int + ) { + visitor(this, shift, hash, dataMap, nodeMap) + + var nodePositions = nodeMap + while (nodePositions != 0) { + val mask = nodePositions.takeLowestOneBit() +// assert(hasNodeAt(mask)) + + val hashSegment = mask.countTrailingZeroBits() + + val childNode = nodeAtIndex(nodeIndex(mask)) + childNode.accept(visitor, hash + (hashSegment shl shift), shift + LOG_MAX_BRANCHING_FACTOR) + + nodePositions -= mask + } + } + + private fun equalsWith(that: TrieNode, equalityComparator: (V, V1) -> Boolean): Boolean { + if (this === that) return true + if (dataMap != that.dataMap || nodeMap != that.nodeMap) return false + if (dataMap == 0 && nodeMap == 0) { // collision node + if (buffer.size != that.buffer.size) return false + return (0 until buffer.size step ENTRY_SIZE).all { i -> + val thatKey = that.keyAtIndex(i) + val thatValue = that.valueAtKeyIndex(i) + val keyIndex = collisionKeyIndex(thatKey) + if (keyIndex != -1) { + val value = valueAtKeyIndex(keyIndex) + equalityComparator(value, thatValue) + } else false + } + } + + val valueSize = dataMap.countOneBits() * ENTRY_SIZE + for (i in 0 until valueSize step ENTRY_SIZE) { + if (keyAtIndex(i) != that.keyAtIndex(i)) return false + if (!equalityComparator(valueAtKeyIndex(i), that.valueAtKeyIndex(i))) return false + } + for (i in valueSize until buffer.size) { + if (!nodeAtIndex(i).equalsWith(that.nodeAtIndex(i), equalityComparator)) return false + } + return true + } + + val keys: Sequence get() = UPersistentHashMapKeysIterator(this).asSequence() + + fun containsKey(key: K): Boolean = containsKey(key.hashCode(), key, 0) + + operator fun get(key: K) = get(key.hashCode(), key, 0) + + operator fun contains(key: K) = containsKey(key) + + fun put(key: K, value: V, owner: MutabilityOwnership): TrieNode = + mutablePut(key.hashCode(), key, value, 0, owner) + + fun remove(key: K, owner: MutabilityOwnership): TrieNode = + removeWithChangeInfo(key, owner).first + + @Suppress("UNCHECKED_CAST") + fun removeWithChangeInfo(key: K, owner: MutabilityOwnership): Pair, Boolean> { + val (node, hasChanged) = mutableRemove(key.hashCode(), key, 0, owner) + return (node ?: EMPTY as TrieNode) to hasChanged + } + + fun removeAndGetValue(key: K, owner: MutabilityOwnership): Pair, V?> { + val (node, value) = mutableRemoveAndGetValue(key.hashCode(), key, 0, owner) + @Suppress("UNCHECKED_CAST") + return (node ?: EMPTY as TrieNode) to value + } + + fun putAll(otherNode: TrieNode, owner: MutabilityOwnership): TrieNode { + return mutablePutAll(otherNode, 0, owner) + } + + @Suppress("UNCHECKED_CAST") + fun clear() = EMPTY as TrieNode + + @Suppress("UNCHECKED_CAST") + override fun equals(other: Any?): Boolean { + other as? TrieNode ?: return false + return this.equalsWith(other) { v1, v2 -> v1 == v2 } + } + + override fun toString(): String = + iterator().asSequence() + .joinToString(separator = "\n", prefix = "{", postfix = "}") { "${it.key} -> ${it.value}" } + + override fun hashCode(): Int = sumOf { it.hashCode() } + + override fun iterator(): Iterator> = UPersistentHashMapEntriesIterator(this) + + companion object { + internal val EMPTY = TrieNode(0, 0, emptyArray()) + } +} diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/TrieNode.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/TrieNode.kt new file mode 100644 index 0000000000..100c6ef9ab --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/TrieNode.kt @@ -0,0 +1,807 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.implementations.immutableSet + +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.internal.forEachOneBit + +typealias UPersistentHashSet = TrieNode + + +internal const val MAX_BRANCHING_FACTOR = 32 +internal const val LOG_MAX_BRANCHING_FACTOR = 5 +internal const val MAX_BRANCHING_FACTOR_MINUS_ONE = MAX_BRANCHING_FACTOR - 1 +internal const val MAX_SHIFT = 30 + +/** + * Gets trie index segment of the specified [index] at the level specified by [shift]. + * + * `shift` equal to zero corresponds to the root level. + * For each lower level `shift` increments by [LOG_MAX_BRANCHING_FACTOR]. + */ +internal fun indexSegment(index: Int, shift: Int): Int = + (index shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE + + +private fun Array.addElementAtIndex(index: Int, element: E): Array { + val newBuffer = arrayOfNulls(this.size + 1) + this.copyInto(newBuffer, endIndex = index) + this.copyInto(newBuffer, index + 1, index, this.size) + newBuffer[index] = element + return newBuffer +} + +private fun Array.removeCellAtIndex(cellIndex: Int): Array { + val newBuffer = arrayOfNulls(this.size - 1) + this.copyInto(newBuffer, endIndex = cellIndex) + this.copyInto(newBuffer, cellIndex, cellIndex + 1, this.size) + return newBuffer +} + +/** + * Writes all elements from [this] to [newArray], starting with [newArrayOffset], filtering + * on the fly using [predicate]. By default filters out [TrieNode.EMPTY] instances + * + * return number of elements written to [newArray] + **/ +private inline fun Array.filterTo( + newArray: Array, + newArrayOffset: Int = 0, + predicate: (Any?) -> Boolean = { it !== TrieNode.EMPTY }, +): Int { + var i = 0 + var j = 0 + while (i < size) { + assert(j <= i) // this is extremely important if newArray === this + val e = this[i] + if (predicate(e)) { + newArray[newArrayOffset + j] = this[i] + ++j + assert(newArrayOffset + j <= newArray.size) + } + ++i + } + return j +} + +class TrieNode( + private var bitmap: Int, + internal var buffer: Array, + private var ownedBy: MutabilityOwnership? +) : Iterable { + + constructor(bitmap: Int, buffer: Array) : this(bitmap, buffer, null) + + // here and later: + // positionMask — an int in form 2^n, i.e. having the single bit set, whose ordinal is a logical position in buffer + + private fun hasNoCellAt(positionMask: Int): Boolean { + return bitmap and positionMask == 0 + } + + internal fun indexOfCellAt(positionMask: Int): Int { + return (bitmap and (positionMask - 1)).countOneBits() + } + + private fun elementAtIndex(index: Int): E { + @Suppress("UNCHECKED_CAST") + return buffer[index] as E + } + + private fun nodeAtIndex(index: Int): TrieNode { + @Suppress("UNCHECKED_CAST") + return buffer[index] as TrieNode + } + + private fun addElementAt(positionMask: Int, element: E, owner: MutabilityOwnership?): TrieNode { +// assert(hasNoCellAt(positionMask)) + + val index = indexOfCellAt(positionMask) + val newBitmap = bitmap or positionMask + val newBuffer = buffer.addElementAtIndex(index, element) + return setProperties(newBitmap, newBuffer, owner) + } + + private fun setProperties(newBitmap: Int, newBuffer: Array, owner: MutabilityOwnership?): TrieNode { + if (ownedBy != null && ownedBy === owner) { + bitmap = newBitmap + buffer = newBuffer + return this + } + return TrieNode(newBitmap, newBuffer, owner) + } + + /** The given [newNode] must not be a part of any persistent set instance. */ + private fun canonicalizeNodeAtIndex(nodeIndex: Int, newNode: TrieNode, owner: MutabilityOwnership?): TrieNode { +// assert(buffer[nodeIndex] !== newNode) + val cell: Any? + + val newNodeBuffer = newNode.buffer + if (newNodeBuffer.size == 1 && newNodeBuffer[0] !is TrieNode<*>) { + if (buffer.size == 1) { + newNode.bitmap = bitmap + return newNode + } + cell = newNodeBuffer[0] + } else { + cell = newNode + } + + return setCellAtIndex(nodeIndex, cell, owner) + } + + private fun setCellAtIndex(cellIndex: Int, newCell: Any?, owner: MutabilityOwnership?): TrieNode { + if (ownedBy != null && ownedBy === owner) { + buffer[cellIndex] = newCell + return this + } + val newBuffer = buffer.copyOf() + newBuffer[cellIndex] = newCell + return TrieNode(bitmap, newBuffer, owner) + } + + private fun makeNodeAtIndex( + elementIndex: Int, + newElementHash: Int, + newElement: E, + shift: Int, + owner: MutabilityOwnership?, + ): TrieNode { + val storedElement = elementAtIndex(elementIndex) + return makeNode(storedElement.hashCode(), storedElement, + newElementHash, newElement, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + + private fun moveElementToNode( + elementIndex: Int, + newElementHash: Int, + newElement: E, + shift: Int, + owner: MutabilityOwnership? + ): TrieNode { + val node = makeNodeAtIndex(elementIndex, newElementHash, newElement, shift, owner) + return setCellAtIndex(elementIndex, node, owner) + } + + private fun makeNode(elementHash1: Int, element1: E, elementHash2: Int, element2: E, + shift: Int, owner: MutabilityOwnership?): TrieNode { + if (shift > MAX_SHIFT) { +// assert(element1 != element2) + // when two element hashes are entirely equal: the last level subtrie node stores them just as unordered list + return TrieNode(0, arrayOf(element1, element2), owner) + } + + val setBit1 = indexSegment(elementHash1, shift) + val setBit2 = indexSegment(elementHash2, shift) + + if (setBit1 != setBit2) { + val nodeBuffer = if (setBit1 < setBit2) { + arrayOf(element1, element2) + } else { + arrayOf(element2, element1) + } + return TrieNode((1 shl setBit1) or (1 shl setBit2), nodeBuffer, owner) + } + // hash segments at the given shift are equal: move these elements into the subtrie + val node = makeNode(elementHash1, element1, elementHash2, element2, shift + LOG_MAX_BRANCHING_FACTOR, owner) + return TrieNode(1 shl setBit1, arrayOf(node), owner) + } + + + private fun removeCellAtIndex(cellIndex: Int, positionMask: Int, owner: MutabilityOwnership?): TrieNode { +// assert(!hasNoCellAt(positionMask)) +// assert(buffer.size > 1) can be false only for the root node + + val newBitmap = bitmap xor positionMask + val newBuffer = buffer.removeCellAtIndex(cellIndex) + return setProperties(newBitmap, newBuffer, owner) + } + + private fun collisionRemoveElementAtIndex(i: Int, owner: MutabilityOwnership?): TrieNode { + val newBuffer = buffer.removeCellAtIndex(i) + return setProperties(newBitmap = 0, newBuffer, owner) + } + + private fun collisionContainsElement(element: E): Boolean { + return buffer.contains(element) + } + + private fun collisionAdd(element: E): TrieNode { + if (collisionContainsElement(element)) return this + val newBuffer = buffer.addElementAtIndex(0, element) + return setProperties(newBitmap = 0, newBuffer, owner = null) + } + + private fun mutableCollisionAdd(element: E, owner: MutabilityOwnership): TrieNode { + if (collisionContainsElement(element)) return this + val newBuffer = buffer.addElementAtIndex(0, element) + return setProperties(newBitmap = 0, newBuffer, owner = owner) + } + + private fun collisionRemove(element: E): TrieNode { + val index = buffer.indexOf(element) + if (index != -1) { + return collisionRemoveElementAtIndex(index, owner = null) + } + return this + } + + private fun mutableCollisionRemove(element: E, owner: MutabilityOwnership): Pair, Boolean> { + val index = buffer.indexOf(element) + if (index != -1) { + return collisionRemoveElementAtIndex(index, owner) to true + } + return this to false + } + + private fun mutableCollisionAddAll(otherNode: TrieNode, owner: MutabilityOwnership): TrieNode { + if (this === otherNode) { + return this + } + val tempBuffer = this.buffer.copyOf(newSize = this.buffer.size + otherNode.buffer.size) + val totalWritten = otherNode.buffer.filterTo(tempBuffer, newArrayOffset = this.buffer.size) { + @Suppress("UNCHECKED_CAST") + !this.collisionContainsElement(it as E) + } + val totalSize = totalWritten + this.buffer.size + if (totalSize == this.buffer.size) return this + if (totalSize == otherNode.buffer.size) return otherNode + + val newBuffer = if (totalSize == tempBuffer.size) tempBuffer else tempBuffer.copyOf(newSize = totalSize) + return setProperties(newBitmap = 0, newBuffer, owner) + } + + private fun mutableCollisionRetainAll(otherNode: TrieNode, owner: MutabilityOwnership): Any? { + if (this === otherNode) { + return this + } + val tempBuffer = + if (owner === ownedBy) buffer + else arrayOfNulls(minOf(buffer.size, otherNode.buffer.size)) + val totalWritten = buffer.filterTo(tempBuffer) { + @Suppress("UNCHECKED_CAST") + otherNode.collisionContainsElement(it as E) + } + return when (totalWritten) { + 0 -> EMPTY + 1 -> tempBuffer[0] + this.buffer.size -> this + otherNode.buffer.size -> otherNode + tempBuffer.size -> setProperties(newBitmap = 0, newBuffer = tempBuffer, owner) + else -> setProperties(newBitmap = 0, newBuffer = tempBuffer.copyOf(newSize = totalWritten), owner) + } + } + + private fun mutableCollisionRemoveAll(otherNode: TrieNode, owner: MutabilityOwnership): Any? { + if (this === otherNode) { + return EMPTY + } + val tempBuffer = if (owner === ownedBy) buffer else arrayOfNulls(buffer.size) + val totalWritten = buffer.filterTo(tempBuffer) { + @Suppress("UNCHECKED_CAST") + !otherNode.collisionContainsElement(it as E) + } + return when (totalWritten) { + 0 -> EMPTY + 1 -> tempBuffer[0] + this.buffer.size -> this + tempBuffer.size -> setProperties(newBitmap = 0, newBuffer = tempBuffer, owner) + else -> setProperties(newBitmap = 0, newBuffer = tempBuffer.copyOf(newSize = totalWritten), owner) + } + } + + fun calculateSize(): Int { + if (bitmap == 0) return buffer.size + var result = 0 + for (e in buffer) { + result += when (e) { + is TrieNode<*> -> e.calculateSize() + else -> 1 + } + } + return result + } + + private fun elementsIdentityEquals(otherNode: TrieNode): Boolean { + if (this === otherNode) return true + if (bitmap != otherNode.bitmap) return false + for (i in 0 until buffer.size) { + if (buffer[i] !== otherNode.buffer[i]) return false + } + return true + } + + fun contains(elementHash: Int, element: E, shift: Int): Boolean { + val cellPositionMask = 1 shl indexSegment(elementHash, shift) + + if (hasNoCellAt(cellPositionMask)) { // element is absent + return false + } + + val cellIndex = indexOfCellAt(cellPositionMask) + if (buffer[cellIndex] is TrieNode<*>) { // element may be in node + val targetNode = nodeAtIndex(cellIndex) + if (shift == MAX_SHIFT) { + return targetNode.collisionContainsElement(element) + } + return targetNode.contains(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR) + } + // element is directly in buffer + return element == buffer[cellIndex] + } + + fun mutableAddAll(otherNode: TrieNode, shift: Int, owner: MutabilityOwnership): TrieNode { + if (this === otherNode) { + return this + } + if (shift > MAX_SHIFT) { + return mutableCollisionAddAll(otherNode, owner) + } + // union mask contains all the bits from input masks + val newBitMap = bitmap or otherNode.bitmap + // first allocate the node and then fill it in + // we are doing a union, so all the array elements are guaranteed to exist + val mutableNode = when { + newBitMap == bitmap && ownedBy == owner -> this + else -> TrieNode(newBitMap, arrayOfNulls(newBitMap.countOneBits()), owner) + } + // for each bit set in the resulting mask, + // either left, right or both masks contain the same bit + // Note: we shouldn't overrun MAX_SHIFT because both sides are correct TrieNodes, right? + newBitMap.forEachOneBit { positionMask, newNodeIndex -> + val thisIndex = indexOfCellAt(positionMask) + val otherNodeIndex = otherNode.indexOfCellAt(positionMask) + mutableNode.buffer[newNodeIndex] = when { + // no element on left -> pick right + hasNoCellAt(positionMask) -> otherNode.buffer[otherNodeIndex] + // no element on right -> pick left + otherNode.hasNoCellAt(positionMask) -> buffer[thisIndex] + // both nodes contain something at the masked bit + else -> { + val thisCell = buffer[thisIndex] + val otherNodeCell = otherNode.buffer[otherNodeIndex] + val thisIsNode = thisCell is TrieNode<*> + val otherIsNode = otherNodeCell is TrieNode<*> + when { + // both are nodes -> merge them recursively + thisIsNode && otherIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as TrieNode + thisCell.mutableAddAll( + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + // one of them is a node -> add the other one to it + thisIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as E + thisCell.mutableAdd( + otherNodeCell.hashCode(), + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + // same as last case, but reversed + otherIsNode -> @Suppress("UNCHECKED_CAST") { + otherNodeCell as TrieNode + thisCell as E + otherNodeCell.mutableAdd( + thisCell.hashCode(), + thisCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + // both are just E => compare them + thisCell == otherNodeCell -> thisCell + // both are just E, but different => make a collision-ish node + else -> @Suppress("UNCHECKED_CAST") { + thisCell as E + otherNodeCell as E + makeNode( + thisCell.hashCode(), + thisCell, + otherNodeCell.hashCode(), + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + } + } + } + } + return when { + this.elementsIdentityEquals(mutableNode) -> this + otherNode.elementsIdentityEquals(mutableNode) -> otherNode + else -> mutableNode + } + } + + fun mutableRetainAll(otherNode: TrieNode, shift: Int, owner: MutabilityOwnership): Any? { + if (this === otherNode) { + return this + } + if (shift > MAX_SHIFT) { + return mutableCollisionRetainAll(otherNode, owner) + } + // intersection mask contains bits that are set in both inputs + // this mask is not final 'cos some children may have no intersection + val newBitMap = bitmap and otherNode.bitmap + // zero means no nodes intersect + if (newBitMap == 0) return EMPTY + val mutableNode = + if (ownedBy == owner && newBitMap == bitmap) this + else TrieNode(newBitMap, arrayOfNulls(newBitMap.countOneBits()), owner) + // we need to keep track of the real mask 'cos some of the children may intersect to nothing + var realBitMap = 0 + // for each bit in intersection mask, try to intersect children + newBitMap.forEachOneBit { positionMask, newNodeIndex -> + val thisIndex = indexOfCellAt(positionMask) + val otherNodeIndex = otherNode.indexOfCellAt(positionMask) + val newValue = run { + val thisCell = buffer[thisIndex] + val otherNodeCell = otherNode.buffer[otherNodeIndex] + val thisIsNode = thisCell is TrieNode<*> + val otherIsNode = otherNodeCell is TrieNode<*> + when { + // both are nodes -> merge them recursively + thisIsNode && otherIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as TrieNode + thisCell.mutableRetainAll( + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + // one of them is a node -> check containment + thisIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as E + if (thisCell.contains(otherNodeCell.hashCode(), otherNodeCell, shift + LOG_MAX_BRANCHING_FACTOR)) { + otherNodeCell + } else EMPTY + } + // same as last case, but reversed + otherIsNode -> @Suppress("UNCHECKED_CAST") { + otherNodeCell as TrieNode + thisCell as E + if (otherNodeCell.contains(thisCell.hashCode(), thisCell, shift + LOG_MAX_BRANCHING_FACTOR)) { + thisCell + } else EMPTY + } + // both are just E => compare them + thisCell == otherNodeCell -> thisCell + // both are just E, but different => return nothing + else -> EMPTY + } + } + if (newValue !== EMPTY) { + // elements that are not in realBitMap will be removed later + realBitMap = realBitMap or positionMask + } + mutableNode.buffer[newNodeIndex] = newValue + } + // resulting array's size is the popcount of resulting mask + val realSize = realBitMap.countOneBits() + return when { + realBitMap == 0 -> EMPTY + realBitMap == newBitMap -> { + when { + mutableNode.elementsIdentityEquals(this) -> this + mutableNode.elementsIdentityEquals(otherNode) -> otherNode + else -> mutableNode + } + } + // single values are kept only on root level + realSize == 1 && shift != 0 -> when (val single = mutableNode.buffer[mutableNode.indexOfCellAt(realBitMap)]) { + is TrieNode<*> -> TrieNode(realBitMap, arrayOf(single), owner) + else -> single + } + else -> { + // clean up all the EMPTYs in the resulting buffer + val realBuffer = arrayOfNulls(realSize) + mutableNode.buffer.filterTo(realBuffer) + TrieNode(realBitMap, realBuffer, owner) + } + } + } + + private fun mutableRemoveAll(otherNode: TrieNode, shift: Int, owner: MutabilityOwnership): Any? { + if (this === otherNode) { + return EMPTY + } + if (shift > MAX_SHIFT) { + return mutableCollisionRemoveAll(otherNode, owner) + } + // same as with intersection, only children of both nodes are considered + // this mask is not final 'cos some children may have no intersection + val removalBitmap = bitmap and otherNode.bitmap + // zero means no intersection => nothing to remove + if (removalBitmap == 0) return this + // node here is either us (if we are mutable) or a mutable copy + val mutableNode = + if (ownedBy == owner) this + else TrieNode(bitmap, buffer.copyOf(), owner) + // keep track of the real mask + var realBitMap = bitmap + removalBitmap.forEachOneBit { positionMask, _ -> + val thisIndex = indexOfCellAt(positionMask) + val otherNodeIndex = otherNode.indexOfCellAt(positionMask) + val newValue = run { + val thisCell = buffer[thisIndex] + val otherNodeCell = otherNode.buffer[otherNodeIndex] + val thisIsNode = thisCell is TrieNode<*> + val otherIsNode = otherNodeCell is TrieNode<*> + when { + // both are nodes -> merge them recursively + thisIsNode && otherIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as TrieNode + thisCell.mutableRemoveAll( + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner + ) + } + // one of them is a node -> remove single element + thisIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as E + val (removed, hasChanged) = thisCell.mutableRemove( + otherNodeCell.hashCode(), + otherNodeCell, + shift + LOG_MAX_BRANCHING_FACTOR, + owner) + + // additional check needed for removal + if (hasChanged) { + if (removed.buffer.size == 1 && removed.buffer[0] !is TrieNode<*>) removed.buffer[0] + else removed + } else thisCell + } + // same as last case, but reversed + otherIsNode -> @Suppress("UNCHECKED_CAST") { + otherNodeCell as TrieNode + thisCell as E + // "removing" a node from a value is basically checking if the value is contained in the node + if (otherNodeCell.contains(thisCell.hashCode(), thisCell, shift + LOG_MAX_BRANCHING_FACTOR)) { + EMPTY + } else thisCell + } + // both are just E => compare them + thisCell == otherNodeCell -> { + EMPTY + } + // both are just E, but different => nothing to remove, return left + else -> thisCell + } + } + if (newValue === EMPTY) { + // if we removed something, keep track + realBitMap = realBitMap xor positionMask + } + mutableNode.buffer[thisIndex] = newValue + } + // resulting size is popcount of the resulting mask + val realSize = realBitMap.countOneBits() + return when { + realBitMap == 0 -> EMPTY + realBitMap == bitmap -> { + when { + mutableNode.elementsIdentityEquals(this) -> this + else -> mutableNode + } + } + // single values are kept only on root level + realSize == 1 && shift != 0 -> when (val single = mutableNode.buffer[mutableNode.indexOfCellAt(realBitMap)]) { + is TrieNode<*> -> TrieNode(realBitMap, arrayOf(single), owner) + else -> single + } + else -> { + // clean up all the EMPTYs in the resulting buffer + val realBuffer = arrayOfNulls(realSize) + mutableNode.buffer.filterTo(realBuffer) + TrieNode(realBitMap, realBuffer, owner) + } + } + } + + private fun containsAll(otherNode: TrieNode, shift: Int): Boolean { + if (this === otherNode) return true + // essentially `buffer.containsAll(otherNode.buffer)` + if (shift > MAX_SHIFT) return otherNode.buffer.all { it in buffer } + + // potential bitmap is an intersection of input bitmaps + val potentialBitMap = bitmap and otherNode.bitmap + // left bitmap must contain right bitmap => right bitmap must be equal to intersection + if (potentialBitMap != otherNode.bitmap) return false + // check each child, shortcut to false if any one isn't contained + potentialBitMap.forEachOneBit { positionMask, _ -> + val thisIndex = indexOfCellAt(positionMask) + val otherNodeIndex = otherNode.indexOfCellAt(positionMask) + val thisCell = buffer[thisIndex] + val otherNodeCell = otherNode.buffer[otherNodeIndex] + val thisIsNode = thisCell is TrieNode<*> + val otherIsNode = otherNodeCell is TrieNode<*> + when { + // both are nodes => check recursively + thisIsNode && otherIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as TrieNode + thisCell.containsAll(otherNodeCell, shift + LOG_MAX_BRANCHING_FACTOR) || return false + } + // left is node, right is just E => check containment + thisIsNode -> @Suppress("UNCHECKED_CAST") { + thisCell as TrieNode + otherNodeCell as E + thisCell.contains(otherNodeCell.hashCode(), otherNodeCell, shift + LOG_MAX_BRANCHING_FACTOR) || return false + } + // left is just E, right is node => not possible + otherIsNode -> return false + // both are just E => containment is just equality + else -> thisCell == otherNodeCell || return false + } + } + return true + } + + fun add(elementHash: Int, element: E, shift: Int): TrieNode { + val cellPositionMask = 1 shl indexSegment( + elementHash, + shift + ) + + if (hasNoCellAt(cellPositionMask)) { // element is absent + return addElementAt(cellPositionMask, element, owner = null) + } + + val cellIndex = indexOfCellAt(cellPositionMask) + if (buffer[cellIndex] is TrieNode<*>) { // element may be in node + val targetNode = nodeAtIndex(cellIndex) + val newNode = if (shift == MAX_SHIFT) { + targetNode.collisionAdd(element) + } else { + targetNode.add(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR) + } + if (targetNode === newNode) return this + return setCellAtIndex(cellIndex, newNode, owner = null) + } + // element is directly in buffer + if (element == buffer[cellIndex]) return this + return moveElementToNode(cellIndex, elementHash, element, shift, owner = null) + } + + private fun mutableAdd(elementHash: Int, element: E, shift: Int, owner: MutabilityOwnership): TrieNode { + val cellPosition = 1 shl indexSegment(elementHash, shift) + + if (hasNoCellAt(cellPosition)) { // element is absent + return addElementAt(cellPosition, element, owner) + } + + val cellIndex = indexOfCellAt(cellPosition) + if (buffer[cellIndex] is TrieNode<*>) { // element may be in node + val targetNode = nodeAtIndex(cellIndex) + val newNode = if (shift == MAX_SHIFT) { + targetNode.mutableCollisionAdd(element, owner) + } else { + targetNode.mutableAdd(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + if (targetNode === newNode) return this + return setCellAtIndex(cellIndex, newNode, owner) + } + // element is directly in buffer + if (element == buffer[cellIndex]) return this + return moveElementToNode(cellIndex, elementHash, element, shift, owner) + } + + fun remove(elementHash: Int, element: E, shift: Int): TrieNode { + val cellPositionMask = 1 shl indexSegment( + elementHash, + shift + ) + + if (hasNoCellAt(cellPositionMask)) { // element is absent + return this + } + + val cellIndex = indexOfCellAt(cellPositionMask) + if (buffer[cellIndex] is TrieNode<*>) { // element may be in node + val targetNode = nodeAtIndex(cellIndex) + val newNode = if (shift == MAX_SHIFT) { + targetNode.collisionRemove(element) + } else { + targetNode.remove(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR) + } + if (targetNode === newNode) return this + return canonicalizeNodeAtIndex(cellIndex, newNode, owner = null) + } + // element is directly in buffer + if (element == buffer[cellIndex]) { + return removeCellAtIndex(cellIndex, cellPositionMask, owner = null) + } + return this + } + + private fun mutableRemove( + elementHash: Int, + element: E, + shift: Int, + owner: MutabilityOwnership + ): Pair, Boolean> { + val cellPositionMask = 1 shl indexSegment(elementHash, shift) + + if (hasNoCellAt(cellPositionMask)) { // element is absent + return this to false + } + + val cellIndex = indexOfCellAt(cellPositionMask) + if (buffer[cellIndex] is TrieNode<*>) { // element may be in node + val targetNode = nodeAtIndex(cellIndex) + val (newNode, hasChanged) = if (shift == MAX_SHIFT) { + targetNode.mutableCollisionRemove(element, owner) + } else { + targetNode.mutableRemove(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR, owner) + } + // If newNode is a single-element node, it is newly created, or targetNode is owned by mutator and a cell was removed in-place. + // Otherwise the single element would have been lifted up. + // If targetNode is owned by mutator, this node is also owned by mutator. Thus no new node will be created to replace this node. + // If newNode !== targetNode, it is newly created. + if (targetNode.ownedBy !== owner && targetNode === newNode) return this to hasChanged + return canonicalizeNodeAtIndex(cellIndex, newNode, owner) to hasChanged + } + // element is directly in buffer + if (element == buffer[cellIndex]) { + return removeCellAtIndex(cellIndex, cellPositionMask, owner) to true // check is empty + } + return this to false + } + + fun remove(element: E, owner: MutabilityOwnership): TrieNode = + mutableRemove(element.hashCode(), element, 0, owner).first + + @Suppress("UNCHECKED_CAST") + fun removeAll(otherNode: TrieNode, owner: MutabilityOwnership): TrieNode = + mutableRemoveAll(otherNode, 0, owner) as TrieNode + + fun add(element: E, owner: MutabilityOwnership): TrieNode = + mutableAdd(element.hashCode(), element, 0, owner) + + fun addAll(otherNode: TrieNode, owner: MutabilityOwnership): TrieNode = + mutableAddAll(otherNode, 0, owner) + + fun contains(element: E): Boolean = contains(element.hashCode(), element, 0) + + fun containsAll(otherNode: TrieNode): Boolean = containsAll(otherNode, 0) + + @Suppress("UNCHECKED_CAST") + fun retainAll(otherNode: TrieNode, owner: MutabilityOwnership) = + mutableRetainAll(otherNode, 0, owner) as TrieNode + + @Suppress("UNCHECKED_CAST") + fun clear() = EMPTY as TrieNode + + override fun toString(): String = + iterator().asSequence().joinToString(separator = ", ", prefix = "{", postfix = "}") { it.toString() } + + override fun hashCode(): Int = sumOf { it.hashCode() } + + override fun equals(other: Any?): Boolean { + @Suppress("UNCHECKED_CAST") + other as? TrieNode ?: return false + return other.calculateSize() == this.calculateSize() && other.containsAll(this) + } + + override fun iterator(): Iterator = UPersistentHashSetIterator(this) + + internal companion object { + internal val EMPTY = TrieNode(0, emptyArray()) + } +} diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/UPersistentHashSetIterator.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/UPersistentHashSetIterator.kt new file mode 100644 index 0000000000..a57bfa711c --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/implementations/immutableSet/UPersistentHashSetIterator.kt @@ -0,0 +1,120 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.implementations.immutableSet + +internal open class UPersistentHashSetIterator(val node: TrieNode) : Iterator { + protected val path = mutableListOf(TrieNodeIterator()) + protected var pathLastIndex = 0 + private var hasNext = true + + init { + path[0].reset(node.buffer) + pathLastIndex = 0 + ensureNextElementIsReady() + } + + private fun moveToNextNodeWithData(pathIndex: Int): Int { + if (path[pathIndex].hasNextElement()) { + return pathIndex + } + if (path[pathIndex].hasNextNode()) { + val node = path[pathIndex].currentNode() + + if (pathIndex + 1 == path.size) { + path.add(TrieNodeIterator()) + } + path[pathIndex + 1].reset(node.buffer) + return moveToNextNodeWithData(pathIndex + 1) + } + return -1 + } + + private fun ensureNextElementIsReady() { + if (path[pathLastIndex].hasNextElement()) { + return + } + for (i in pathLastIndex downTo 0) { + var result = moveToNextNodeWithData(i) + + if (result == -1 && path[i].hasNextCell()) { + path[i].moveToNextCell() + result = moveToNextNodeWithData(i) + } + if (result != -1) { + pathLastIndex = result + return + } + if (i > 0) { + path[i - 1].moveToNextCell() + } + path[i].reset(TrieNode.EMPTY.buffer, 0) + } + hasNext = false + } + + override fun hasNext(): Boolean { + return hasNext + } + + override fun next(): E { + if (!hasNext) + throw NoSuchElementException() + + val result = path[pathLastIndex].nextElement() + ensureNextElementIsReady() + return result + } + + protected fun currentElement(): E { + assert(hasNext()) + return path[pathLastIndex].currentElement() + } +} + +internal class TrieNodeIterator { + private var buffer = TrieNode.EMPTY.buffer + private var index = 0 + + fun reset(buffer: Array, index: Int = 0) { + this.buffer = buffer + this.index = index + } + + fun hasNextCell(): Boolean { + return index < buffer.size + } + + fun moveToNextCell() { + assert(hasNextCell()) + index++ + } + + fun hasNextElement(): Boolean { + return hasNextCell() && buffer[index] !is TrieNode<*> + } + + fun currentElement(): E { + assert(hasNextElement()) + @Suppress("UNCHECKED_CAST") + return buffer[index] as E + } + + fun nextElement(): E { + assert(hasNextElement()) + @Suppress("UNCHECKED_CAST") + return buffer[index++] as E + } + + fun hasNextNode(): Boolean { + return hasNextCell() && buffer[index] is TrieNode<*> + } + + fun currentNode(): TrieNode { + assert(hasNextNode()) + @Suppress("UNCHECKED_CAST") + return buffer[index] as TrieNode + } +} diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/ForEachOneBit.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/ForEachOneBit.kt new file mode 100644 index 0000000000..c2421d7ca9 --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/ForEachOneBit.kt @@ -0,0 +1,18 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.internal + +// 'iterate' all the bits set to one in a given integer, in the form of one-bit masks +internal inline fun Int.forEachOneBit(body: (mask: Int, index: Int) -> Unit) { + var mask = this + var index = 0 + while (mask != 0) { + val bit = mask.takeLowestOneBit() + body(bit, index) + index++ + mask = mask xor bit + } +} diff --git a/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/MutabilityOwnership.kt b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/MutabilityOwnership.kt new file mode 100644 index 0000000000..c86c67ca7d --- /dev/null +++ b/usvm-util/src/main/kotlin/org/usvm/collections/immutable/internal/MutabilityOwnership.kt @@ -0,0 +1,13 @@ +/* + * Copyright 2016-2019 JetBrains s.r.o. + * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file. + */ + +package org.usvm.collections.immutable.internal + +/** + * The mutability ownership token of a persistent collection builder. + * + * Used to mark persistent data structures, that are owned by a collection builder and can be mutated by it. + */ +class MutabilityOwnership