Skip to content

Commit e9495d4

Browse files
shashankiiitShashank Paliwal
andauthored
Add a new method for standard DotProduct for users seeking non-normalized score (#876)
Co-authored-by: Shashank Paliwal <spaliwal@spaliwal-mn1.linkedin.biz>
1 parent 90a6b6d commit e9495d4

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

feathr-impl/src/main/java/com/linkedin/feathr/common/util/MvelContextUDFs.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,31 @@ public static Float cosineSimilarity(Object obj1, Object obj2) {
364364
}
365365
}
366366

367+
/**
368+
* Returns a standard dotProduct of two vector objects.
369+
* Use {@link MvelContextUDFs#cosineSimilarity(Object, Object)} for normalized dot-product.
370+
*/
371+
@ExportToMvel
372+
public static Double dotProduct(Object obj1, Object obj2) {
373+
if (obj1 == null || obj2 == null) {
374+
return null;
375+
}
376+
Map<String, Float> mapA = CoercionUtils.coerceToVector(obj1);
377+
Map<String, Float> mapB = CoercionUtils.coerceToVector(obj2);
378+
double dotProduct = 0;
379+
380+
for (Map.Entry<String, Float> entry : mapA.entrySet()) {
381+
String k = entry.getKey();
382+
float valA = entry.getValue();
383+
Float valB = mapB.get(k);
384+
if (valB != null) {
385+
dotProduct += ((double) valA * valB);
386+
}
387+
}
388+
389+
return dotProduct;
390+
}
391+
367392
/**
368393
* convert input to lower case string
369394
* @param input input string

feathr-impl/src/test/java/com/linkedin/feathr/offline/TestMvelContext.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,24 @@ public void testCosineSimilarity() {
3030
categoricalOutput2.clear();
3131
assertEquals(cosineSimilarity(categoricalOutput1, categoricalOutput2), 0.0F);
3232
}
33+
34+
@Test
35+
public void testDotProduct() {
36+
// Test basic dot product calculation
37+
Map<String, Float> categoricalOutput1 = new HashMap<>();
38+
categoricalOutput1.put("A", 1F);
39+
categoricalOutput1.put("B", 1F);
40+
41+
Map<String, Float> categoricalOutput2 = new HashMap<>();
42+
categoricalOutput2.put("B", 1F);
43+
categoricalOutput2.put("C", 1F);
44+
45+
assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 1.0D);
46+
47+
// Test dot product of zero vectors
48+
categoricalOutput1.clear();
49+
assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 0.0D);
50+
categoricalOutput2.clear();
51+
assertEquals(dotProduct(categoricalOutput1, categoricalOutput2), 0.0D);
52+
}
3353
}

0 commit comments

Comments
 (0)