diff --git a/pywhy_stats/kernel_utils.py b/pywhy_stats/kernel_utils.py index faae0c3..88b5e0c 100644 --- a/pywhy_stats/kernel_utils.py +++ b/pywhy_stats/kernel_utils.py @@ -123,6 +123,18 @@ def compute_kernel( ------- kernel : ArrayLike of shape (n_samples_X, n_samples_X) or (n_samples_X, n_samples_Y) The kernel matrix. + + Notes + ----- + If the metric is a callable, it will have either one input for ``X``, or two inputs for ``X`` and + ``Y``. If one input is passed in, it is assumed that the kernel operates on the entire array to compute + the kernel array. If two inputs are passed in, then it is assumed that the kernel operates on + pairwise row vectors from each input array. This callable is parallelized across rows of the input + using :func:`~sklearn.metrics.pairwise.pairwise_kernels`. Note that ``(X, Y=None)`` is a valid input + signature for the kernel function and would then get passed to the pairwise kernel function. If + a callable is passed in, it is generally faster and more efficient if one can define a vectorized + operation that operates on the whole array at once. Otherwise, the pairwise kernel function will + call the function for each combination of rows in the input arrays. """ # if the width of the kernel is not set, then use the median trick to set the # kernel width based on the data X