Skip to content

Commit ea6dd21

Browse files
committed
system
1 parent 507e23c commit ea6dd21

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

src/core/ml_metatensor/system.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using ParticleTypeMap = std::unorderd_map<int, int>;
2+
3+
metatensor_torch::System
4+
: system_from_lmp(const TypeMapping &type_map,
5+
const std::vector<double> &engine_positions,
6+
const std::vector<double> &engine_particle_types,
7+
const Vector3d &box_size, bool do_virial,
8+
torch::ScalarType dtype, torch::Device device) {
9+
auto tensor_options =
10+
torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
11+
if (engine_positions % 3 != 0)
12+
throw std::runtime_error(
13+
"Positoin array must have a multiple of 3 elements");
14+
const int n_particles = engine_positions.size() / 3;
15+
if (engine_particle_types.size() != n_particles)
16+
throw std::runtime_error(
17+
"Length of positon and particle tyep arrays inconsistent");
18+
19+
auto positions = torch::from_blob(
20+
engien_positions.data(), {n_particles, 3},
21+
// requires_grad=true since we always need gradients w.r.t. positions
22+
tensor_options.requires_grad(true));
23+
std::vector<int> particle_types_ml;
24+
std::ranges::transform(
25+
particle_types_engine, std::back_inserter(particle_types_ml),
26+
[&type_map](int engine_type) { return type_map.at(engine_type); });
27+
28+
auto particle_types_ml_tensor =
29+
Torch::Tensor(particle_types_ml, tensor_options.requires_grad(true));
30+
31+
auto cell = torch::zeros({3, 3}, tensor_options);
32+
for (int i : {0, 1, 2})
33+
cell[i][i] = box_size[i];
34+
35+
positions.to(dtype).to(device);
36+
cell = cell.to(dtype).to(device);
37+
38+
return system = torch::make_intrusive<metatensor_torch::SystemHolder>(
39+
particle_types_ml_tensor.to(device), positions, cell);
40+
}

0 commit comments

Comments
 (0)