Skip to content

Commit 15a7e75

Browse files
refine docstring of get_buffer
1 parent b3a6408 commit 15a7e75

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

source/api_cc/include/DeepPotPD.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,20 @@ class DeepPotPD : public DeepPotBase {
235235
void get_type_map(std::string& type_map);
236236

237237
/**
238-
* @brief Get the type map (element name of the atom types) of this model.
239-
* @param[out] type_map The type map of this model.
238+
* @brief Get the buffer of this model.
239+
* @param[in] buffer_name Buffer name.
240+
* @param[out] buffer_array Buffer array.
240241
**/
241242
template<typename BUFFERTYPE>
242-
void get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_arr);
243+
void get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_array);
243244

245+
/**
246+
* @brief Get the buffer of this model.
247+
* @param[in] buffer_name Buffer name.
248+
* @param[out] buffer_scalar Buffer scalar.
249+
**/
244250
template<typename BUFFERTYPE>
245-
void get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_arr);
251+
void get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_scalar);
246252

247253
/**
248254
* @brief Get whether the atom dimension of aparam is nall instead of fparam.

source/api_cc/src/DeepPotPD.cc

+17-17
Original file line numberDiff line numberDiff line change
@@ -379,23 +379,6 @@ template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
379379
const std::vector<float>& aparam,
380380
const bool atomic);
381381

382-
/* general function except for string buffer */
383-
template<typename BUFFERVTYPE>
384-
void DeepPotPD::get_buffer(const std::string &buffer_name, std::vector<BUFFERVTYPE> &buffer_arr) {
385-
auto buffer_tensor = predictor->GetOutputHandle(buffer_name);
386-
auto buffer_shape = buffer_tensor->shape();
387-
int buffer_size = std::accumulate(buffer_shape.begin(), buffer_shape.end(), 1, std::multiplies<int>());
388-
buffer_arr.resize(buffer_size);
389-
buffer_tensor->CopyToCpu(buffer_arr.data());
390-
}
391-
392-
template<typename BUFFERTYPE>
393-
void DeepPotPD::get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer) {
394-
std::vector<BUFFERTYPE> buffer_arr(1);
395-
DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_arr);
396-
buffer = buffer_arr[0];
397-
}
398-
399382
/* type_map is regarded as a special string buffer
400383
that need to be postprocessed */
401384
void DeepPotPD::get_type_map(std::string& type_map) {
@@ -410,6 +393,23 @@ void DeepPotPD::get_type_map(std::string& type_map) {
410393
}
411394
}
412395

396+
/* general function except for string buffer */
397+
template<typename BUFFERTYPE>
398+
void DeepPotPD::get_buffer(const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_array) {
399+
auto buffer_tensor = predictor->GetOutputHandle(buffer_name);
400+
auto buffer_shape = buffer_tensor->shape();
401+
int buffer_size = std::accumulate(buffer_shape.begin(), buffer_shape.end(), 1, std::multiplies<int>());
402+
buffer_array.resize(buffer_size);
403+
buffer_tensor->CopyToCpu(buffer_array.data());
404+
}
405+
406+
template<typename BUFFERTYPE>
407+
void DeepPotPD::get_buffer(const std::string &buffer_name, BUFFERTYPE &buffer_scalar) {
408+
std::vector<BUFFERTYPE> buffer_array(1);
409+
DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_array);
410+
buffer_scalar = buffer_array[0];
411+
}
412+
413413
// forward to template method
414414
void DeepPotPD::computew(std::vector<double>& ener,
415415
std::vector<double>& force,

0 commit comments

Comments
 (0)