@@ -379,23 +379,6 @@ template void DeepPotPD::compute<float, std::vector<ENERGYTYPE>>(
379
379
const std::vector<float >& aparam,
380
380
const bool atomic);
381
381
382
- /* general function except for string buffer */
383
- template <typename BUFFERVTYPE>
384
- void DeepPotPD::get_buffer (const std::string &buffer_name, std::vector<BUFFERVTYPE> &buffer_arr) {
385
- auto buffer_tensor = predictor->GetOutputHandle (buffer_name);
386
- auto buffer_shape = buffer_tensor->shape ();
387
- int buffer_size = std::accumulate (buffer_shape.begin (), buffer_shape.end (), 1 , std::multiplies<int >());
388
- buffer_arr.resize (buffer_size);
389
- buffer_tensor->CopyToCpu (buffer_arr.data ());
390
- }
391
-
392
- template <typename BUFFERTYPE>
393
- void DeepPotPD::get_buffer (const std::string &buffer_name, BUFFERTYPE &buffer) {
394
- std::vector<BUFFERTYPE> buffer_arr (1 );
395
- DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_arr);
396
- buffer = buffer_arr[0 ];
397
- }
398
-
399
382
/* type_map is regarded as a special string buffer
400
383
that need to be postprocessed */
401
384
void DeepPotPD::get_type_map (std::string& type_map) {
@@ -410,6 +393,23 @@ void DeepPotPD::get_type_map(std::string& type_map) {
410
393
}
411
394
}
412
395
396
+ /* general function except for string buffer */
397
+ template <typename BUFFERTYPE>
398
+ void DeepPotPD::get_buffer (const std::string &buffer_name, std::vector<BUFFERTYPE> &buffer_array) {
399
+ auto buffer_tensor = predictor->GetOutputHandle (buffer_name);
400
+ auto buffer_shape = buffer_tensor->shape ();
401
+ int buffer_size = std::accumulate (buffer_shape.begin (), buffer_shape.end (), 1 , std::multiplies<int >());
402
+ buffer_array.resize (buffer_size);
403
+ buffer_tensor->CopyToCpu (buffer_array.data ());
404
+ }
405
+
406
+ template <typename BUFFERTYPE>
407
+ void DeepPotPD::get_buffer (const std::string &buffer_name, BUFFERTYPE &buffer_scalar) {
408
+ std::vector<BUFFERTYPE> buffer_array (1 );
409
+ DeepPotPD::get_buffer<BUFFERTYPE>(buffer_name, buffer_array);
410
+ buffer_scalar = buffer_array[0 ];
411
+ }
412
+
413
413
// forward to template method
414
414
void DeepPotPD::computew (std::vector<double >& ener,
415
415
std::vector<double >& force,
0 commit comments