diff --git a/dowhy/gcm/divergence.py b/dowhy/gcm/divergence.py index 88ddf4e74..59b141c15 100644 --- a/dowhy/gcm/divergence.py +++ b/dowhy/gcm/divergence.py @@ -61,7 +61,7 @@ def estimate_kl_divergence_continuous_knn( # Making sure that X and Y have no overlapping values, which would lead to a distance of 0 with k=1 and, thus, to # a division by zero. if remove_common_elements: - X = setdiff2d(X, Y, assume_unique=True) + X = setdiff2d(X, Y, assume_unique=False) if X.shape[0] < k + 1: # All elements are equal (or at least less than k samples are different) return 0