Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Civl] Permissions cleanup #956

Merged
merged 10 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 57 additions & 21 deletions Source/Concurrency/LinearRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,12 @@ public List<Cmd> RewriteCallCmd(CallCmd callCmd)
{
switch (Monomorphizer.GetOriginalDecl(callCmd.Proc).Name)
{
case "One_New":
case "One_To_Fractions":
case "Fractions_To_One":
case "Cell_Pack":
case "Cell_Unpack":
case "Loc_New":
case "KeyedLocSet_New":
case "Set_MakeEmpty":
case "Map_MakeEmpty":
case "Map_Pack":
case "Map_Unpack":
case "Map_Assume":
return new List<Cmd>{callCmd};
case "Set_Split":
return RewriteSetSplit(callCmd);
Expand All @@ -180,6 +176,10 @@ public List<Cmd> RewriteCallCmd(CallCmd callCmd)
return RewriteMapGet(callCmd);
case "Map_Put":
return RewriteMapPut(callCmd);
case "Map_GetValue":
return RewriteMapGetValue(callCmd);
case "Map_PutValue":
return RewriteMapPutValue(callCmd);
default:
Contract.Assume(false);
return null;
Expand Down Expand Up @@ -270,14 +270,6 @@ private Function OneConstructor(Type type)
return oneConstructor;
}

private Function CellConstructor(Type keyType, Type valType)
{
var actualTypeParams = new List<Type>() { keyType, valType };
var cellTypeCtorDecl = (DatatypeTypeCtorDecl)monomorphizer.InstantiateTypeCtorDecl("Cell", actualTypeParams);
var cellConstructor = cellTypeCtorDecl.Constructors[0];
return cellConstructor;
}

private Function SetConstructor(Type type)
{
var actualTypeParams = new List<Type>() { type };
Expand Down Expand Up @@ -451,7 +443,8 @@ private List<Cmd> RewriteMapGet(CallCmd callCmd)
var cmdSeq = new List<Cmd>();
var path = callCmd.Ins[0];
var k = callCmd.Ins[1];
var c = callCmd.Outs[0];
var l = callCmd.Outs[0];
var v = callCmd.Outs[1];

var instantiation = monomorphizer.GetTypeInstantiation(callCmd.Proc);
var domain = instantiation["K"];
Expand All @@ -460,9 +453,9 @@ private List<Cmd> RewriteMapGet(CallCmd callCmd)
var mapRemoveFunc = MapRemove(domain, range);
var mapAtFunc = MapAt(domain, range);
cmdSeq.Add(AssertCmd(callCmd.tok, ExprHelper.FunctionCall(mapContainsFunc, path, k), "Map_Get failed"));
var cellConstructor = CellConstructor(domain, range);
cmdSeq.Add(
CmdHelper.AssignCmd(c.Decl, ExprHelper.FunctionCall(cellConstructor, k, ExprHelper.FunctionCall(mapAtFunc, path, k))));
var oneConstructor = OneConstructor(domain);
cmdSeq.Add(CmdHelper.AssignCmd(l.Decl, ExprHelper.FunctionCall(oneConstructor, k)));
cmdSeq.Add(CmdHelper.AssignCmd(v.Decl, ExprHelper.FunctionCall(mapAtFunc, path, k)));
cmdSeq.Add(
CmdHelper.AssignCmd(CmdHelper.ExprToAssignLhs(path), ExprHelper.FunctionCall(mapRemoveFunc, path, k)));

Expand All @@ -474,17 +467,60 @@ private List<Cmd> RewriteMapPut(CallCmd callCmd)
{
var cmdSeq = new List<Cmd>();
var path = callCmd.Ins[0];
var c = callCmd.Ins[1];
var l = callCmd.Ins[1];
var v = callCmd.Ins[2];

var instantiation = monomorphizer.GetTypeInstantiation(callCmd.Proc);
var domain = instantiation["K"];
var range = instantiation["V"];
var mapContainsFunc = MapContains(domain, range);
var mapUpdateFunc = MapUpdate(domain, range);
var attribute = new QKeyValue(Token.NoToken, "linear", new List<object>(), null);
cmdSeq.Add(new AssumeCmd(Token.NoToken, Expr.Not(ExprHelper.FunctionCall(mapContainsFunc, path, Key(c))), attribute));
cmdSeq.Add(new AssumeCmd(Token.NoToken, Expr.Not(ExprHelper.FunctionCall(mapContainsFunc, path, Val(l))), attribute));
cmdSeq.Add(
CmdHelper.AssignCmd(CmdHelper.ExprToAssignLhs(path), ExprHelper.FunctionCall(mapUpdateFunc, path, Val(l), v)));

ResolveAndTypecheck(options, cmdSeq);
return cmdSeq;
}

private List<Cmd> RewriteMapGetValue(CallCmd callCmd)
{
var cmdSeq = new List<Cmd>();
var path = callCmd.Ins[0];
var k = callCmd.Ins[1];
var v = callCmd.Outs[0];

var instantiation = monomorphizer.GetTypeInstantiation(callCmd.Proc);
var domain = instantiation["K"];
var range = instantiation["V"];
var mapContainsFunc = MapContains(domain, range);
var mapRemoveFunc = MapRemove(domain, range);
var mapAtFunc = MapAt(domain, range);
cmdSeq.Add(AssertCmd(callCmd.tok, ExprHelper.FunctionCall(mapContainsFunc, path, k), "Map_GetValue failed"));
var oneConstructor = OneConstructor(domain);
cmdSeq.Add(CmdHelper.AssignCmd(v.Decl, ExprHelper.FunctionCall(mapAtFunc, path, k)));
cmdSeq.Add(
CmdHelper.AssignCmd(CmdHelper.ExprToAssignLhs(path), ExprHelper.FunctionCall(mapRemoveFunc, path, k)));

ResolveAndTypecheck(options, cmdSeq);
return cmdSeq;
}

private List<Cmd> RewriteMapPutValue(CallCmd callCmd)
{
var cmdSeq = new List<Cmd>();
var path = callCmd.Ins[0];
var k = callCmd.Ins[1];
var v = callCmd.Ins[2];

var instantiation = monomorphizer.GetTypeInstantiation(callCmd.Proc);
var domain = instantiation["K"];
var range = instantiation["V"];
var mapContainsFunc = MapContains(domain, range);
var mapUpdateFunc = MapUpdate(domain, range);
cmdSeq.Add(
CmdHelper.AssignCmd(CmdHelper.ExprToAssignLhs(path), ExprHelper.FunctionCall(mapUpdateFunc, path, Key(c), Val(c))));
CmdHelper.AssignCmd(CmdHelper.ExprToAssignLhs(path), ExprHelper.FunctionCall(mapUpdateFunc, path, k, v)));

ResolveAndTypecheck(options, cmdSeq);
return cmdSeq;
Expand Down
94 changes: 87 additions & 7 deletions Source/Concurrency/LinearTypeChecker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ public override Cmd VisitAssignCmd(AssignCmd node)
{
Error(rhs, $"linear variable {rhs.Decl.Name} can occur at most once as the source of an assignment");
}
else if (InvalidAssignmentWithKeyCollection(lhs.DeepAssignedVariable, rhs.Decl))
{
Error(rhs, $"Mismatch in key collection between source and target");
}
else
{
rhsVars.Add(rhs.Decl);
Expand All @@ -486,6 +490,10 @@ public override Cmd VisitAssignCmd(AssignCmd node)
{
Error(arg, $"linear variable {ie.Decl.Name} can occur at most once as the source of an assignment");
}
else if (InvalidAssignmentWithKeyCollection(field, ie.Decl))
{
Error(arg, $"Mismatch in key collection between source and target");
}
else
{
rhsVars.Add(ie.Decl);
Expand All @@ -507,9 +515,14 @@ public override Cmd VisitUnpackCmd(UnpackCmd node)
continue;
}
isLinearUnpack = true;
if (FindLinearKind(node.Constructor.InParams[j]) == LinearKind.ORDINARY)
var field = node.Constructor.InParams[j];
if (FindLinearKind(field) == LinearKind.ORDINARY)
{
Error(unpackedLhs[j], $"source of unpack must be linear field: {field}");
}
else if (InvalidAssignmentWithKeyCollection(unpackedLhs[j].Decl, field))
{
Error(unpackedLhs[j], $"source of unpack must be linear field: {node.Constructor.InParams[j]}");
Error(unpackedLhs[j], $"Mismatch in key collection between source and target");
}
}
if (isLinearUnpack)
Expand Down Expand Up @@ -568,6 +581,11 @@ public override Cmd VisitCallCmd(CallCmd node)
Error(node, $"linear variable {actual.Decl.Name} can occur only once as an input parameter");
continue;
}
if (!isPrimitive && InvalidAssignmentWithKeyCollection(formal, actual.Decl))
{
Error(node, $"Mismatch in key collection between source and target");
continue;
}
inVars.Add(actual.Decl);
if (actual.Decl is GlobalVariable && actualKind == LinearKind.LINEAR_IN)
{
Expand All @@ -590,6 +608,11 @@ public override Cmd VisitCallCmd(CallCmd node)
Error(node, $"only linear parameter can be assigned to a linear variable: {formal}");
continue;
}
if (!isPrimitive && InvalidAssignmentWithKeyCollection(actual.Decl, formal))
{
Error(node, $"Mismatch in key collection between source and target");
continue;
}
}

var globalOutVars = node.Outs.Select(ie => ie.Decl).ToHashSet();
Expand All @@ -598,6 +621,8 @@ public override Cmd VisitCallCmd(CallCmd node)
Error(node, $"global variable passed as input to pure call but not received as output: {v}");
});

var originalProc = (Procedure)Monomorphizer.GetOriginalDecl(node.Proc);

if (isPrimitive)
{
var modifiedArgument = CivlPrimitives.ModifiedArgument(node)?.Decl;
Expand All @@ -612,17 +637,52 @@ public override Cmd VisitCallCmd(CallCmd node)
Error(node, $"primitive assigns to input variable that is also an output variable: {modifiedArgument}");
}
else if (modifiedArgument is GlobalVariable &&
enclosingProc is not YieldProcedureDecl &&
enclosingProc.Modifies.All(v => v.Decl != modifiedArgument))
enclosingProc is not YieldProcedureDecl &&
enclosingProc.Modifies.All(v => v.Decl != modifiedArgument))
{
var str = enclosingProc is ActionDecl ? "action's" : "procedure's";
Error(node,
$"primitive assigns to a global variable that is not in the enclosing {str} modifies clause: {modifiedArgument}");
}

if (originalProc.Name == "Map_Split")
{
if (InvalidAssignmentWithKeyCollection(node.Outs[0].Decl, modifiedArgument))
{
Error(node.Outs[0], $"Mismatch in key collection between source and target");
}
}
else if (originalProc.Name == "Map_Join")
{
if (node.Ins[1] is IdentifierExpr ie && InvalidAssignmentWithKeyCollection(modifiedArgument, ie.Decl))
{
Error(node.Ins[1], $"Mismatch in key collection between source and target");
}
}
else if (originalProc.Name == "Map_Get" || originalProc.Name == "Map_Put")
{
if (!AreKeysCollected(modifiedArgument))
{
Error(node, $"Keys must be collected");
}
}
else if (originalProc.Name == "Map_GetValue" || originalProc.Name == "Map_PutValue")
{
if (AreKeysCollected(modifiedArgument))
{
Error(node, $"Keys must not be collected");
}
}
}
else if (originalProc.Name == "Map_Unpack")
{
if (node.Ins[0] is IdentifierExpr ie && !AreKeysCollected(ie.Decl))
{
Error(node.Ins[0], $"Mismatch in key collection between source and target");
}
}
}

var originalProc = (Procedure)Monomorphizer.GetOriginalDecl(node.Proc);
if (originalProc.Name == "create_multi_asyncs" || originalProc.Name == "create_asyncs")
{
var actionDecl = GetActionDeclFromCreateAsyncs(node);
Expand Down Expand Up @@ -749,6 +809,25 @@ public override Variable VisitVariable(Variable node)
return node;
}

private bool AreKeysCollected(Variable v)
{
var attr = QKeyValue.FindAttribute(v.Attributes, x => x.Key == "linear");
var attrParams = attr == null ? new List<object>() : attr.Params;
foreach (var param in attrParams)
{
if (param is string s && s == "no_collect_keys")
{
return false;
}
}
return true;
}

private bool InvalidAssignmentWithKeyCollection(Variable target, Variable source)
{
return AreKeysCollected(target) && !AreKeysCollected(source);
}

private void CheckLinearStoreAccessInGuards()
{
program.Implementations.ForEach(impl => {
Expand Down Expand Up @@ -871,7 +950,7 @@ public Type GetPermissionType(Type type)
{
var originalTypeCtorDecl = Monomorphizer.GetOriginalDecl(datatypeTypeCtorDecl);
var typeName = originalTypeCtorDecl.Name;
if (typeName == "Map" || typeName == "Set" || typeName == "Cell" | typeName == "One")
if (typeName == "Map" || typeName == "Set" || typeName == "One")
{
var actualTypeParams = program.monomorphizer.GetTypeInstantiation(datatypeTypeCtorDecl);
return actualTypeParams[0];
Expand Down Expand Up @@ -975,6 +1054,7 @@ private IEnumerable<Variable> FilterVariables(LinearDomain domain, IEnumerable<V
{
return scope.Where(v =>
FindLinearKind(v) != LinearKind.ORDINARY &&
AreKeysCollected(v) &&
collectors.ContainsKey(v.TypedIdent.Type) &&
collectors[v.TypedIdent.Type].ContainsKey(domain.permissionType));
}
Expand Down Expand Up @@ -1077,7 +1157,7 @@ private void CheckType(Type type)
return;
}
var typeCtorDeclName = Monomorphizer.GetOriginalDecl(ctorType.Decl).Name;
if (typeCtorDeclName == "Map" || typeCtorDeclName == "Cell")
if (typeCtorDeclName == "Map")
{
hasLinearStoreAccess = true;
}
Expand Down
14 changes: 5 additions & 9 deletions Source/Core/CivlAttributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ public static class CivlPrimitives
{
public static HashSet<string> LinearPrimitives = new()
{
"One_New", "One_To_Fractions", "Fractions_To_One",
"Cell_Pack", "Cell_Unpack",
"Map_MakeEmpty", "Map_Pack", "Map_Unpack", "Map_Split", "Map_Join", "Map_Get", "Map_Put", "Map_Assume",
"Loc_New", "KeyedLocSet_New",
"Map_MakeEmpty", "Map_Pack", "Map_Unpack", "Map_Split", "Map_Join",
"Map_Get", "Map_Put", "Map_GetValue", "Map_PutValue",
"Set_MakeEmpty", "Set_Split", "Set_Get", "Set_Put", "One_Split", "One_Get", "One_Put"
};

Expand All @@ -243,16 +243,12 @@ public static IdentifierExpr ModifiedArgument(CallCmd callCmd)
{
switch (Monomorphizer.GetOriginalDecl(callCmd.Proc).Name)
{
case "One_New":
case "One_To_Fractions":
case "Fractions_To_One":
case "Cell_Pack":
case "Cell_Unpack":
case "Loc_New":
case "KeyedLocSet_New":
case "Set_MakeEmpty":
case "Map_MakeEmpty":
case "Map_Pack":
case "Map_Unpack":
case "Map_Assume":
return null;
default:
return ExtractRootFromAccessPathExpr(callCmd.Ins[0]);
Expand Down
Loading
Loading