1
1
// SPDX-License-Identifier: LGPL-3.0-or-later
2
2
#pragma once
3
3
4
- #include " paddle/include/paddle_inference_api.h"
4
+ // #include "paddle/include/paddle_inference_api.h"
5
5
// #include "paddle/extension.h"
6
6
// #include "paddle/phi/backends/all_context.h"
7
7
8
8
#include " DeepPot.h"
9
9
#include " common.h"
10
+ #include " commonPD.h"
10
11
#include " neighbor_list.h"
11
12
12
13
namespace deepmd {
@@ -177,19 +178,19 @@ class DeepPotPD : public DeepPotBase {
177
178
*same aparam.
178
179
* @param[in] atomic Whether to compute the atomic energy and virial.
179
180
**/
180
- template <typename VALUETYPE, typename ENERGYVTYPE>
181
- void compute_mixed_type (ENERGYVTYPE& ener,
182
- std::vector<VALUETYPE>& force,
183
- std::vector<VALUETYPE>& virial,
184
- std::vector<VALUETYPE>& atom_energy,
185
- std::vector<VALUETYPE>& atom_virial,
186
- const int & nframes,
187
- const std::vector<VALUETYPE>& coord,
188
- const std::vector<int >& atype,
189
- const std::vector<VALUETYPE>& box,
190
- const std::vector<VALUETYPE>& fparam,
191
- const std::vector<VALUETYPE>& aparam,
192
- const bool atomic);
181
+ // template <typename VALUETYPE, typename ENERGYVTYPE>
182
+ // void compute_mixed_type(ENERGYVTYPE& ener,
183
+ // std::vector<VALUETYPE>& force,
184
+ // std::vector<VALUETYPE>& virial,
185
+ // std::vector<VALUETYPE>& atom_energy,
186
+ // std::vector<VALUETYPE>& atom_virial,
187
+ // const int& nframes,
188
+ // const std::vector<VALUETYPE>& coord,
189
+ // const std::vector<int>& atype,
190
+ // const std::vector<VALUETYPE>& box,
191
+ // const std::vector<VALUETYPE>& fparam,
192
+ // const std::vector<VALUETYPE>& aparam,
193
+ // const bool atomic);
193
194
194
195
public:
195
196
/* *
@@ -327,6 +328,10 @@ class DeepPotPD : public DeepPotBase {
327
328
private:
328
329
int num_intra_nthreads, num_inter_nthreads;
329
330
bool inited;
331
+
332
+ template <class VT >
333
+ VT get_scalar (const std::string& name) const ;
334
+
330
335
int ntypes;
331
336
int ntypes_spin;
332
337
int dfparam;
@@ -336,6 +341,7 @@ class DeepPotPD : public DeepPotBase {
336
341
std::shared_ptr<paddle_infer::Predictor> predictor = nullptr ;
337
342
std::shared_ptr<paddle_infer::Config> config = nullptr ;
338
343
double rcut;
344
+ double cell_size;
339
345
NeighborListData nlist_data;
340
346
int max_num_neighbors;
341
347
InputNlist nlist;
0 commit comments