Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] First step in supporting type inference, based on the usage of sq_concat #207

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasNbPositiv
return type->tp_as_number && type->tp_as_number->nb_positive;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqConcat(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_concat;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqLength(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ val sampleStringFunction = StringProgramProvider(
/**
* Sample of a function that cannot be covered right now.
* */
val listConcatProgram = StringProgramProvider(
val tupleConcatProgram = StringProgramProvider(
"""
def list_concat(x):
y = x + [1]
if len(y[::-1]) == 5:
return 1
return 2
def tuple_concat(x, y):
z = x + y
return z + (1, 2, 3)
""".trimIndent(),
"list_concat",
) { listOf(PythonAnyType) }
"tuple_concat",
) { listOf(PythonAnyType, PythonAnyType) }
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ class SimpleTypeInferenceTest: PythonTestRunnerForPrimitiveProgram("SimpleTypeIn
)
}

@Test
fun testListConcatUsage() {
check2WithConcreteRun(
constructFunction("list_concat_usage", List(2) { PythonAnyType }),
ignoreNumberOfAnalysisResults,
standardConcolicAndConcreteChecks,
/* invariants = */ emptyList(),
/* propertiesToDiscover = */ listOf(
{ _, _, res -> res.selfTypeName == "AssertionError" },
{ _, _, res -> res.repr == "None" }
)
)
}

@Test
fun testLenUsage() {
check1WithConcreteRun(
Expand Down
6 changes: 6 additions & 0 deletions usvm-python/src/test/resources/samples/SimpleTypeInference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def isinstance_sample(x):
return "Not reachable"


def list_concat_usage(x, y):
z = x + y
z += []
assert z


def len_usage(x):
if len(x) == 5:
return 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class CPythonAdapter {
public native int typeHasNbMatrixMultiply(long type);
public native int typeHasNbNegative(long type);
public native int typeHasNbPositive(long type);
public native int typeHasSqConcat(long type);
public native int typeHasSqLength(long type);
public native int typeHasMpLength(long type);
public native int typeHasMpSubscript(long type);
Expand Down Expand Up @@ -1071,6 +1072,18 @@ public static void notifyNbPositive(ConcolicRunContext context, SymbolForCPython
nbPositiveKt(context, on.obj);
}

@CPythonAdapterJavaMethod(cName = "sq_concat")
@CPythonFunction(
argCTypes = {CType.PyObject, CType.PyObject},
argConverters = {ObjectConverter.StandardConverter, ObjectConverter.StandardConverter}
)
public static void notifySqConcat(ConcolicRunContext context, SymbolForCPython left, SymbolForCPython right) {
if (left.obj == null || right.obj == null)
return;
context.curOperation = new MockHeader(SqConcatMethod.INSTANCE, Arrays.asList(left.obj, right.obj), null);
sqConcatKt(context, left.obj, right.obj);
}

@CPythonAdapterJavaMethod(cName = "sq_length")
@CPythonFunction(
argCTypes = {CType.PyObject},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ data object NbMultiplyMethod : TypeMethod(false)
data object NbMatrixMultiplyMethod : TypeMethod(false)
data object NbNegativeMethod : TypeMethod(false)
data object NbPositiveMethod : TypeMethod(false)
data object SqConcatMethod : TypeMethod(false)
data object SqLengthMethod : TypeMethod(true)
data object MpSubscriptMethod : TypeMethod(false)
data object MpAssSubscriptMethod : TypeMethod(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ object ConcretePythonInterpreter {
val typeHasNbMatrixMultiply = createTypeQuery { pythonAdapter.typeHasNbMatrixMultiply(it) }
val typeHasNbNegative = createTypeQuery { pythonAdapter.typeHasNbNegative(it) }
val typeHasNbPositive = createTypeQuery { pythonAdapter.typeHasNbPositive(it) }
val typeHasSqConcat = createTypeQuery { pythonAdapter.typeHasSqConcat(it) }
val typeHasSqLength = createTypeQuery { pythonAdapter.typeHasSqLength(it) }
val typeHasMpLength = createTypeQuery { pythonAdapter.typeHasMpLength(it) }
val typeHasMpSubscript = createTypeQuery { pythonAdapter.typeHasMpSubscript(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.usvm.machine.types.HasNbNegative
import org.usvm.machine.types.HasNbPositive
import org.usvm.machine.types.HasNbSubtract
import org.usvm.machine.types.HasSqConcat
import org.usvm.machine.types.HasSqLength
import org.usvm.machine.types.HasTpCall
import org.usvm.machine.types.HasTpGetattro
Expand All @@ -37,7 +38,15 @@
context.ctx
) {
context.curState ?: return
pyAssert(context, left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd))
/*
The __add__ method corresponds both to the nb_add and sq_concat slots,
so it is crucial not to assert the presence of nb_add, but to fork on these
two possible options.
tochilinak marked this conversation as resolved.
Show resolved Hide resolved
*/
Fixed Show fixed Hide fixed
val nbAdd = left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd)
val sqConcat = left.evalIsSoft(context, HasSqConcat) and right.evalIsSoft(context, HasSqConcat)
pyAssert(context, nbAdd.not() implies sqConcat)
pyFork(context, nbAdd)
}

fun nbSubtractKt(
Expand Down Expand Up @@ -74,6 +83,17 @@
pyAssert(context, on.evalIsSoft(context, HasNbPositive))
}

fun sqConcatKt(
context: ConcolicRunContext,
left: UninterpretedSymbolicPythonObject,
right: UninterpretedSymbolicPythonObject,
) = with(
context.ctx
) {
context.curState ?: return
pyAssert(context, left.evalIsSoft(context, HasSqConcat) and right.evalIsSoft(context, HasSqConcat))
}

fun sqLengthKt(context: ConcolicRunContext, on: UninterpretedSymbolicPythonObject) {
context.curState ?: return
val sqLength = on.evalIsSoft(context, HasSqLength)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.usvm.language.NbMultiplyMethod
import org.usvm.language.NbNegativeMethod
import org.usvm.language.NbPositiveMethod
import org.usvm.language.NbSubtractMethod
import org.usvm.language.SqConcatMethod
import org.usvm.language.SqLengthMethod
import org.usvm.language.TpCallMethod
import org.usvm.language.TpGetattro
Expand Down Expand Up @@ -95,6 +96,10 @@ class SymbolTypeTree(
listOf(createBinaryProtocol("__mul__", pythonAnyType, returnType))
}

SqConcatMethod -> { returnType: UtType ->
listOf(createBinaryProtocol("__add__", pythonAnyType, returnType))
}

SqLengthMethod -> { _: UtType ->
listOf(createUnaryProtocol("__len__", typeHintsStorage.pythonInt))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ object HasNbPositive : TypeProtocol() {
ConcretePythonInterpreter.typeHasNbPositive(type.asObject)
}

object HasSqConcat : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqConcat(type.asObject)
}

object HasSqLength : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqLength(type.asObject)
Expand Down
Loading