Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jefremof committed Aug 23, 2024
1 parent af6a944 commit 8ebc128
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 50 deletions.
14 changes: 9 additions & 5 deletions usvm-python/cpythonadapter/src/main/c/virtual_objects.c
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ initialize_virtual_object_ready_types() {
void
deinitialize_virtual_object_ready_types() {
Py_DECREF(ready_virtual_object_types);
ready_virtual_object_types = 0;
}

void
Expand All @@ -261,6 +262,7 @@ initialize_virtual_object_available_slots() {
void
deinitialize_virtual_object_available_slots() {
PyMem_RawFree(AVAILABLE_SLOTS);
AVAILABLE_SLOTS = 0;
}

#define MASK_SIZE (sizeof(unsigned char) * CHAR_BIT)
Expand Down Expand Up @@ -366,13 +368,15 @@ allocate_raw_virtual_object(JNIEnv *env, jobject object, jbyteArray mask) {
return result;
}

// Since there are about 80 slots, a mask with 96 bits (12 bytes) in it
// should be enough to cover all of them
#define MAX_NEEDED_MASK_BYTE_NUMBER 12

PyObject *
allocate_raw_virtual_object_with_all_slots(JNIEnv *env, jobject object) {
// There are less than 90 slots, so 12 bytes are enough.
// That array should be able to cover all available slots.
const unsigned char all = 0b11111111;
const unsigned char mask[12] = {all, all, all, all, all, all, all, all, all, all, all, all};
return _allocate_raw_virtual_object(env, object, mask, 12);
const unsigned char all = 0b11111111; // This byte enables all 8 slots that сorrespond to it.
const unsigned char mask[MAX_NEEDED_MASK_BYTE_NUMBER] = {all, all, all, all, all, all, all, all, all, all, all, all};
return _allocate_raw_virtual_object(env, object, mask, MAX_NEEDED_MASK_BYTE_NUMBER);
}

void
Expand Down
38 changes: 18 additions & 20 deletions usvm-python/src/test/kotlin/org/usvm/samples/VirtualObjectsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ class VirtualObjectsTest {

fun checkSlotDisabling(
slotId: SlotId,
checkMethod: ((PyObject) -> Boolean)? = null
): Boolean? {
val obj = VirtualPythonObject(-1 - slotId.ordinal)
obj.slotMask.setSlotBit(slotId, false)
val pyObj = ConcretePythonInterpreter.allocateVirtualObject(obj)
val type = ConcretePythonInterpreter.getPythonObjectType(pyObj)
val method = checkMethod ?: (slotMethods[slotId] ?: return null)
return method(type)
val method = slotMethods[slotId] ?: return null
return !method(type)
}

@Test
Expand All @@ -64,68 +63,67 @@ class VirtualObjectsTest {
@Test
fun testNbBoolDisabled() {
val result = checkSlotDisabling(SlotId.NbBool) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testNbAddDisabled() {
val result = checkSlotDisabling(SlotId.NbAdd) ?: return
assertFalse(result)

assertTrue(result)
}

@Test
fun testNbSubtractDisabled() {
val result = checkSlotDisabling(SlotId.NbSubtract) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testNbMultiplyDisabled() {
val result = checkSlotDisabling(SlotId.NbMultiply) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testNbMatrixMultiplyDisabled() {
val result = checkSlotDisabling(SlotId.NbMatrixMultiply) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testNbNegativeDisabled() {
val result = checkSlotDisabling(SlotId.NbNegative) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testNbPositiveDisabled() {
val result = checkSlotDisabling(SlotId.NbPositive) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testSqLengthDisabled() {
val result = checkSlotDisabling(SlotId.SqLength) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testSqConcatDisabled() {
val result = checkSlotDisabling(SlotId.SqConcat) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testMpSubscriptDisabled() {
val result = checkSlotDisabling(SlotId.MpSubscript) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testMpAssSubscriptDisabled() {
val result = checkSlotDisabling(SlotId.MpAssSubscript) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
Expand All @@ -147,30 +145,30 @@ class VirtualObjectsTest {
@Test
fun testTpGetattroDisabled() {
val result = checkSlotDisabling(SlotId.TpGetattro) ?: return
assertTrue(result)
assertFalse(result) // tp_getattro is marked as mandatory
}

@Test
fun testTpSetattroDisabled() {
val result = checkSlotDisabling(SlotId.TpSetattro) ?: return
assertTrue(result) // tp_setattro is marked as mandatory
assertFalse(result) // tp_setattro is marked as mandatory
}

@Test
fun testTpIterDisabled() {
val result = checkSlotDisabling(SlotId.TpIter) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testTpCallDisabled() {
val result = checkSlotDisabling(SlotId.TpCall) ?: return
assertFalse(result)
assertTrue(result)
}

@Test
fun testTpHashDisabled() {
val result = checkSlotDisabling(SlotId.TpHash) ?: return
assertTrue(result) // tp_hash is marked as mandatory
assertFalse(result) // tp_hash is marked as mandatory
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package org.usvm.annotations.codegeneration
import org.usvm.annotations.ids.SlotId

fun generateAvailableSlotInitialization(): String {
val filtered = SlotId.entries.filter {!it.mandatory};
val size = filtered.size;
val filtered = SlotId.entries.filter { !it.mandatory };

Check warning

Code scanning / detekt

Detects semicolons Warning

Unnecessary semicolon
val size = filtered.size
val prefix = """
AVAILABLE_SLOTS = PyMem_RawMalloc(sizeof(PyType_Slot) * ${size + 1});
""".trimIndent()
Expand All @@ -20,17 +20,18 @@ fun generateAvailableSlotInitialization(): String {
}

fun generateMandatorySlotMacro(): String {
val size = SlotId.values().filter {it.mandatory}.size;
val number_macro = "#define MANDATORY_SLOTS_NUMBER $size".trimIndent()
val filtered = SlotId.values().filter { it.mandatory }
val size = filtered.size
val numberMacro = "#define MANDATORY_SLOTS_NUMBER $size".trimIndent()
val prefix = "#define INCLUDE_MANDATORY_SLOTS".trimIndent()
if (size == 0) {
return number_macro + "\n\n" + prefix + "\n"
return numberMacro + "\n\n" + prefix + "\n"
}
val items = SlotId.entries.filter {it.mandatory}.map {
val items = filtered.map {
requireNotNull(it.slotName)
"slots[i++] = Virtual_${it.slotName};".trimIndent()
}
return number_macro + "\n\n" +
return numberMacro + "\n\n" +
prefix + " \\\n" +
items.joinToString("\n").replace("\n", "\\\n") + "\n"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ package org.usvm.annotations.ids
* tp_hash cannot be disabled, however
* if you don't explicitly specify that the type HAS that slot,
* you will not be able to disable the tp_richcompare slot
*
*
* For that reason, some slots are marked as mandatory.
* They cannot be disabled using mask, so they do not have
* a mask bit number.
*
*
* The usage of swapSlotBit or setSlotBit with these slots
* will not have any effect on the mask.
*/
enum class SlotId(
val slotName: String,
val mandatory: Boolean = false
val mandatory: Boolean = false,
) {
TpGetattro("tp_getattro"),
TpGetattro("tp_getattro", true),
TpSetattro("tp_setattro", true),
TpRichcompare("tp_richcompare"),
TpIter("tp_iter"),
Expand All @@ -33,37 +33,39 @@ enum class SlotId(
SqLength("sq_length"),
MpSubscript("mp_subscript"),
MpAssSubscript("mp_ass_subscript"),
SqConcat("sq_concat");
SqConcat("sq_concat"),
;

companion object {
init {
values().filter {!it.mandatory}.forEachIndexed {
index, entry -> entry.maskBit = index
values().filter { !it.mandatory }.forEachIndexed {
index, entry ->
entry.maskBit = index

Check warning

Code scanning / detekt

Reports mis-indented code Warning

Unexpected indentation (20) (should be 16)
}
}
}
private var maskBit: Int? = null
fun getMaskBit(): Int = maskBit!!
fun getMaskBit(): Int = maskBit ?: error("No bits in the mask correspond to a mandatory slot.")
}

fun ByteArray.swapSlotBit(slot: SlotId): ByteArray {
if (slot.mandatory) return this
val bitPosition = this.size * 8 - 1 - slot.getMaskBit()
val byteIndex = bitPosition / 8
val bitMask = 1 shl (slot.getMaskBit() % 8)
val bitPosition = this.size * Byte.SIZE_BITS - 1 - slot.getMaskBit()
val byteIndex = bitPosition / Byte.SIZE_BITS
val bitMask = 1 shl (slot.getMaskBit() % Byte.SIZE_BITS)
this[byteIndex] = (this[byteIndex].toInt() xor bitMask).toByte()
return this // just to allow Builder-like usage
}

fun ByteArray.setSlotBit(slot: SlotId, state: Boolean): ByteArray {
if (slot.mandatory) return this
val bitPosition = this.size * 8 - 1 - slot.getMaskBit()
val byteIndex = bitPosition / 8
val bitMask = 1 shl (slot.getMaskBit() % 8)
val bitPosition = this.size * Byte.SIZE_BITS - 1 - slot.getMaskBit()
val byteIndex = bitPosition / Byte.SIZE_BITS
val bitMask = 1 shl (slot.getMaskBit() % Byte.SIZE_BITS)
if (state) {
this[byteIndex] = (this[byteIndex].toInt() or bitMask).toByte()
} else {
this[byteIndex] = (this[byteIndex].toInt() and (bitMask.inv())).toByte()
}
return this // just to allow Builder-like usage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ object ConcretePythonInterpreter {
*
* pythonAdapter.allocateRawVirtualObjectWithAllSlots(object) does exactly the same as
* pythonAdapter.allocateRawVirtualObject(virtualObject, List(12) {0b11111111.toByte()}.toByteArray())
*
* In order to manually enable/disable some slots, use swapSlotBit or setSlotBit:
* pythonAdapter.allocateRawVirtualObject(obj, obj.slotMask.swapSlotBit(SlotId.NbAdd))
* pythonAdapter.allocateRawVirtualObject(obj, obj.slotMask.setSlotBit(SlotId.NbAdd, false))
*/
val ref = pythonAdapter.allocateRawVirtualObject(virtualObject, virtualObject.slotMask)
if (ref == 0L) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package org.usvm.machine.interpreters.concrete.utils

const val MAX_NEEDED_MASK_BYTE_NUMBER: Int = 12
const val ALL_SLOTS_BYTE: Int = 0b11111111

class VirtualPythonObject(
@JvmField
val interpretedObjRef: Int,
val slotMask: ByteArray = List(12) { 0b11111111.toByte() }.toByteArray()
val slotMask: ByteArray = List(MAX_NEEDED_MASK_BYTE_NUMBER) { ALL_SLOTS_BYTE.toByte()}.toByteArray(),

Check warning

Code scanning / detekt

Reports spaces around curly braces Warning

Missing spacing before "}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ data class ConcreteTypeNegation(val concreteType: ConcretePythonType) : VirtualP
}
}

sealed class TypeProtocol(private val slotId: SlotId? = null) : VirtualPythonType() {
/*
* Temporary slotId is nullable,
* since some slots, such as nb_int, nb_index and mp_length,
* are missing their implementation in virtual_objects.c
*/
sealed class TypeProtocol(val slotId: SlotId? = null) : VirtualPythonType() {
abstract fun acceptsConcrete(type: ConcretePythonType): Boolean
override fun accepts(type: PythonType): Boolean {
if (type == this || type is MockType) {
Expand Down Expand Up @@ -129,11 +134,13 @@ object HasTpRichcmp : TypeProtocol(SlotId.TpRichcompare) {
ConcretePythonInterpreter.typeHasTpRichcmp(type.asObject)
}

// TODO: since we cannot turn off this slot on virtual object, we may need to remove this [TypeProtocol] in the future.
object HasTpGetattro : TypeProtocol(SlotId.TpGetattro) {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasTpGetattro(type.asObject)
}

// TODO: since we cannot turn off this slot on virtual object, we may need to remove this [TypeProtocol] in the future.
object HasTpSetattro : TypeProtocol(SlotId.TpSetattro) {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasTpSetattro(type.asObject)
Expand All @@ -149,6 +156,7 @@ object HasTpCall : TypeProtocol(SlotId.TpCall) {
ConcretePythonInterpreter.typeHasTpCall(type.asObject)
}

// TODO: since we cannot turn off this slot on virtual object, we may need to remove this [TypeProtocol] in the future.
object HasTpHash : TypeProtocol(SlotId.TpHash) {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasTpHash(type.asObject)
Expand Down

0 comments on commit 8ebc128

Please sign in to comment.