Skip to content

Commit a3c4663

Browse files
update CMAKE
1 parent 482d588 commit a3c4663

File tree

5 files changed

+669
-172
lines changed

5 files changed

+669
-172
lines changed

source/api_cc/CMakeLists.txt

+1-4
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ if(ENABLE_PYTORCH
2323
target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}")
2424
target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH)
2525
endif()
26-
if(ENABLE_PADDLE
27-
AND "${OP_CXX_ABI_PT}" EQUAL "${OP_CXX_ABI}"
28-
# LAMMPS and i-PI in the Python package are not ready - needs more work
29-
AND NOT BUILD_PY_IF)
26+
if(ENABLE_PADDLE AND NOT BUILD_PY_IF)
3027
target_link_libraries(${libname} PRIVATE "${PADDLE_LIBRARIES}")
3128
target_compile_definitions(${libname} PRIVATE BUILD_PADDLE)
3229
endif()

source/api_cc/include/DeepPotPD.h

+20-14
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
// SPDX-License-Identifier: LGPL-3.0-or-later
22
#pragma once
33

4-
#include "paddle/include/paddle_inference_api.h"
4+
// #include "paddle/include/paddle_inference_api.h"
55
// #include "paddle/extension.h"
66
// #include "paddle/phi/backends/all_context.h"
77

88
#include "DeepPot.h"
99
#include "common.h"
10+
#include "commonPD.h"
1011
#include "neighbor_list.h"
1112

1213
namespace deepmd {
@@ -177,19 +178,19 @@ class DeepPotPD : public DeepPotBase {
177178
*same aparam.
178179
* @param[in] atomic Whether to compute the atomic energy and virial.
179180
**/
180-
template <typename VALUETYPE, typename ENERGYVTYPE>
181-
void compute_mixed_type(ENERGYVTYPE& ener,
182-
std::vector<VALUETYPE>& force,
183-
std::vector<VALUETYPE>& virial,
184-
std::vector<VALUETYPE>& atom_energy,
185-
std::vector<VALUETYPE>& atom_virial,
186-
const int& nframes,
187-
const std::vector<VALUETYPE>& coord,
188-
const std::vector<int>& atype,
189-
const std::vector<VALUETYPE>& box,
190-
const std::vector<VALUETYPE>& fparam,
191-
const std::vector<VALUETYPE>& aparam,
192-
const bool atomic);
181+
// template <typename VALUETYPE, typename ENERGYVTYPE>
182+
// void compute_mixed_type(ENERGYVTYPE& ener,
183+
// std::vector<VALUETYPE>& force,
184+
// std::vector<VALUETYPE>& virial,
185+
// std::vector<VALUETYPE>& atom_energy,
186+
// std::vector<VALUETYPE>& atom_virial,
187+
// const int& nframes,
188+
// const std::vector<VALUETYPE>& coord,
189+
// const std::vector<int>& atype,
190+
// const std::vector<VALUETYPE>& box,
191+
// const std::vector<VALUETYPE>& fparam,
192+
// const std::vector<VALUETYPE>& aparam,
193+
// const bool atomic);
193194

194195
public:
195196
/**
@@ -327,6 +328,10 @@ class DeepPotPD : public DeepPotBase {
327328
private:
328329
int num_intra_nthreads, num_inter_nthreads;
329330
bool inited;
331+
332+
template <class VT>
333+
VT get_scalar(const std::string& name) const;
334+
330335
int ntypes;
331336
int ntypes_spin;
332337
int dfparam;
@@ -336,6 +341,7 @@ class DeepPotPD : public DeepPotBase {
336341
std::shared_ptr<paddle_infer::Predictor> predictor = nullptr;
337342
std::shared_ptr<paddle_infer::Config> config = nullptr;
338343
double rcut;
344+
double cell_size;
339345
NeighborListData nlist_data;
340346
int max_num_neighbors;
341347
InputNlist nlist;

source/api_cc/include/version.h.in

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ const std::string global_git_branch="@GIT_BRANCH@";
1010
const std::string global_tf_include_dir="@TensorFlow_INCLUDE_DIRS@";
1111
const std::string global_tf_lib="@TensorFlow_LIBRARY@";
1212
const std::string global_pt_lib="@TORCH_LIBRARIES@";
13+
const std::string global_pd_lib="@PADDLE_LIBRARIES@";
1314
const std::string global_model_version="@MODEL_VERSION@";

0 commit comments

Comments
 (0)