diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 48a4ea6f..0cf0323c 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -3,6 +3,7 @@ #include "import.h" #include "llm/operators.h" #include "onnx/operators.h" +#include "search.h" #include // keep this line to convert stl types namespace py = pybind11; @@ -26,6 +27,7 @@ namespace refactor::python_ffi { py::class_>(m, "Pinned" ); m .def("config_log" , &configLog , return_::automatic ) + .def("random_search" , &randomSearch , return_::move ) .def("find_device" , &findDevice , return_::move ) .def("_make_operator" , &makeOp , return_::move ) .def("_make_tensor" , &makeTensor , return_::move ) diff --git a/src/09python_ffi/src/search.cpp b/src/09python_ffi/src/search.cpp new file mode 100644 index 00000000..63961081 --- /dev/null +++ b/src/09python_ffi/src/search.cpp @@ -0,0 +1,69 @@ +#include "search.h" +#include "functions.h" +#include + +namespace refactor::python_ffi { + + pybind11::array randomSearch(pybind11::array logits_, int topK, float topP, float temperature) { + auto shape = std::span(logits_.shape(), logits_.ndim()); + ASSERT(!shape.empty(), ""); + auto shapeBack = shape.begin() + shape.size() - 1; + auto batch = std::accumulate(shape.begin(), shapeBack, 1l, std::multiplies()), + vocabSize = *shapeBack; + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution rng(0, 1); + std::vector result(batch); + for (auto i : range0_(batch)) { + // cast + std::vector logits(vocabSize); + if (auto type = parseNumpyDType(logits_.dtype()); type == DataType::FP16) { + auto data = reinterpret_cast(logits_.data()) + i * vocabSize; + std::transform(data, data + logits.size(), logits.begin(), [=](auto x) { return x.to_f32() / temperature; }); + } else if (type == DataType::F32) { + auto data = reinterpret_cast(logits_.data()) + i * vocabSize; + std::transform(data, data + logits.size(), logits.begin(), [=](auto x) { return x / temperature; }); + } else { + RUNTIME_ERROR("unsupported data type."); + } + std::vector> probabilities(vocabSize); + // softmax + for (auto max = *std::max_element(logits.begin(), logits.end()); + auto j : range0_(vocabSize)) { + auto p = std::exp(logits[j] - max); + probabilities[j] = {p, j}; + } + auto k = 0; + {// topK + topP + std::sort(probabilities.begin(), probabilities.end(), + [](auto a, auto b) { return a.first > b.first; }); + for (auto cum = 0.f; auto i : range0_(topK)) { + ++k; + if ((cum += probabilities[i].first) > topP) { + break; + } + } + } + auto chosen = false; + auto p = rng(gen); + // re-softmax + for (auto max = probabilities[0].first; + auto j : range0_(k)) { + auto p_ = std::exp(probabilities[j].first - max); + if (p_ >= p) { + result[i] = probabilities[j].second; + chosen = true; + break; + } + p -= p_; + } + if (!chosen) { + result[i] = probabilities[k - 1].second; + } + } + + return pybind11::array(buildNumpyDType(DataType::I64), std::span(shape.begin(), shape.size() - 1), result.data()); + } + +}// namespace refactor::python_ffi diff --git a/src/09python_ffi/src/search.h b/src/09python_ffi/src/search.h new file mode 100644 index 00000000..f9d8d398 --- /dev/null +++ b/src/09python_ffi/src/search.h @@ -0,0 +1,12 @@ +#ifndef PYTHON_FFI_SEARCH_H +#define PYTHON_FFI_SEARCH_H + +#include + +namespace refactor::python_ffi { + + pybind11::array randomSearch(pybind11::array, int topK, float topP, float temperature); + +}// namespace refactor::python_ffi + +#endif// PYTHON_FFI_SEARCH_H