diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java b/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java index 489dae2f..0216b29d 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java @@ -262,6 +262,11 @@ public SymbolTable(SymbolTable s) { publicKeys.addAll(s.publicKeys); } + public SymbolTable(List symbols) { + this.symbols = new ArrayList<>(symbols); + this.publicKeys = new ArrayList<>(); + } + public SymbolTable(List symbols, List publicKeys) { this.symbols = new ArrayList<>(); this.symbols.addAll(symbols); diff --git a/src/main/java/org/biscuitsec/biscuit/token/Block.java b/src/main/java/org/biscuitsec/biscuit/token/Block.java index 382224a8..5931d52f 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Block.java @@ -94,7 +94,12 @@ public String print(SymbolTable symbol_table) { s.append("\n\t\texternal key: "); s.append(this.externalKey.get().toString()); } - s.append("\n\t\tfacts: ["); + s.append("\n\t\tscopes: ["); + for (Scope scope : this.scopes) { + s.append("\n\t\t\t"); + s.append(symbol_table.print_scope(scope)); + } + s.append("\n\t\t]\n\t\tfacts: ["); for (Fact f : this.facts) { s.append("\n\t\t\t"); s.append(symbol_table.print_fact(f)); diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java index 302e6930..c4591e03 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java @@ -32,6 +32,8 @@ public class Block { List rules; List checks; List scopes; + List publicKeys; + Option externalKey; public Block(long index, SymbolTable base_symbols) { this.index = index; @@ -43,6 +45,33 @@ public Block(long index, SymbolTable base_symbols) { this.rules = new ArrayList<>(); this.checks = new ArrayList<>(); this.scopes = new ArrayList<>(); + this.publicKeys = new ArrayList<>(); + this.externalKey = Option.none(); + } + + public Block setExternalKey(Option externalKey) { + this.externalKey = externalKey; + return this; + } + + public Block addPublicKey(PublicKey publicKey) { + this.publicKeys.add(publicKey); + return this; + } + + public Block addPublicKeys(List publicKeys) { + this.publicKeys.addAll(publicKeys); + return this; + } + + public Block setPublicKeys(List publicKeys) { + this.publicKeys = publicKeys; + return this; + } + + public Block addSymbol(String symbol) { + this.symbols.add(symbol); + return this; } public Block add_fact(org.biscuitsec.biscuit.token.builder.Fact f) { @@ -124,7 +153,7 @@ public org.biscuitsec.biscuit.token.Block build() { SchemaVersion schemaVersion = new SchemaVersion(this.facts, this.rules, this.checks, this.scopes); return new org.biscuitsec.biscuit.token.Block(symbols, this.context, this.facts, this.rules, this.checks, - this.scopes, publicKeys, Option.none(), schemaVersion.version()); + this.scopes, publicKeys, this.externalKey, schemaVersion.version()); } public Block check_right(String right) { diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java index 9c3c00a2..86664e7b 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java @@ -306,43 +306,57 @@ public static Either> expr7(String s) { return Either.left(res1.getLeft()); } Tuple2 t1 = res1.get(); - s = space(t1._1); - Expression e1 = t1._2; - if(!s.startsWith(".")) { - return Either.right(new Tuple2<>(s, e1)); - } - s = s.substring(1); - - Either> res2 = binary_op7(s); - if (res2.isLeft()) { - return Either.left(res2.getLeft()); - } - Tuple2 t2 = res2.get(); - s = space(t2._1); - Expression.Op op = t2._2; - - if(!s.startsWith("(")) { - return Either.left(new Error(s, "missing (")); - } - - s = space(s.substring(1)); + s = t1._1; + Expression e = t1._2; - Either> res3 = expr(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } + while(true) { + s = space(s); + if(s.isEmpty()) { + break; + } - Tuple2 t3 = res3.get(); + if (!s.startsWith(".")) { + return Either.right(new Tuple2<>(s, e)); + } - s = space(t3._1); - if(!s.startsWith(")")) { - return Either.left(new Error(s, "missing )")); + s = s.substring(1); + Either> res2 = binary_op7(s); + if (!res2.isLeft()) { + Tuple2 t2 = res2.get(); + s = space(t2._1); + Expression.Op op = t2._2; + + if (!s.startsWith("(")) { + return Either.left(new Error(s, "missing (")); + } + + s = space(s.substring(1)); + + Either> res3 = expr_term(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + + Tuple2 t3 = res3.get(); + + s = space(t3._1); + if (!s.startsWith(")")) { + return Either.left(new Error(s, "missing )")); + } + s = space(s.substring(1)); + Expression e2 = t3._2; + + e = new Expression.Binary(op, e, e2); + } else { + if (s.startsWith("length()")) { + e = new Expression.Unary(Expression.Op.Length, e); + s = s.substring(9); + } + } } - s = space(s.substring(1)); - Expression e2 = t3._2; - return Either.right(new Tuple2<>(s, new Expression.Binary(op, e1, e2))); + return Either.right(new Tuple2<>(s, e)); } public static Either> expr_term(String s) { @@ -517,6 +531,12 @@ public static Either> binary_op6(String s) } public static Either> binary_op7(String s) { + if(s.startsWith("intersection")) { + return Either.right(new Tuple2<>(s.substring(12), Expression.Op.Intersection)); + } + if(s.startsWith("union")) { + return Either.right(new Tuple2<>(s.substring(5), Expression.Op.Union)); + } if(s.startsWith("contains")) { return Either.right(new Tuple2<>(s.substring(8), Expression.Op.Contains)); } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index dbd1c2c8..e3bac7cb 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -1,7 +1,9 @@ package org.biscuitsec.biscuit.token.builder.parser; import biscuit.format.schema.Schema; +import io.vavr.collection.Stream; import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.token.Policy; import io.vavr.Tuple2; import io.vavr.Tuple4; @@ -10,9 +12,7 @@ import java.time.OffsetDateTime; import java.time.format.DateTimeParseException; -import java.util.ArrayList; -import java.util.List; -import java.util.HashSet; +import java.util.*; import java.util.function.Function; public class Parser { @@ -524,14 +524,14 @@ public static Either> integer(String s) { return Either.left(new Error(s, "not an integer")); } - Integer i = Integer.parseInt(s.substring(0, index2)); + long i = Long.parseLong(s.substring(0, index2)); String remaining = s.substring(index2); - return Either.right(new Tuple2(remaining, (Term.Integer) Utils.integer(i.intValue()))); + return Either.right(new Tuple2(remaining, (Term.Integer) Utils.integer(i))); } public static Either> date(String s) { - Tuple2 t = take_while(s, (c) -> c != ' ' && c != ',' && c != ')'); + Tuple2 t = take_while(s, (c) -> c != ' ' && c != ',' && c != ')' && c != ']'); try { OffsetDateTime d = OffsetDateTime.parse(t._1); diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index 3c76890b..0230b235 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -5,6 +5,7 @@ import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; import org.biscuitsec.biscuit.datalog.expressions.Op; +import org.biscuitsec.biscuit.token.Biscuit; import org.biscuitsec.biscuit.token.builder.parser.Error; import org.biscuitsec.biscuit.token.builder.parser.Parser; import io.vavr.Tuple2; @@ -15,7 +16,7 @@ import org.junit.jupiter.api.Test; import static org.biscuitsec.biscuit.datalog.Check.Kind.One; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; import java.util.*; @@ -132,6 +133,43 @@ void testRuleWithExpressionOrdering() { res); } + @Test + void expressionIntersectionAndContainsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).contains(1)"); + + assertEquals(Either.right(new Tuple2<>("", + new Expression.Binary( + Expression.Op.Contains, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) + ), + new Expression.Value(Utils.integer(1)) + ))), res); + } + + @Test + void expressionIntersectionAndContainsAndLengthEqualsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).length() == 2"); + + assertEquals(Either.right(new Tuple2<>("", + new Expression.Binary( + Expression.Op.Equal, + new Expression.Unary( + Expression.Op.Length, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), + new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) + ) + ), + new Expression.Value(Utils.integer(2)) + ))), res); + } + @Test void ruleWithFreeExpressionVariables() { Either> res = diff --git a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java index c11fbebb..5e38166b 100644 --- a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java +++ b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java @@ -9,6 +9,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; public class ExpressionTest { @@ -114,4 +115,27 @@ public void testNegativeContainsStr() throws Error.Execution { e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) ); } + + @Test + public void testIntersectionAndContains() throws Error.Execution { + SymbolTable symbols = new SymbolTable(); + + Expression e = new Expression(new ArrayList(Arrays.asList( + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3))))), + new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2))))), + new Op.Binary(Op.BinaryOp.Intersection), + new Op.Value(new Term.Integer(1)), + new Op.Binary(Op.BinaryOp.Contains) + ))); + + assertEquals( + "[1, 2, 3].intersection([1, 2]).contains(1)", + e.print(symbols).get() + ); + + assertEquals( + new Term.Bool(true), + e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) + ); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java index 02549fe1..bb5907c4 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java @@ -4,15 +4,20 @@ import com.google.gson.*; import com.google.protobuf.MapEntry; import io.vavr.Tuple2; +import io.vavr.control.Option; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.Rule; import org.biscuitsec.biscuit.datalog.RunLimits; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TrustedOrigins; import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; import io.vavr.control.Try; import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.parser.ExpressionParser; +import org.biscuitsec.biscuit.token.builder.parser.Parser; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; @@ -22,6 +27,7 @@ import java.time.Duration; import java.util.*; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @@ -142,19 +148,11 @@ DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, }); } - class Block { List symbols; - - public String getCode() { - return code; - } - - public void setCode(String code) { - this.code = code; - } - String code; + List public_keys; + String external_key; public List getSymbols() { return symbols; @@ -163,17 +161,37 @@ public List getSymbols() { public void setSymbols(List symbols) { this.symbols = symbols; } - } - class Token { - List blocks; + public String getCode() { return code; } + + public void setCode(String code) { this.code = code; } - public List getBlocks() { - return blocks; + public List getPublicKeys() { + return this.public_keys.stream() + .map(pk -> + Parser.publicKey(pk).fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2) + ) + .collect(Collectors.toList()); + } + + public void setPublicKeys(List publicKeys) { + this.public_keys = publicKeys.stream() + .map(PublicKey::toString) + .collect(Collectors.toList()); + } + + public Option getExternalKey() { + if (this.external_key != null) { + PublicKey externalKey = Parser.publicKey(this.external_key) + .fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2); + return Option.of(externalKey); + } else { + return Option.none(); + } } - public void setBlocks(List blocks) { - this.blocks = blocks; + public void setExternalKey(Option externalKey) { + this.external_key = externalKey.map(PublicKey::toString).getOrElse((String) null); } } @@ -189,7 +207,7 @@ public void setTitle(String title) { } String filename; - List tokens; + List token; JsonElement validations; public String getFilename() { @@ -200,12 +218,12 @@ public void setFilename(String filename) { this.filename = filename; } - public List getTokens() { - return tokens; + public List getToken() { + return token; } - public void setTokens(List tokens) { - this.tokens = tokens; + public void setTokens(List token) { + this.token = token; } public JsonElement getValidations() {