Skip to content

Commit

Permalink
Improve COALESCE inference for MySQL
Browse files Browse the repository at this point in the history
  • Loading branch information
janedbal authored and ondrejmirtes committed Jul 2, 2024
1 parent 2004f84 commit c87ee29
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 18 deletions.
88 changes: 82 additions & 6 deletions src/Type/Doctrine/Query/QueryResultTypeWalker.php
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ public function walkJoin($join): string
*/
public function walkCoalesceExpression($coalesceExpression): string
{
$rawTypes = [];
$expressionTypes = [];
$allTypesContainNull = true;

Expand All @@ -987,22 +988,67 @@ public function walkCoalesceExpression($coalesceExpression): string
continue;
}

$type = $this->unmarshalType($expression->dispatch($this));
$allTypesContainNull = $allTypesContainNull && $this->canBeNull($type);
$rawType = $this->unmarshalType($expression->dispatch($this));
$rawTypes[] = $rawType;

$allTypesContainNull = $allTypesContainNull && $this->canBeNull($rawType);

// Some drivers manipulate the types, lets avoid false positives by generalizing constant types
// e.g. sqlsrv: "COALESCE returns the data type of value with the highest precedence"
// e.g. mysql: COALESCE(1, 'foo') === '1' (undocumented? https://gist.github.com/jrunning/4535434)
$expressionTypes[] = $this->generalizeConstantType($type, false);
$expressionTypes[] = $this->generalizeConstantType($rawType, false);
}

$type = TypeCombinator::union(...$expressionTypes);
$generalizedUnion = TypeCombinator::union(...$expressionTypes);

if (!$allTypesContainNull) {
$type = TypeCombinator::removeNull($type);
$generalizedUnion = TypeCombinator::removeNull($generalizedUnion);
}

return $this->marshalType($type);
if ($this->driverType === DriverDetector::MYSQLI || $this->driverType === DriverDetector::PDO_MYSQL) {
return $this->marshalType(
$this->inferCoalesceForMySql($rawTypes, $generalizedUnion)
);
}

return $this->marshalType($generalizedUnion);
}

/**
* @param list<Type> $rawTypes
*/
private function inferCoalesceForMySql(array $rawTypes, Type $originalResult): Type
{
$containsString = false;
$containsFloat = false;
$allIsNumericExcludingLiteralString = true;

foreach ($rawTypes as $rawType) {
$rawTypeNoNull = TypeCombinator::removeNull($rawType);
$isLiteralString = $rawTypeNoNull instanceof DqlConstantStringType && $rawTypeNoNull->getOriginLiteralType() === AST\Literal::STRING;

if (!$this->containsOnlyNumericTypes($rawTypeNoNull) || $isLiteralString) {
$allIsNumericExcludingLiteralString = false;
}

if ($rawTypeNoNull->isString()->yes()) {
$containsString = true;
}

if (!$rawTypeNoNull->isFloat()->yes()) {
continue;
}

$containsFloat = true;
}

if ($containsFloat && $allIsNumericExcludingLiteralString) {
return $this->simpleFloatify($originalResult);
} elseif ($containsString) {
return $this->simpleStringify($originalResult);
}

return $originalResult;
}

/**
Expand Down Expand Up @@ -2107,4 +2153,34 @@ private function isSupportedDriver(): bool
], true);
}

private function simpleStringify(Type $type): Type
{
return TypeTraverser::map($type, static function (Type $type, callable $traverse): Type {
if ($type instanceof UnionType || $type instanceof IntersectionType) {
return $traverse($type);
}

if ($type instanceof IntegerType || $type instanceof FloatType || $type instanceof BooleanType) {
return $type->toString();
}

return $traverse($type);
});
}

private function simpleFloatify(Type $type): Type
{
return TypeTraverser::map($type, static function (Type $type, callable $traverse): Type {
if ($type instanceof UnionType || $type instanceof IntersectionType) {
return $traverse($type);
}

if ($type instanceof IntegerType || $type instanceof BooleanType || $type instanceof StringType) {
return $type->toFloat();
}

return $traverse($type);
});
}

}
Loading

0 comments on commit c87ee29

Please sign in to comment.