Skip to content

Commit

Permalink
[fix](Nerieds) using join bugs (apache#48030)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Related PR: apache#15311

Problem Summary:

1. select * should only return using columns from left relation
2. bind expression on using join should not do distinct slot by name on
using join's output


### Release note

change the return columns when select asterisk from using join.
before return key column from both side. now only return key column
from left side.
  • Loading branch information
morrySnow authored Feb 21, 2025
1 parent 6c1778f commit 5a7454e
Show file tree
Hide file tree
Showing 30 changed files with 920 additions and 596 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,44 @@ public class Scope {

private final Optional<Scope> outerScope;
private final List<Slot> slots;
private final List<Slot> asteriskSlots;
private final Set<Slot> correlatedSlots;
private final boolean buildNameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToAsteriskSlot;

public Scope(List<? extends Slot> slots) {
public Scope(List<Slot> slots) {
this(Optional.empty(), slots);
}

public Scope(Optional<Scope> outerScope, List<Slot> slots) {
this(outerScope, slots, slots);
}

public Scope(List<Slot> slots, List<Slot> asteriskSlots) {
this(Optional.empty(), slots, asteriskSlots);
}

/** Scope */
public Scope(Optional<Scope> outerScope, List<? extends Slot> slots) {
public Scope(Optional<Scope> outerScope, List<Slot> slots, List<Slot> asteriskSlots) {
this.outerScope = Objects.requireNonNull(outerScope, "outerScope can not be null");
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
this.correlatedSlots = Sets.newLinkedHashSet();
this.buildNameToSlot = slots.size() > 500;
this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null;
this.nameToAsteriskSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToAsteriskSlot) : null;
this.asteriskSlots = Utils.fastToImmutableList(
Objects.requireNonNull(asteriskSlots, "asteriskSlots can not be null"));
}

public List<Slot> getSlots() {
return slots;
}

public List<Slot> getAsteriskSlots() {
return asteriskSlots;
}

public Optional<Scope> getOuterScope() {
return outerScope;
}
Expand All @@ -91,17 +108,18 @@ public Set<Slot> getCorrelatedSlots() {
}

/** findSlotIgnoreCase */
public List<Slot> findSlotIgnoreCase(String slotName) {
public List<Slot> findSlotIgnoreCase(String slotName, boolean all) {
List<Slot> slots = all ? this.slots : this.asteriskSlots;
Supplier<ListMultimap<String, Slot>> nameToSlot = all ? this.nameToSlot : this.nameToAsteriskSlot;
if (!buildNameToSlot) {
Object[] array = new Object[slots.size()];
Slot[] array = new Slot[slots.size()];
int filterIndex = 0;
for (int i = 0; i < slots.size(); i++) {
Slot slot = slots.get(i);
for (Slot slot : slots) {
if (slot.getName().equalsIgnoreCase(slotName)) {
array[filterIndex++] = slot;
}
}
return (List) Arrays.asList(array).subList(0, filterIndex);
return Arrays.asList(array).subList(0, filterIndex);
} else {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
}
Expand All @@ -114,4 +132,12 @@ private ListMultimap<String, Slot> buildNameToSlot() {
}
return map;
}

private ListMultimap<String, Slot> buildNameToAsteriskSlot() {
ListMultimap<String, Slot> map = LinkedListMultimap.create(asteriskSlots.size());
for (Slot slot : asteriskSlots) {
map.put(slot.getName().toUpperCase(Locale.ROOT), slot);
}
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalUsingJoin;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
Expand Down Expand Up @@ -3604,8 +3604,7 @@ private LogicalPlan withJoinRelations(LogicalPlan input, RelationContext ctx) {
last,
plan(join.relationPrimary()), null);
} else {
last = new UsingJoin<>(joinType, last,
plan(join.relationPrimary()), ImmutableList.of(), ids, distributeHint);
last = new LogicalUsingJoin<>(joinType, last, plan(join.relationPrimary()), ids, distributeHint);

}
if (distributeHint.distributeType != DistributeType.NONE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,54 @@
*/
public class LogicalProperties {
protected final Supplier<List<Slot>> outputSupplier;
protected final Supplier<List<Id>> outputExprIdsSupplier;
protected final Supplier<List<Id<?>>> outputExprIdsSupplier;
protected final Supplier<Set<Slot>> outputSetSupplier;
protected final Supplier<Map<Slot, Slot>> outputMapSupplier;
protected final Supplier<Set<ExprId>> outputExprIdSetSupplier;
protected final Supplier<List<Slot>> asteriskOutputSupplier;
protected final Supplier<DataTrait> dataTraitSupplier;
private Integer hashCode = null;

/**
* constructor when output same as asterisk's output.
*/
public LogicalProperties(Supplier<List<Slot>> outputSupplier, Supplier<DataTrait> dataTraitSupplier) {
// the second parameters should be null to reuse memorized output supplier
this(outputSupplier, null, dataTraitSupplier);
}

/**
* constructor of LogicalProperties.
*
* @param outputSupplier provide the output. Supplier can lazy compute output without
* throw exception for which children have UnboundRelation
* @param asteriskOutputSupplier provide the output when do select *.
* @param dataTraitSupplier provide the data trait.
*/
public LogicalProperties(Supplier<List<Slot>> outputSupplier,
public LogicalProperties(Supplier<List<Slot>> outputSupplier, Supplier<List<Slot>> asteriskOutputSupplier,
Supplier<DataTrait> dataTraitSupplier) {
this.outputSupplier = Suppliers.memoize(
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
);
this.outputExprIdsSupplier = Suppliers.memoize(() -> {
List<Slot> output = this.outputSupplier.get();
ImmutableList.Builder<Id> exprIdSet
ImmutableList.Builder<Id<?>> exprIdSet
= ImmutableList.builderWithExpectedSize(output.size());
for (Slot slot : output) {
exprIdSet.add(slot.getExprId());
}
return exprIdSet.build();
});
this.outputSetSupplier = Suppliers.memoize(() -> {
List<Slot> output = outputSupplier.get();
List<Slot> output = this.outputSupplier.get();
ImmutableSet.Builder<Slot> slots = ImmutableSet.builderWithExpectedSize(output.size());
for (Slot slot : output) {
slots.add(slot);
}
return slots.build();
});
this.outputMapSupplier = Suppliers.memoize(() -> {
Set<Slot> slots = outputSetSupplier.get();
Set<Slot> slots = this.outputSetSupplier.get();
ImmutableMap.Builder<Slot, Slot> map = ImmutableMap.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
map.put(slot, slot);
Expand All @@ -89,6 +100,9 @@ public LogicalProperties(Supplier<List<Slot>> outputSupplier,
}
return exprIdSet.build();
});
this.asteriskOutputSupplier = asteriskOutputSupplier == null ? this.outputSupplier : Suppliers.memoize(
Objects.requireNonNull(asteriskOutputSupplier, "asteriskOutputSupplier can not be null")
);
this.dataTraitSupplier = Suppliers.memoize(
Objects.requireNonNull(dataTraitSupplier, "Data Trait can not be null")
);
Expand All @@ -110,12 +124,16 @@ public Set<ExprId> getOutputExprIdSet() {
return outputExprIdSetSupplier.get();
}

public DataTrait getTrait() {
return dataTraitSupplier.get();
public List<Id<?>> getOutputExprIds() {
return outputExprIdsSupplier.get();
}

public List<Id> getOutputExprIds() {
return outputExprIdsSupplier.get();
public List<Slot> getAsteriskOutput() {
return asteriskOutputSupplier.get();
}

public DataTrait getTrait() {
return dataTraitSupplier.get();
}

@Override
Expand All @@ -126,6 +144,7 @@ public String toString() {
+ "\noutputSetSupplier=" + outputSetSupplier.get()
+ "\noutputMapSupplier=" + outputMapSupplier.get()
+ "\noutputExprIdSetSupplier=" + outputExprIdSetSupplier.get()
+ "\nasteriskOutputSupplier=" + asteriskOutputSupplier.get()
+ "\nhashCode=" + hashCode
+ '}';
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.SqlCacheContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.MappingSlot;
Expand Down Expand Up @@ -88,7 +87,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.UsingJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalUsingJoin;
import org.apache.doris.nereids.trees.plans.visitor.InferPlanOutputAlias;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.StructField;
Expand Down Expand Up @@ -130,7 +129,7 @@
*/
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class BindExpression implements AnalysisRuleFactory {
public static final Logger LOG = LogManager.getLogger(NereidsPlanner.class);
public static final Logger LOG = LogManager.getLogger(BindExpression.class);

@Override
public List<Rule> buildRules() {
Expand Down Expand Up @@ -161,7 +160,7 @@ protected boolean condition(Rule rule, Plan plan) {
logicalFilter().thenApply(this::bindFilter)
),
RuleType.BINDING_USING_JOIN_SLOT.build(
usingJoin().thenApply(this::bindUsingJoin)
logicalUsingJoin().thenApply(this::bindUsingJoin)
),
RuleType.BINDING_JOIN_SLOT.build(
logicalJoin().thenApply(this::bindJoin)
Expand Down Expand Up @@ -546,7 +545,7 @@ private LogicalJoin<Plan, Plan> bindJoin(MatchingContext<LogicalJoin<Plan, Plan>

return new LogicalJoin<>(join.getJoinType(),
hashJoinConjuncts.build(), otherJoinConjuncts.build(),
join.getDistributeHint(), join.getMarkJoinSlotReference(),
join.getDistributeHint(), join.getMarkJoinSlotReference(), join.getExceptAsteriskOutputs(),
join.children(), null);
}

Expand Down Expand Up @@ -591,16 +590,17 @@ private String getDbName(Plan plan) {
}
}

private LogicalJoin<Plan, Plan> bindUsingJoin(MatchingContext<UsingJoin<Plan, Plan>> ctx) {
UsingJoin<Plan, Plan> using = ctx.root;
private LogicalPlan bindUsingJoin(MatchingContext<LogicalUsingJoin<Plan, Plan>> ctx) {
LogicalUsingJoin<Plan, Plan> using = ctx.root;
CascadesContext cascadesContext = ctx.cascadesContext;
List<Expression> unboundHashJoinConjunct = using.getHashJoinConjuncts();
List<Expression> unboundHashJoinConjunct = using.getUsingSlots();

Scope leftScope = toScope(cascadesContext, ExpressionUtils.distinctSlotByName(using.left().getOutput()));
Scope rightScope = toScope(cascadesContext, ExpressionUtils.distinctSlotByName(using.right().getOutput()));
Scope leftScope = toScope(cascadesContext, using.left().getOutput(), using.left().getAsteriskOutput());
Scope rightScope = toScope(cascadesContext, using.right().getOutput(), using.right().getAsteriskOutput());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);

Builder<Expression> hashEqExprs = ImmutableList.builderWithExpectedSize(unboundHashJoinConjunct.size());
List<Slot> rightConjunctsSlots = Lists.newArrayList();
for (Expression usingColumn : unboundHashJoinConjunct) {
ExpressionAnalyzer leftExprAnalyzer = new ExpressionAnalyzer(
using, leftScope, cascadesContext, true, false);
Expand All @@ -609,13 +609,14 @@ private LogicalJoin<Plan, Plan> bindUsingJoin(MatchingContext<UsingJoin<Plan, Pl
ExpressionAnalyzer rightExprAnalyzer = new ExpressionAnalyzer(
using, rightScope, cascadesContext, true, false);
Expression usingRightSlot = rightExprAnalyzer.analyze(usingColumn, rewriteContext);
rightConjunctsSlots.add((Slot) usingRightSlot);
hashEqExprs.add(new EqualTo(usingLeftSlot, usingRightSlot));
}

return new LogicalJoin<>(
using.getJoinType() == JoinType.CROSS_JOIN ? JoinType.INNER_JOIN : using.getJoinType(),
hashEqExprs.build(), using.getOtherJoinConjuncts(),
using.getDistributeHint(), using.getMarkJoinSlotReference(),
hashEqExprs.build(), ImmutableList.of(),
using.getDistributeHint(), Optional.empty(), rightConjunctsSlots,
using.children(), null);
}

Expand Down Expand Up @@ -1035,11 +1036,11 @@ private List<Expression> bindGroupBy(
private Supplier<Scope> buildAggOutputScopeWithoutAggFun(
List<? extends NamedExpression> boundAggOutput, CascadesContext cascadesContext) {
return Suppliers.memoize(() -> {
Builder<MappingSlot> nonAggFunOutput = ImmutableList.builderWithExpectedSize(boundAggOutput.size());
Builder<Slot> nonAggFunOutput = ImmutableList.builderWithExpectedSize(boundAggOutput.size());
for (NamedExpression output : boundAggOutput) {
if (!output.containsType(AggregateFunction.class)) {
Slot outputSlot = output.toSlot();
MappingSlot mappingSlot = new MappingSlot(outputSlot,
Slot mappingSlot = new MappingSlot(outputSlot,
output instanceof Alias ? output.child(0) : output);
nonAggFunOutput.add(mappingSlot);
}
Expand Down Expand Up @@ -1226,7 +1227,7 @@ private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan)
return expression;
}

private Scope toScope(CascadesContext cascadesContext, List<? extends Slot> slots) {
private Scope toScope(CascadesContext cascadesContext, List<Slot> slots) {
Optional<Scope> outerScope = cascadesContext.getOuterScope();
if (outerScope.isPresent()) {
return new Scope(outerScope, slots);
Expand All @@ -1235,11 +1236,20 @@ private Scope toScope(CascadesContext cascadesContext, List<? extends Slot> slot
}
}

private Scope toScope(CascadesContext cascadesContext, List<Slot> slots, List<Slot> asteriskSlots) {
Optional<Scope> outerScope = cascadesContext.getOuterScope();
if (outerScope.isPresent()) {
return new Scope(outerScope, slots, asteriskSlots);
} else {
return new Scope(slots, asteriskSlots);
}
}

private SimpleExprAnalyzer buildSimpleExprAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, List<Plan> children,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
List<Slot> childrenOutputs = PlanUtils.fastGetChildrenOutputs(children);
Scope scope = toScope(cascadesContext, childrenOutputs);
Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children),
PlanUtils.fastGetChildrenAsteriskOutputs(children));
return buildSimpleExprAnalyzer(currentPlan, cascadesContext, scope, enableExactMatch, bindSlotInOuterScope);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) {
public Expression visitUnboundStar(UnboundStar unboundStar, ExpressionRewriteContext context) {
List<String> qualifier = unboundStar.getQualifier();
boolean showHidden = Util.showHiddenColumns();
List<Slot> slots = getScope().getSlots()
List<Slot> slots = getScope().getAsteriskSlots()
.stream()
.filter(slot -> !(slot instanceof SlotReference)
|| (((SlotReference) slot).isVisible()) || showHidden)
Expand Down Expand Up @@ -920,7 +920,7 @@ private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) {
private List<Slot> bindSingleSlotByName(String name, Scope scope) {
int namePartSize = 1;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
for (Slot boundSlot : scope.findSlotIgnoreCase(name, false)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
Expand All @@ -933,7 +933,7 @@ private List<Slot> bindSingleSlotByName(String name, Scope scope) {
private List<Slot> bindSingleSlotByTable(String table, String name, Scope scope) {
int namePartSize = 2;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
for (Slot boundSlot : scope.findSlotIgnoreCase(name, true)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
Expand All @@ -951,7 +951,7 @@ private List<Slot> bindSingleSlotByTable(String table, String name, Scope scope)
private List<Slot> bindSingleSlotByDb(String db, String table, String name, Scope scope) {
int namePartSize = 3;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
for (Slot boundSlot : scope.findSlotIgnoreCase(name, true)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
Expand All @@ -970,7 +970,7 @@ private List<Slot> bindSingleSlotByDb(String db, String table, String name, Scop
private List<Slot> bindSingleSlotByCatalog(String catalog, String db, String table, String name, Scope scope) {
int namePartSize = 4;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
for (Slot boundSlot : scope.findSlotIgnoreCase(name, true)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ private AnalyzedResult analyzeSubquery(SubqueryExpr expr) {
cascadesContext, expr.getQueryPlan(), cascadesContext.getCteContext());
// don't use `getScope()` because we only need `getScope().getOuterScope()` and `getScope().getSlots()`
// otherwise unexpected errors may occur
Scope subqueryScope = new Scope(getScope().getOuterScope(), getScope().getSlots());
Scope subqueryScope = new Scope(getScope().getOuterScope(),
getScope().getSlots(), getScope().getAsteriskSlots());
subqueryContext.setOuterScope(subqueryScope);
subqueryContext.newAnalyzer().analyze();
return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
Expand Down
Loading

0 comments on commit 5a7454e

Please sign in to comment.