Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
ortools-backend: use indexes for correlated subqueries (#154)
Browse files Browse the repository at this point in the history
Signed-off-by: Lalith Suresh <lsuresh@vmware.com>
  • Loading branch information
lalithsuresh committed Feb 22, 2022
1 parent 78e919c commit 52f9754
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package com.vmware.dcm.backend.ortools;

import com.vmware.dcm.ModelException;
import com.vmware.dcm.compiler.IRTable;
import com.vmware.dcm.compiler.ir.BinaryOperatorPredicate;
import com.vmware.dcm.compiler.ir.ColumnIdentifier;
import com.vmware.dcm.compiler.ir.GroupByComprehension;
import com.vmware.dcm.compiler.ir.JoinPredicate;
import com.vmware.dcm.compiler.ir.ListComprehension;
import com.vmware.dcm.compiler.ir.SimpleVisitor;
import com.vmware.dcm.compiler.ir.TableRowGenerator;
Expand All @@ -17,6 +18,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -44,40 +46,71 @@ protected VoidType visitListComprehension(final ListComprehension node, final Vo
final List<TableRowGenerator> tableRowGenerators = inner.getQualifiers().stream()
.filter(e -> e instanceof TableRowGenerator)
.map(e -> (TableRowGenerator) e).collect(Collectors.toList());
final List<JoinPredicate> joinPredicates = inner.getQualifiers().stream()
.filter(e -> e instanceof JoinPredicate)
.map(e -> (JoinPredicate) e).collect(Collectors.toList());
final List<BinaryOperatorPredicate> joinPredicates = inner.getQualifiers().stream()
.filter(e -> e instanceof BinaryOperatorPredicate)
.map(e -> (BinaryOperatorPredicate) e).collect(Collectors.toList());

// For now, we always use a scan for the first Table being iterated over.
tableRowGenerators.subList(1, tableRowGenerators.size())
.forEach(tr -> {
tableRowGenerators.forEach(tr -> {
// We might be able to use an index here, look for equality based accesses
// across the join predicates
for (final BinaryOperatorPredicate binaryOp : joinPredicates) {
if (binaryOp.getOperator().equals(BinaryOperatorPredicate.Operator.EQUAL)) {
final ColumnIdentifier left = (ColumnIdentifier) binaryOp.getLeft();
final ColumnIdentifier right = (ColumnIdentifier) binaryOp.getRight();
if (right.getTableName().equals(tr.getTable().getName())) {
indexes.add(new IndexDescription(tr, right));
} else if (left.getTableName().equals(tr.getTable().getName())) {
indexes.add(new IndexDescription(tr, left));
}
}
maybeIndex(binaryOp, tr.getTable())
.ifPresent(indexes::add);
}
});
return super.visitListComprehension(node, context);
}
}

static Optional<IndexDescription> maybeIndex(final BinaryOperatorPredicate op,
final IRTable innerTable) {
if (!(op.getLeft() instanceof ColumnIdentifier && op.getRight() instanceof ColumnIdentifier)) {
return Optional.empty();
}
final ColumnIdentifier left = (ColumnIdentifier) op.getLeft();
final ColumnIdentifier right = (ColumnIdentifier) op.getRight();
final String leftTableNameOrAlias = left.getField().getIRTable().getAliasedName();
final String rightTableNameOrAlias = right.getField().getIRTable().getAliasedName();
final String innerTableNameOrAlias = innerTable.getAliasedName();
if (rightTableNameOrAlias.equals(innerTableNameOrAlias)) {
return Optional.of(new IndexDescription(innerTable, right));
} else if (leftTableNameOrAlias.equals(innerTableNameOrAlias)) {
return Optional.of(new IndexDescription(innerTable, left));
}
return Optional.empty();
}

static class IndexedAccess {
final ColumnIdentifier indexedColumn;
final ColumnIdentifier scanColumn;

IndexedAccess(final ColumnIdentifier indexedColumn, final ColumnIdentifier scanColumn) {
this.indexedColumn = indexedColumn;
this.scanColumn = scanColumn;
}
}

static class IndexDescription {
final TableRowGenerator relation;
final IRTable relation;
final ColumnIdentifier columnBeingAccessed;

IndexDescription(final TableRowGenerator relation, final ColumnIdentifier columnBeingAccessed) {
IndexDescription(final IRTable relation, final ColumnIdentifier columnBeingAccessed) {
this.relation = relation;
this.columnBeingAccessed = columnBeingAccessed;
}

public IndexedAccess toIndexedAccess(final BinaryOperatorPredicate op) {
final ColumnIdentifier left = (ColumnIdentifier) op.getLeft();
final ColumnIdentifier right = (ColumnIdentifier) op.getRight();
if (left.getField().getIRTable().getAliasedName().equals(relation.getAliasedName())) {
return new IndexedAccess(left, right);
} else if (right.getField().getIRTable().getAliasedName().equals(relation.getAliasedName())) {
return new IndexedAccess(right, left);
}
throw new ModelException("Unreachable");
}

@Override
public boolean equals(final Object o) {
if (this == o) {
Expand All @@ -87,14 +120,14 @@ public boolean equals(final Object o) {
return false;
}
final IndexDescription that = (IndexDescription) o;
return Objects.equals(relation.getTable().getAliasedName(), that.relation.getTable().getAliasedName())
return Objects.equals(relation.getAliasedName(), that.relation.getAliasedName())
&& Objects.equals(columnBeingAccessed.getField().getName(),
that.columnBeingAccessed.getField().getName());
}

@Override
public int hashCode() {
return Objects.hash(relation.getTable().getAliasedName(), columnBeingAccessed.getField().getName());
return Objects.hash(relation.getAliasedName(), columnBeingAccessed.getField().getName());
}

@Override
Expand Down
18 changes: 15 additions & 3 deletions dcm/src/main/java/com/vmware/dcm/backend/ortools/Ops.java
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,11 @@ public List<String> findSufficientAssumptions(final CpSolver solver) {
.collect(Collectors.toList());
}

public <K, V> Map<K, List<Integer>> toIndex(final List<V> resultSet, final Function<V, K> extractKey) {
final Map<K, List<Integer>> index = new HashMap<>();
public <K, V> Index<K> toIndex(final List<V> resultSet, final Function<V, K> extractKey) {
final Index<K> index = new Index<>();
for (int i = 0; i < resultSet.size(); i++) {
final K key = extractKey.apply(resultSet.get(i));
index.computeIfAbsent(key, (k) -> new ArrayList<>()).add(i);
index.set(key, i);
}
return index;
}
Expand Down Expand Up @@ -1053,4 +1053,16 @@ public CpSolver solve() {
throw new SolverException(status.toString());
}
}

public static class Index<K> {
final Map<K, List<Integer>> index = new HashMap<>();

void set(final K key, final int i) {
index.computeIfAbsent(key, (k) -> new ArrayList<>()).add(i);
}

public List<Integer> get(final K key) {
return index.getOrDefault(key, Collections.emptyList());
}
}
}
95 changes: 55 additions & 40 deletions dcm/src/main/java/com/vmware/dcm/backend/ortools/OrToolsSolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -270,16 +269,21 @@ public List<String> generateModelCode(final IRContext context, final Program<Lis

private void addIndexes(final DetermineIndexes determineIndexes, final MethodSpec.Builder output) {
if (configUseIndicesForEqualityBasedJoins) {
final Set<String> createdIndices = new HashSet<>();
determineIndexes.indexes().forEach(
id -> {
final IRTable table = id.relation.getTable();
final IRTable table = id.relation;
final IRColumn column = id.columnBeingAccessed.getField();
final int fieldIndex = tupleMetadata.getFieldIndexInTable(table.getName(), column.getName());
output.addStatement("final var $1L$2LIndex = " +
"o.toIndex($1L, (r) -> r.get($3L /* $4L */, $5L.class));",
tableNameStr(table.getName()),
fieldNameStr(table.getName(), column.getName()), fieldIndex, column.getName(),
tupleMetadata.getTypeForField(table, column));
final String indexName = indexName(id);
// Indexed accesses can happen for the same column under different aliases.
// We disambiguate here.
if (!createdIndices.contains(indexName)) {
final int fieldIndex = tupleMetadata.getFieldIndexInTable(table.getName(), column.getName());
output.addStatement("final var $1L = o.toIndex($2L, (r) -> r.get($3L /* $4L */, $5L.class))",
indexName, tableNameStr(table.getName()), fieldIndex, column.getName(),
tupleMetadata.getTypeForField(table, column));
createdIndices.add(indexName);
}
}
);
}
Expand Down Expand Up @@ -434,7 +438,7 @@ private OutputIR.Block innerComprehensionBlock(final String viewName,
final QualifiersByVarType qualifiersByVarType = extractQualifiersByVarType(comprehension);

// Start control flows to iterate over tables/views
final OutputIR.ForBlock iterationBlock = tableIterationBlock(viewName, qualifiersByVarType.nonVar);
final OutputIR.ForBlock iterationBlock = tableIterationBlock(viewName, qualifiersByVarType.nonVar, context);
viewBlock.addBody(iterationBlock);
context.enterScope(iterationBlock);

Expand Down Expand Up @@ -558,39 +562,45 @@ private CodeBlock forLoopFromTableRowGeneratorBlock(final TableRowGenerator tabl
* iteration indices pointing to the relevant lists of tuples. These iteration indices
* may be obtained via nested for loops or using indexes if available.
*/
private OutputIR.ForBlock tableIterationBlock(final String viewName,
final QualifiersByType nonVarQualifiers) {
final List<CodeBlock> loopStatements = forLoopsOrIndicesFromTableRowGenerators(
private OutputIR.ForBlock tableIterationBlock(final String viewName, final QualifiersByType nonVarQualifiers,
final TranslationContext context) {
final List<CodeBlock> loopStatements =
context.isSubQueryContext() ? forLoopsOrIndicesFromTableRowGenerators(
nonVarQualifiers.tableRowGenerators,
nonVarQualifiers.joinPredicates);
nonVarQualifiers.joinPredicates)
: subQueryAccess(nonVarQualifiers);
return outputIR.newForBlock(viewName, loopStatements);
}

private List<CodeBlock> subQueryAccess(final QualifiersByType nonVarQualifiers) {
// Check if this is a correlated sub-query that can be expressed as an equijoin
if (nonVarQualifiers.wherePredicates.size() == 1 && nonVarQualifiers.tableRowGenerators.size() == 1) {
final BinaryOperatorPredicate op = nonVarQualifiers.wherePredicates.get(0);
final TableRowGenerator tr = nonVarQualifiers.tableRowGenerators.get(0);
final Optional<DetermineIndexes.IndexDescription> idx =
DetermineIndexes.maybeIndex(op, tr.getTable());
if (idx.isPresent()) {
return List.of(indexedAccess(idx.get(), op));
}
}
return forLoopsOrIndicesFromTableRowGenerators(nonVarQualifiers.tableRowGenerators,
nonVarQualifiers.joinPredicates);
}

private List<CodeBlock> forLoopsOrIndicesFromTableRowGenerators(final List<TableRowGenerator> tableRowGenerators,
final List<JoinPredicate> joinPredicates) {
final TableRowGenerator forLoopTable = tableRowGenerators.get(0);
final List<CodeBlock> loopStatements = new ArrayList<>();
loopStatements.add(forLoopFromTableRowGeneratorBlock(forLoopTable));
final Function<ColumnIdentifier, TableRowGenerator> trByName =
(s) -> tableRowGenerators.stream()
.filter(e -> e.getTable().getAliasedName()
.equalsIgnoreCase(s.getTableName()))
.findAny().orElseThrow();

// XXX: Use the IndexDescription from the DetermineIndexes pass instead
tableRowGenerators.subList(1, tableRowGenerators.size()).stream()
.map(tr -> {
if (configUseIndicesForEqualityBasedJoins) {
for (final BinaryOperatorPredicate binaryOp: joinPredicates) {
if (binaryOp.getOperator().equals(BinaryOperatorPredicate.Operator.EQUAL)) {
final ColumnIdentifier left = (ColumnIdentifier) binaryOp.getLeft();
final ColumnIdentifier right = (ColumnIdentifier) binaryOp.getRight();
if (right.getTableName().equals(tr.getTable().getName())) {
return indexedAccess(tr, right, trByName.apply(left), left);
} else if (left.getTableName().equals(tr.getTable().getName())) {
return indexedAccess(tr, left, trByName.apply(right), right);
}
final Optional<DetermineIndexes.IndexDescription> idx =
DetermineIndexes.maybeIndex(binaryOp, tr.getTable());
if (idx.isPresent()) {
return indexedAccess(idx.get(), binaryOp);
}
}
}
Expand All @@ -601,20 +611,19 @@ private List<CodeBlock> forLoopsOrIndicesFromTableRowGenerators(final List<Table
return loopStatements;
}

private CodeBlock indexedAccess(final TableRowGenerator indexedTable, final ColumnIdentifier indexedColumn,
final TableRowGenerator scanTable, final ColumnIdentifier scanColumn) {
final String idxIterStr = iterStr(indexedTable.getTable().getAliasedName());
final String idxTableNameStr = tableNameStr(indexedTable.getTable().getName());
final String fieldAccessFromScan = fieldNameStrWithIter(scanTable.getTable().getName(),
scanColumn.getField().getName(),
iterStr(scanTable.getTable().getAliasedName()));
final String idxFieldName = fieldNameStr(indexedColumn.getTableName(), indexedColumn.getField().getName());
private CodeBlock indexedAccess(final DetermineIndexes.IndexDescription idx,
final BinaryOperatorPredicate op) {
final DetermineIndexes.IndexedAccess access = idx.toIndexedAccess(op);
final IRTable indexedTable = access.indexedColumn.getField().getIRTable();
final IRTable scanTable = access.scanColumn.getField().getIRTable();
final String idxIterStr = iterStr(indexedTable.getAliasedName());
final String idxName = indexName(idx);
final String fieldAccessFromScan = fieldNameStrWithIter(scanTable.getName(),
access.scanColumn.getField().getName(),
iterStr(scanTable.getAliasedName()));
return CodeBlock.builder()
.addStatement("final List<Integer> $1LList = $2L$3LIndex.get($4L)", idxIterStr, idxTableNameStr,
idxFieldName, fieldAccessFromScan)
.beginControlFlow("if ($LList == null)", idxIterStr)
.addStatement("continue")
.endControlFlow()
.addStatement("final List<Integer> $1LList = $2L.get($3L)", idxIterStr, idxName,
fieldAccessFromScan)
.add("for (int $1L : $1LList)", idxIterStr)
.build();
}
Expand Down Expand Up @@ -1046,6 +1055,12 @@ private String fieldNameStrWithIter(final String tableName, final String fieldNa
}
}

private static String indexName(final DetermineIndexes.IndexDescription index) {
final IRTable table = index.relation;
final IRColumn column = index.columnBeingAccessed.getField();
return tableNameStr(table.getName()) + fieldNameStr(table.getName(), column.getName()) + "Index";
}

private static String tableNumRowsStr(final String tableName) {
return String.format("%s.size()", CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, tableName));
}
Expand Down
25 changes: 25 additions & 0 deletions dcm/src/test/java/com/vmware/dcm/ModelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,31 @@ public void subqueryDifferentContexts(final SolverConfig solver) {
model.solve("HOSTS");
}

@Test
public void correlatedSubqueryEquiJoin() {
final DSLContext conn = setup();
conn.execute("CREATE TABLE t1 (controllable__c1 integer, c2 integer)");
conn.execute("CREATE TABLE t2 (c1 integer, c2 integer)");
final List<String> constraint = List.of(
"CREATE CONSTRAINT subq AS SELECT * FROM t1 " +
"CHECK t1.controllable__c1 IN (SELECT t2.c1 FROM t2 WHERE t2.c2 = t1.c2)");
conn.execute("insert into t1 values (1, 1)");
conn.execute("insert into t1 values (1, 2)");
conn.execute("insert into t1 values (1, 2)");
conn.execute("insert into t1 values (1, 3)");
conn.execute("insert into t2 values (1151, 1)");
conn.execute("insert into t2 values (138, 2)");
conn.execute("insert into t2 values (17, 3)");
conn.execute("insert into t2 values (5, 4)");
final Model model = Model.build(conn, constraint);
final Result<? extends Record> t1 = model.solve("T1");
final String matchString =
"final List<Integer> t2IterList = t2t2C2Index.get((Integer) t1.get(t1Iter).get(1 /* C2 */));";
assertTrue(model.compilationOutput().stream()
.map(String::trim)
.anyMatch(e -> e.contains(matchString)));
assertEquals(List.of(1151, 138, 138, 17), t1.getValues(0, Integer.class));
}

@ParameterizedTest
@MethodSource("solvers")
Expand Down

0 comments on commit 52f9754

Please sign in to comment.