diff --git a/dcm/src/main/java/com/vmware/dcm/backend/ortools/DetermineIndexes.java b/dcm/src/main/java/com/vmware/dcm/backend/ortools/DetermineIndexes.java index 7a27f62c..3e36d729 100644 --- a/dcm/src/main/java/com/vmware/dcm/backend/ortools/DetermineIndexes.java +++ b/dcm/src/main/java/com/vmware/dcm/backend/ortools/DetermineIndexes.java @@ -5,10 +5,11 @@ package com.vmware.dcm.backend.ortools; +import com.vmware.dcm.ModelException; +import com.vmware.dcm.compiler.IRTable; import com.vmware.dcm.compiler.ir.BinaryOperatorPredicate; import com.vmware.dcm.compiler.ir.ColumnIdentifier; import com.vmware.dcm.compiler.ir.GroupByComprehension; -import com.vmware.dcm.compiler.ir.JoinPredicate; import com.vmware.dcm.compiler.ir.ListComprehension; import com.vmware.dcm.compiler.ir.SimpleVisitor; import com.vmware.dcm.compiler.ir.TableRowGenerator; @@ -17,6 +18,7 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -44,40 +46,71 @@ protected VoidType visitListComprehension(final ListComprehension node, final Vo final List tableRowGenerators = inner.getQualifiers().stream() .filter(e -> e instanceof TableRowGenerator) .map(e -> (TableRowGenerator) e).collect(Collectors.toList()); - final List joinPredicates = inner.getQualifiers().stream() - .filter(e -> e instanceof JoinPredicate) - .map(e -> (JoinPredicate) e).collect(Collectors.toList()); + final List joinPredicates = inner.getQualifiers().stream() + .filter(e -> e instanceof BinaryOperatorPredicate) + .map(e -> (BinaryOperatorPredicate) e).collect(Collectors.toList()); // For now, we always use a scan for the first Table being iterated over. - tableRowGenerators.subList(1, tableRowGenerators.size()) - .forEach(tr -> { + tableRowGenerators.forEach(tr -> { // We might be able to use an index here, look for equality based accesses // across the join predicates for (final BinaryOperatorPredicate binaryOp : joinPredicates) { - if (binaryOp.getOperator().equals(BinaryOperatorPredicate.Operator.EQUAL)) { - final ColumnIdentifier left = (ColumnIdentifier) binaryOp.getLeft(); - final ColumnIdentifier right = (ColumnIdentifier) binaryOp.getRight(); - if (right.getTableName().equals(tr.getTable().getName())) { - indexes.add(new IndexDescription(tr, right)); - } else if (left.getTableName().equals(tr.getTable().getName())) { - indexes.add(new IndexDescription(tr, left)); - } - } + maybeIndex(binaryOp, tr.getTable()) + .ifPresent(indexes::add); } }); return super.visitListComprehension(node, context); } } + static Optional maybeIndex(final BinaryOperatorPredicate op, + final IRTable innerTable) { + if (!(op.getLeft() instanceof ColumnIdentifier && op.getRight() instanceof ColumnIdentifier)) { + return Optional.empty(); + } + final ColumnIdentifier left = (ColumnIdentifier) op.getLeft(); + final ColumnIdentifier right = (ColumnIdentifier) op.getRight(); + final String leftTableNameOrAlias = left.getField().getIRTable().getAliasedName(); + final String rightTableNameOrAlias = right.getField().getIRTable().getAliasedName(); + final String innerTableNameOrAlias = innerTable.getAliasedName(); + if (rightTableNameOrAlias.equals(innerTableNameOrAlias)) { + return Optional.of(new IndexDescription(innerTable, right)); + } else if (leftTableNameOrAlias.equals(innerTableNameOrAlias)) { + return Optional.of(new IndexDescription(innerTable, left)); + } + return Optional.empty(); + } + + static class IndexedAccess { + final ColumnIdentifier indexedColumn; + final ColumnIdentifier scanColumn; + + IndexedAccess(final ColumnIdentifier indexedColumn, final ColumnIdentifier scanColumn) { + this.indexedColumn = indexedColumn; + this.scanColumn = scanColumn; + } + } + static class IndexDescription { - final TableRowGenerator relation; + final IRTable relation; final ColumnIdentifier columnBeingAccessed; - IndexDescription(final TableRowGenerator relation, final ColumnIdentifier columnBeingAccessed) { + IndexDescription(final IRTable relation, final ColumnIdentifier columnBeingAccessed) { this.relation = relation; this.columnBeingAccessed = columnBeingAccessed; } + public IndexedAccess toIndexedAccess(final BinaryOperatorPredicate op) { + final ColumnIdentifier left = (ColumnIdentifier) op.getLeft(); + final ColumnIdentifier right = (ColumnIdentifier) op.getRight(); + if (left.getField().getIRTable().getAliasedName().equals(relation.getAliasedName())) { + return new IndexedAccess(left, right); + } else if (right.getField().getIRTable().getAliasedName().equals(relation.getAliasedName())) { + return new IndexedAccess(right, left); + } + throw new ModelException("Unreachable"); + } + @Override public boolean equals(final Object o) { if (this == o) { @@ -87,14 +120,14 @@ public boolean equals(final Object o) { return false; } final IndexDescription that = (IndexDescription) o; - return Objects.equals(relation.getTable().getAliasedName(), that.relation.getTable().getAliasedName()) + return Objects.equals(relation.getAliasedName(), that.relation.getAliasedName()) && Objects.equals(columnBeingAccessed.getField().getName(), that.columnBeingAccessed.getField().getName()); } @Override public int hashCode() { - return Objects.hash(relation.getTable().getAliasedName(), columnBeingAccessed.getField().getName()); + return Objects.hash(relation.getAliasedName(), columnBeingAccessed.getField().getName()); } @Override diff --git a/dcm/src/main/java/com/vmware/dcm/backend/ortools/Ops.java b/dcm/src/main/java/com/vmware/dcm/backend/ortools/Ops.java index 66db09a1..629893a1 100644 --- a/dcm/src/main/java/com/vmware/dcm/backend/ortools/Ops.java +++ b/dcm/src/main/java/com/vmware/dcm/backend/ortools/Ops.java @@ -1002,11 +1002,11 @@ public List findSufficientAssumptions(final CpSolver solver) { .collect(Collectors.toList()); } - public Map> toIndex(final List resultSet, final Function extractKey) { - final Map> index = new HashMap<>(); + public Index toIndex(final List resultSet, final Function extractKey) { + final Index index = new Index<>(); for (int i = 0; i < resultSet.size(); i++) { final K key = extractKey.apply(resultSet.get(i)); - index.computeIfAbsent(key, (k) -> new ArrayList<>()).add(i); + index.set(key, i); } return index; } @@ -1053,4 +1053,16 @@ public CpSolver solve() { throw new SolverException(status.toString()); } } + + public static class Index { + final Map> index = new HashMap<>(); + + void set(final K key, final int i) { + index.computeIfAbsent(key, (k) -> new ArrayList<>()).add(i); + } + + public List get(final K key) { + return index.getOrDefault(key, Collections.emptyList()); + } + } } \ No newline at end of file diff --git a/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java b/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java index e6f810ed..20f0d5f3 100644 --- a/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java +++ b/dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java @@ -87,7 +87,6 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -270,16 +269,21 @@ public List generateModelCode(final IRContext context, final Program createdIndices = new HashSet<>(); determineIndexes.indexes().forEach( id -> { - final IRTable table = id.relation.getTable(); + final IRTable table = id.relation; final IRColumn column = id.columnBeingAccessed.getField(); - final int fieldIndex = tupleMetadata.getFieldIndexInTable(table.getName(), column.getName()); - output.addStatement("final var $1L$2LIndex = " + - "o.toIndex($1L, (r) -> r.get($3L /* $4L */, $5L.class));", - tableNameStr(table.getName()), - fieldNameStr(table.getName(), column.getName()), fieldIndex, column.getName(), - tupleMetadata.getTypeForField(table, column)); + final String indexName = indexName(id); + // Indexed accesses can happen for the same column under different aliases. + // We disambiguate here. + if (!createdIndices.contains(indexName)) { + final int fieldIndex = tupleMetadata.getFieldIndexInTable(table.getName(), column.getName()); + output.addStatement("final var $1L = o.toIndex($2L, (r) -> r.get($3L /* $4L */, $5L.class))", + indexName, tableNameStr(table.getName()), fieldIndex, column.getName(), + tupleMetadata.getTypeForField(table, column)); + createdIndices.add(indexName); + } } ); } @@ -434,7 +438,7 @@ private OutputIR.Block innerComprehensionBlock(final String viewName, final QualifiersByVarType qualifiersByVarType = extractQualifiersByVarType(comprehension); // Start control flows to iterate over tables/views - final OutputIR.ForBlock iterationBlock = tableIterationBlock(viewName, qualifiersByVarType.nonVar); + final OutputIR.ForBlock iterationBlock = tableIterationBlock(viewName, qualifiersByVarType.nonVar, context); viewBlock.addBody(iterationBlock); context.enterScope(iterationBlock); @@ -558,39 +562,45 @@ private CodeBlock forLoopFromTableRowGeneratorBlock(final TableRowGenerator tabl * iteration indices pointing to the relevant lists of tuples. These iteration indices * may be obtained via nested for loops or using indexes if available. */ - private OutputIR.ForBlock tableIterationBlock(final String viewName, - final QualifiersByType nonVarQualifiers) { - final List loopStatements = forLoopsOrIndicesFromTableRowGenerators( + private OutputIR.ForBlock tableIterationBlock(final String viewName, final QualifiersByType nonVarQualifiers, + final TranslationContext context) { + final List loopStatements = + context.isSubQueryContext() ? forLoopsOrIndicesFromTableRowGenerators( nonVarQualifiers.tableRowGenerators, - nonVarQualifiers.joinPredicates); + nonVarQualifiers.joinPredicates) + : subQueryAccess(nonVarQualifiers); return outputIR.newForBlock(viewName, loopStatements); } + private List subQueryAccess(final QualifiersByType nonVarQualifiers) { + // Check if this is a correlated sub-query that can be expressed as an equijoin + if (nonVarQualifiers.wherePredicates.size() == 1 && nonVarQualifiers.tableRowGenerators.size() == 1) { + final BinaryOperatorPredicate op = nonVarQualifiers.wherePredicates.get(0); + final TableRowGenerator tr = nonVarQualifiers.tableRowGenerators.get(0); + final Optional idx = + DetermineIndexes.maybeIndex(op, tr.getTable()); + if (idx.isPresent()) { + return List.of(indexedAccess(idx.get(), op)); + } + } + return forLoopsOrIndicesFromTableRowGenerators(nonVarQualifiers.tableRowGenerators, + nonVarQualifiers.joinPredicates); + } private List forLoopsOrIndicesFromTableRowGenerators(final List tableRowGenerators, final List joinPredicates) { final TableRowGenerator forLoopTable = tableRowGenerators.get(0); final List loopStatements = new ArrayList<>(); loopStatements.add(forLoopFromTableRowGeneratorBlock(forLoopTable)); - final Function trByName = - (s) -> tableRowGenerators.stream() - .filter(e -> e.getTable().getAliasedName() - .equalsIgnoreCase(s.getTableName())) - .findAny().orElseThrow(); - // XXX: Use the IndexDescription from the DetermineIndexes pass instead tableRowGenerators.subList(1, tableRowGenerators.size()).stream() .map(tr -> { if (configUseIndicesForEqualityBasedJoins) { for (final BinaryOperatorPredicate binaryOp: joinPredicates) { - if (binaryOp.getOperator().equals(BinaryOperatorPredicate.Operator.EQUAL)) { - final ColumnIdentifier left = (ColumnIdentifier) binaryOp.getLeft(); - final ColumnIdentifier right = (ColumnIdentifier) binaryOp.getRight(); - if (right.getTableName().equals(tr.getTable().getName())) { - return indexedAccess(tr, right, trByName.apply(left), left); - } else if (left.getTableName().equals(tr.getTable().getName())) { - return indexedAccess(tr, left, trByName.apply(right), right); - } + final Optional idx = + DetermineIndexes.maybeIndex(binaryOp, tr.getTable()); + if (idx.isPresent()) { + return indexedAccess(idx.get(), binaryOp); } } } @@ -601,20 +611,19 @@ private List forLoopsOrIndicesFromTableRowGenerators(final List $1LList = $2L$3LIndex.get($4L)", idxIterStr, idxTableNameStr, - idxFieldName, fieldAccessFromScan) - .beginControlFlow("if ($LList == null)", idxIterStr) - .addStatement("continue") - .endControlFlow() + .addStatement("final List $1LList = $2L.get($3L)", idxIterStr, idxName, + fieldAccessFromScan) .add("for (int $1L : $1LList)", idxIterStr) .build(); } @@ -1046,6 +1055,12 @@ private String fieldNameStrWithIter(final String tableName, final String fieldNa } } + private static String indexName(final DetermineIndexes.IndexDescription index) { + final IRTable table = index.relation; + final IRColumn column = index.columnBeingAccessed.getField(); + return tableNameStr(table.getName()) + fieldNameStr(table.getName(), column.getName()) + "Index"; + } + private static String tableNumRowsStr(final String tableName) { return String.format("%s.size()", CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, tableName)); } diff --git a/dcm/src/test/java/com/vmware/dcm/ModelTest.java b/dcm/src/test/java/com/vmware/dcm/ModelTest.java index 2a56aef0..1de34d65 100644 --- a/dcm/src/test/java/com/vmware/dcm/ModelTest.java +++ b/dcm/src/test/java/com/vmware/dcm/ModelTest.java @@ -1252,6 +1252,31 @@ public void subqueryDifferentContexts(final SolverConfig solver) { model.solve("HOSTS"); } + @Test + public void correlatedSubqueryEquiJoin() { + final DSLContext conn = setup(); + conn.execute("CREATE TABLE t1 (controllable__c1 integer, c2 integer)"); + conn.execute("CREATE TABLE t2 (c1 integer, c2 integer)"); + final List constraint = List.of( + "CREATE CONSTRAINT subq AS SELECT * FROM t1 " + + "CHECK t1.controllable__c1 IN (SELECT t2.c1 FROM t2 WHERE t2.c2 = t1.c2)"); + conn.execute("insert into t1 values (1, 1)"); + conn.execute("insert into t1 values (1, 2)"); + conn.execute("insert into t1 values (1, 2)"); + conn.execute("insert into t1 values (1, 3)"); + conn.execute("insert into t2 values (1151, 1)"); + conn.execute("insert into t2 values (138, 2)"); + conn.execute("insert into t2 values (17, 3)"); + conn.execute("insert into t2 values (5, 4)"); + final Model model = Model.build(conn, constraint); + final Result t1 = model.solve("T1"); + final String matchString = + "final List t2IterList = t2t2C2Index.get((Integer) t1.get(t1Iter).get(1 /* C2 */));"; + assertTrue(model.compilationOutput().stream() + .map(String::trim) + .anyMatch(e -> e.contains(matchString))); + assertEquals(List.of(1151, 138, 138, 17), t1.getValues(0, Integer.class)); + } @ParameterizedTest @MethodSource("solvers")