diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index fa0e2a9217..2d33c10190 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -148,7 +148,7 @@ struct config { }; template <> struct config { - using value_t = half; + using value_t = float; static constexpr double kDivisor = 1.0; }; template <> diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 9bf4ae6784..f097767c1a 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -138,8 +138,8 @@ struct loadAndComputeDist { for (int k = 0; k < Veclen; ++k) { compute_dist(dist, queryRegs[k], encV[k]); if constexpr (ComputeNorm) { - norm_query += queryRegs[k] * queryRegs[k]; - norm_data += encV[k] * encV[k]; + norm_query += (AccT)(queryRegs[k] * queryRegs[k]); + norm_data += (AccT)(encV[k] * encV[k]); } } } @@ -173,8 +173,8 @@ struct loadAndComputeDist { T q = raft::shfl(queryReg, d + k, raft::WarpSize); compute_dist(dist, q, encV[k]); if constexpr (ComputeNorm) { - norm_query += q * q; - norm_data += encV[k] * encV[k]; + norm_query += (AccT)(q * q); + norm_data += (AccT)(encV[k] * encV[k]); } } } @@ -199,8 +199,8 @@ struct loadAndComputeDist { T q = raft::shfl(queryReg, d + k, raft::WarpSize); compute_dist(dist, q, enc[k]); if constexpr (ComputeNorm) { - norm_query += q * q; - norm_data += enc[k] * enc[k]; + norm_query += (AccT)(q * q); + norm_data += (AccT)(enc[k] * enc[k]); } } }