Skip to content

Commit

Permalink
Coerce return, support methods unique coercion, fix for void
Browse files Browse the repository at this point in the history
  • Loading branch information
sakehl committed Aug 8, 2024
1 parent db267fc commit a084769
Show file tree
Hide file tree
Showing 12 changed files with 482 additions and 101 deletions.
4 changes: 4 additions & 0 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,10 @@ final case class NewConstPointerArray[G](element: Type[G], size: Expr[G])(
)(implicit val o: Origin)
extends NewPointer[G] with NewConstPointerArrayImpl[G]

final case class UniquePointerCoercion[G](e: Expr[G], t: Type[G])(
implicit val o: Origin
) extends Expr[G] with UniquePointerCoercionImpl[G]

final case class FreePointer[G](pointer: Expr[G])(
val blame: Blame[PointerFreeError]
)(implicit val o: Origin)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package vct.col.ast.unsorted
package vct.col.ast.family.coercion

import vct.col.ast.CoerceBetweenUniquePointer
import vct.col.ast.ops.CoerceBetweenUniquePointerOps
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package vct.col.ast.unsorted
package vct.col.ast.family.coercion

import vct.col.ast.CoerceFromUniquePointer
import vct.col.ast.ops.CoerceFromUniquePointerOps
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package vct.col.ast.unsorted
package vct.col.ast.family.coercion

import vct.col.ast.{CoerceToUniquePointer, TUniquePointer, Type}
import vct.col.ast.ops.CoerceToUniquePointerOps
Expand Down
23 changes: 18 additions & 5 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package vct.col.typerules

import com.typesafe.scalalogging.LazyLogging
import hre.util.FuncTools
import hre.util.{FuncTools, ScopedStack}
import vct.col.ast._
import vct.col.ast.rewrite.BaseCoercingRewriter
import vct.col.ast.`type`.typeclass.TFloats
Expand Down Expand Up @@ -53,6 +53,7 @@ abstract class CoercingRewriter[Pre <: Generation]()
import CoercingRewriter._

type Post = Rewritten[Pre]
val resultType: ScopedStack[Type[Pre]] = ScopedStack()

val coercedDeclaration: SuccessionMap[Declaration[Pre], Declaration[Pre]] =
SuccessionMap()
Expand Down Expand Up @@ -375,9 +376,16 @@ abstract class CoercingRewriter[Pre <: Generation]()
def postCoerce(decl: Declaration[Pre]): Unit =
allScopes.anySucceed(decl, decl.rewriteDefault())
override final def dispatch(decl: Declaration[Pre]): Unit = {
val coercedDecl = coerce(preCoerce(decl))
coercedDeclaration(decl) = coercedDecl
postCoerce(coercedDecl)
def rewrite() : Unit = {
val coercedDecl = coerce(preCoerce(decl))
coercedDeclaration(decl) = coercedDecl
postCoerce(coercedDecl)
}
decl match {
case m: AbstractMethod[Pre] =>
resultType.having(m.returnType)({rewrite()})
case _ => rewrite()
}
}

def coerce(node: Coercion[Pre]): Coercion[Pre] = {
Expand Down Expand Up @@ -1981,6 +1989,7 @@ abstract class CoercingRewriter[Pre <: Generation]()
UMinus(float(arg)),
UMinus(rat(arg)),
)
case u: UniquePointerCoercion[Pre] => u
case u @ Unfolding(pred, body) => Unfolding(pred, body)(u.blame)
case UntypedLiteralBag(values) =>
val sharedType = Types.leastCommonSuperType(values.map(_.t))
Expand Down Expand Up @@ -2251,7 +2260,11 @@ abstract class CoercingRewriter[Pre <: Generation]()
case Recv(ref) => Recv(ref)
case r @ Refute(assn) => Refute(res(assn))(r.blame)
case Return(result) =>
Return(result) // TODO coerce return, make AmbiguousReturn?
if(resultType.nonEmpty){
Return(coerce(result, resultType.top)) // TODO coerce return, make AmbiguousReturn?
} else {
Return(result)
}
case Scope(locals, body) => Scope(locals, body)
case send @ Send(decl, offset, resource) =>
Send(decl, offset, res(resource))(send.blame)
Expand Down
22 changes: 18 additions & 4 deletions src/col/vct/col/typerules/CoercionUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,28 @@ case object CoercionUtils {
def getCoercion[G](source: Type[G], target: Type[G]): Option[Coercion[G]] =
getAnyCoercion(source, target).filter(_.isCPromoting)

// We don't want pointers to coerce just between anything, just some things we allow
def getPointerCoercion[G](source: Type[G], target: Type[G], innerSource: Type[G], innerTarget: Type[G]) : Option[Coercion[G]] = {
Some((innerSource, innerTarget) match {
case (i,l) if i == l => CoerceIdentity(source)
case (TUnique(l, lId), TUnique(r, rId)) =>
case (l,r) if l == r => CoerceIdentity(source)
case (TCInt(), TInt()) => CoerceIdentity(source)
case (CPrimitiveType(specs), r) =>
specs.collectFirst { case spec: CSpecificationType[G] => spec } match {
case Some(CSpecificationType(t)) =>
return getPointerCoercion(source, target, t, r)
case None => return None
}
case (l, CPrimitiveType(specs)) =>
specs.collectFirst { case spec: CSpecificationType[G] => spec } match {
case Some(CSpecificationType(t)) =>
return getPointerCoercion(source, target, l, t)
case None => return None
}
case (TUnique(l, _), TUnique(r, _)) =>
if(l == r) CoerceBetweenUniquePointer(source, target) else return None
case (TUnique(l, lId), r) =>
case (TUnique(l, _), r) =>
if(l == r) CoerceFromUniquePointer(source, target) else return None
case (TUnique(l, lId), r) =>
case (l, TUnique(r, _)) =>
if(l == r) CoerceToUniquePointer(source, target) else return None
case _ => return None
})
Expand Down
2 changes: 2 additions & 0 deletions src/main/vct/main/stages/Transformation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import vct.rewrite.{
MonomorphizeClass,
SmtlibToProverTypes,
TypeQualifierCoercion,
MakeUniqueMethodCopies,
}
import vct.rewrite.lang.ReplaceSYCLTypes
import vct.rewrite.veymont._
Expand Down Expand Up @@ -306,6 +307,7 @@ case class SilverTransformation(
ReplaceSYCLTypes,
CFloatIntCoercion,
TypeQualifierCoercion,
MakeUniqueMethodCopies,
ComputeBipGlue,
InstantiateBipSynchronizations,
EncodeBipPermissions,
Expand Down
15 changes: 9 additions & 6 deletions src/rewrite/vct/rewrite/EncodeArrayValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,14 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
): (Procedure[Post], FreePointer[Pre] => PointerFreeFailed[Pre]) = {
implicit val o: Origin = freeFuncOrigin
var errors: Seq[Expr[Pre] => PointerFreeError] = Seq()
val innerT = t match {
case TPointer(it) => it
case TUniquePointer(it, _) => it
}

val proc = globalDeclarations.declare({
val (vars, ptr) = variables.collect {
val a_var = new Variable[Post](TPointer(t))(o.where(name = "p"))
val a_var = new Variable[Post](t)(o.where(name = "p"))
variables.declare(a_var)
Local[Post](a_var.ref)
}
Expand Down Expand Up @@ -179,7 +183,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
IteratedPtrInjective,
)
requiresT =
if (!typeIsRef(t))
if (!typeIsRef(innerT))
requiresT
else {
// I think this error actually can never happen, since we require full write permission already
Expand All @@ -192,7 +196,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
// If structure contains structs, the permission for those fields need to be released as well
val permFields =
t match {
case t: TClass[Post] => unwrapStructPerm(access, t, o, makeStruct)
case innerT: TClass[Post] => unwrapStructPerm(access, innerT, o, makeStruct)
case _ => Seq()
}
requiresT =
Expand All @@ -213,7 +217,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
body = None,
requires = requiresPred,
decreases = Some(DecreasesClauseNoRecursion[Post]()),
)(o.where("free_" + t.toString))
)(o.where("free_" + innerT.toString))
})
(proc, (node: FreePointer[Pre]) => PointerFreeFailed(node, errors))
}
Expand Down Expand Up @@ -633,8 +637,7 @@ case class EncodeArrayValues[Pre <: Generation]() extends Rewriter[Pre] {
)(PointerArrayCreationFailed(ncpa))
case free @ FreePointer(xs) =>
val newXs = dispatch(xs)
val TPointer(t) = newXs.t
val (freeFunc, freeBlame) = freeMethods.getOrElseUpdate(t, makeFree(t))
val (freeFunc, freeBlame) = freeMethods.getOrElseUpdate(newXs.t, makeFree(newXs.t))
ProcedureInvocation[Post](freeFunc.ref, Seq(newXs), Nil, Nil, Nil, Nil)(
freeBlame(free)
)(free.o)
Expand Down
Loading

0 comments on commit a084769

Please sign in to comment.