diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index 9d6048f484b83..21ca1c5943b14 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -1651,7 +1651,7 @@ public static Double mapCosineSimilarity(@SqlType("map(varchar,double)") Block l @ScalarFunction("cosine_similarity") @SqlNullable @SqlType(StandardTypes.DOUBLE) - public static Double arrayCosineSimilarity(@SqlType("array(double)") Block leftArray, @SqlType("array(double)") Block rightArray) + public static Double arrayCosineSimilarityDouble(@SqlType("array(double)") Block leftArray, @SqlType("array(double)") Block rightArray) { checkCondition( leftArray.getPositionCount() == rightArray.getPositionCount(), @@ -1675,6 +1675,34 @@ public static Double arrayCosineSimilarity(@SqlType("array(double)") Block leftA return dotProduct / (normLeftArray * normRightArray); } + @Description("cosine similarity between the given identical sized vectors represented as arrays") + @ScalarFunction("cosine_similarity") + @SqlNullable + @SqlType(StandardTypes.REAL) + public static Long arrayCosineSimilarityReal(@SqlType("array(real)") Block leftArray, @SqlType("array(real)") Block rightArray) + { + checkCondition( + leftArray.getPositionCount() == rightArray.getPositionCount(), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments need to have identical size"); + + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both arrays must not have nulls"); + + Float normLeftArray = array2NormReal(leftArray); + Float normRightArray = array2NormReal(rightArray); + + if (normLeftArray == null || normRightArray == null) { + return null; + } + + long dotProduct = arrayDotProductReal(leftArray, rightArray); + + return (long) floatToRawIntBits(intBitsToFloat((int) dotProduct) / (normLeftArray * normRightArray)); + } + @Description("squared Euclidean distance between the given identical sized vectors represented as arrays") @ScalarFunction("l2_squared") @SqlType(StandardTypes.REAL) @@ -1831,6 +1859,19 @@ private static Double array2Norm(Block array) return Math.sqrt(norm); } + private static Float array2NormReal(Block array) + { + float norm = 0.0f; + for (int i = 0; i < array.getPositionCount(); i++) { + if (array.isNull(i)) { + return null; + } + norm += intBitsToFloat((int) REAL.getLong(array, i)) * intBitsToFloat((int) REAL.getLong(array, i)); + } + + return (float) Math.sqrt(norm); + } + @Description("factorial of a given integer in the range of 0 to 20") @ScalarFunction @SqlType(StandardTypes.BIGINT)