From a3eeaa730f70895294c7929ea261f7c41eeae1d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20Ho=C3=9Fbach?= Date: Fri, 2 Aug 2024 13:40:23 +0200 Subject: [PATCH 1/9] add metatensor to cmake --- CMakeLists.txt | 33 +++++++++++++++++++++++++++++ cmake/espresso_cmake_config.cmakein | 2 ++ src/config/features.def | 1 + 3 files changed, 36 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index fd98f91fac..662ec25815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,7 @@ option(ESPRESSO_BUILD_TESTS "Enable tests" ON) option(ESPRESSO_BUILD_WITH_SCAFACOS "Build with ScaFaCoS support" OFF) option(ESPRESSO_BUILD_WITH_STOKESIAN_DYNAMICS "Build with Stokesian Dynamics" OFF) +option(ESPRESSO_BUILD_WITH_METATENSOR "Build with Metatensor support" OFF) option(ESPRESSO_BUILD_WITH_WALBERLA "Build with waLBerla lattice-Boltzmann support" OFF) option(ESPRESSO_BUILD_WITH_WALBERLA_AVX @@ -595,6 +596,38 @@ if(ESPRESSO_BUILD_BENCHMARKS) add_subdirectory(maintainer/benchmarks) endif() +# +# Metatensor +# + +if(ESPRESSO_BUILD_WITH_METATENSOR) + # Bring the `torch` target in scope to allow evaluation of cmake generator + # expression from `metatensor_torch` + find_package(Torch REQUIRED) + + # cmake-format: off + set(METATENSOR_URL_BASE "https://github.com/lab-cosmo/metatensor/releases/download") + set(METATENSOR_CORE_VERSION "0.1.8") + + include(FetchContent) + FetchContent_Declare( + metatensor + URL "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz" + URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef + ) + + # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146 + if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS "${metatensor_SOURCE_DIR}") + message(STATUS "Fetching metatensor v${METATENSOR_CORE_VERSION} from github") + FetchContent_Populate(metatensor) + endif() + # cmake-format: on + + set(BUILD_SHARED_LIBS on CACHE BOOL "") + set(METATENSOR_INSTALL_BOTH_STATIC_SHARED off CACHE BOOL "") + add_subdirectory("${metatensor_SOURCE_DIR}") +endif() + # # waLBerla # diff --git a/cmake/espresso_cmake_config.cmakein b/cmake/espresso_cmake_config.cmakein index ffdeb99912..57322a5aca 100644 --- a/cmake/espresso_cmake_config.cmakein +++ b/cmake/espresso_cmake_config.cmakein @@ -13,6 +13,8 @@ #cmakedefine ESPRESSO_BUILD_WITH_STOKESIAN_DYNAMICS +#cmakedefine ESPRESSO_BUILD_WITH_METATENSOR + #cmakedefine ESPRESSO_BUILD_WITH_WALBERLA #cmakedefine ESPRESSO_BUILD_WITH_WALBERLA_FFT diff --git a/src/config/features.def b/src/config/features.def index ff8eb2a041..6e887b5574 100644 --- a/src/config/features.def +++ b/src/config/features.def @@ -111,6 +111,7 @@ HDF5 external SCAFACOS external GSL external STOKESIAN_DYNAMICS external +METATENSOR external WALBERLA external WALBERLA_FFT external VALGRIND external From ea6dd211f12b6b0f43ce44859fa7941f3fdb924c Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Tue, 6 Aug 2024 11:23:18 +0200 Subject: [PATCH 2/9] system --- src/core/ml_metatensor/system.hpp | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 src/core/ml_metatensor/system.hpp diff --git a/src/core/ml_metatensor/system.hpp b/src/core/ml_metatensor/system.hpp new file mode 100644 index 0000000000..d9a5dac359 --- /dev/null +++ b/src/core/ml_metatensor/system.hpp @@ -0,0 +1,40 @@ +using ParticleTypeMap = std::unorderd_map; + +metatensor_torch::System + : system_from_lmp(const TypeMapping &type_map, + const std::vector &engine_positions, + const std::vector &engine_particle_types, + const Vector3d &box_size, bool do_virial, + torch::ScalarType dtype, torch::Device device) { + auto tensor_options = + torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + if (engine_positions % 3 != 0) + throw std::runtime_error( + "Positoin array must have a multiple of 3 elements"); + const int n_particles = engine_positions.size() / 3; + if (engine_particle_types.size() != n_particles) + throw std::runtime_error( + "Length of positon and particle tyep arrays inconsistent"); + + auto positions = torch::from_blob( + engien_positions.data(), {n_particles, 3}, + // requires_grad=true since we always need gradients w.r.t. positions + tensor_options.requires_grad(true)); + std::vector particle_types_ml; + std::ranges::transform( + particle_types_engine, std::back_inserter(particle_types_ml), + [&type_map](int engine_type) { return type_map.at(engine_type); }); + + auto particle_types_ml_tensor = + Torch::Tensor(particle_types_ml, tensor_options.requires_grad(true)); + + auto cell = torch::zeros({3, 3}, tensor_options); + for (int i : {0, 1, 2}) + cell[i][i] = box_size[i]; + + positions.to(dtype).to(device); + cell = cell.to(dtype).to(device); + + return system = torch::make_intrusive( + particle_types_ml_tensor.to(device), positions, cell); +} From 7f0b5fe9d4c53cbe2699d0a58657d9140df15079 Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Tue, 6 Aug 2024 13:44:14 +0200 Subject: [PATCH 3/9] neighbor list --- src/core/ml_metatensor/CMakeLists.txt | 20 ++++++ src/core/ml_metatensor/add_neighbor_list.hpp | 71 ++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 src/core/ml_metatensor/CMakeLists.txt create mode 100644 src/core/ml_metatensor/add_neighbor_list.hpp diff --git a/src/core/ml_metatensor/CMakeLists.txt b/src/core/ml_metatensor/CMakeLists.txt new file mode 100644 index 0000000000..fa3eb2b411 --- /dev/null +++ b/src/core/ml_metatensor/CMakeLists.txt @@ -0,0 +1,20 @@ +# +# Copyright (C) 2018-2022 The ESPResSo project +# +# This file is part of ESPResSo. +# +# ESPResSo is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# ESPResSo is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +#target_sources(espresso_core PRIVATE bonded_interaction_data.cpp) diff --git a/src/core/ml_metatensor/add_neighbor_list.hpp b/src/core/ml_metatensor/add_neighbor_list.hpp new file mode 100644 index 0000000000..ceb18cf1ee --- /dev/null +++ b/src/core/ml_metatensor/add_neighbor_list.hpp @@ -0,0 +1,71 @@ +struct PairInfo { + int part_id_1, + int part_id_2, + Utils::Vector3d distance; +} + +using Sample = std::array; +using Distances = + std::variant>, std::vector>>; + + +template +TorchTensorBlock neighbor_list_from_pairs(const metatensor_torch::System& system, const PairIterable& pairs) { + auto dtype = system->positions().scalar_type(); + auto device = system->positions().device(); + std::vector samples; + Distances distances; + if (dtype == torch::kFloat64) { + distances = {std::vector>()}; + } + else if (dtype == torch::kFloat32) { + distances = {std::vector>()}; + } + else { + throw std::runtime_error("Unsupported floating poitn data type"); + } + + for (auto const& pair: pairs) { + auto sample = Sample{ + pair.particle_id_1, pair.particle_id_2, 0, 0, 0}; + samples.push_back(sample); + (*distances).push_back(pair.distance); + } + + + int64_t n_pairs = samples.size(); + auto samples_tensor = torch::from_blob( + reinterpret_cast(samples.data()), + {n_pairs, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU) + ); + + auto samples = torch::make_intrusive( + std::vector{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"}, + samples_values + ); + + distances_vectors = torch::from_blob( + (*distances).data(), + {n_pairs, 3, 1}, + torch::TensorOptions().dtype(dtype).device(torch::kCPU) + ); + return neighbors = torch::make_intrusive( + distances_vectors.to(dtype).to(device), + samples->to(device), + std::vector{ + metatensor_torch::LabelsHolder::create({"xyz"}, {{0}, {1}, {2}})->to(device), + }, + metatensor_torch::LabelsHolder::create({"distance"}, {{0}})->to(device) + ); + +} + +void add_neighbor_list_to_system(MetatensorTorch::system& system, + const TorchTensorBlock& neighbors, + const NeighborListOptions& options) { + metatensor_torch::register_autograd_neighbors(system, neighbors, options_.check_consistency); + system->add_neighbor_list(options, neighbors); +} + + From 5be4e0a2d2b4011405f251599d0688726c47668a Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Tue, 6 Aug 2024 14:34:10 +0200 Subject: [PATCH 4/9] load model --- src/core/ml_metatensor/load_model.hpp | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/core/ml_metatensor/load_model.hpp diff --git a/src/core/ml_metatensor/load_model.hpp b/src/core/ml_metatensor/load_model.hpp new file mode 100644 index 0000000000..b749c41a3a --- /dev/null +++ b/src/core/ml_metatensor/load_model.hpp @@ -0,0 +1,47 @@ +using ModelPtr = std::unique_ptr; +using NeighborListRequest = + std::pair; + + +ModelPtr load_model(const std::string& path, const std::string& extensions_directory, torch::device device) { + + return std::make_unique( + metatensor_torch::load_atomistic_model(path, extensions) + ); +} + + +metatensor_torch::ModelCapabilitiesHolder +get_model_capabilites(const ModelPtr& model) { + auto capabilities_ivalue = model->run_method("capabilities"); + return capabilities_ivalue.toCustomClass(); +}; + +bool modle_provides_energy(const torch_metatensor::ModelCapabilitiesHolder& capabilities) { + return (capabilities->outputs().contains("energy")); +} + + +metatensor_toch::ModelMetadataHolder get_model_metadata(ModelPtr& model) { + auto metadata_ivalue = model->run_method("metadata"); + return metadata_ivalue.toCustomClass(); +} + + + +double required_range(const metatensor_torch::ModelCapabilities& capabilities, const metatensor_torch::ModelEvaluatoinOptoins& evaluation_optoins) { + return range = mts_data->capabilities->engine_interaction_range(evaluation_options->length_unit()); +} + +std::vector get_requested_neighbor_lists(ModelPtr& model) { + + std::vector res; + auto requested_nl = mts_data->model->run_method("requested_neighbor_lists"); + for (const auto& ivalue: requested_nl.toList()) { + auto options = ivalue.get().toCustomClass(); + auto cutoff = options->engine_cutoff(mts_data->evaluation_options->length_unit()); + + res.push_back({cutoff, options}); + } + return res; +} From 5f6f684df968503cb521bf189c5147ddc0a2817e Mon Sep 17 00:00:00 2001 From: Julian Hossbach Date: Tue, 6 Aug 2024 15:53:05 +0200 Subject: [PATCH 5/9] add metatensor lj test --- lennard-jones.pt | Bin 0 -> 34673 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lennard-jones.pt diff --git a/lennard-jones.pt b/lennard-jones.pt new file mode 100644 index 0000000000000000000000000000000000000000..760ebca085bbd878ecef79aaea4a2cd5afb89f57 GIT binary patch literal 34673 zcmbSy1yo#1wl?nW?(XgqAXspBr*UcA-8BSvm*5cG9fG?{kl-FPSboUdnYr`k-C6Vg zvsOWOon5uR?Oo?oDawFDfPlcjfc)D<00ISK1F*9*ax`VMvbO^`F$3J49gUc60nSFI zM$SeOqF^A}+FoR@p8;fiBxGzXY#fX%Jd7+{sw}*GEUbJSJWTAI>|9)&^eilVEG%RU zBxFDvBWE*vN82}gcOEVSE)GUnOFI{LG9N`5XqeM1+#RV`9B)nWFL4opF#NB$UeN*U zoGk6_oFqhPUXlHGmyX^WXWBMdM6g$`Z|$!+Q-RR^ubkgB^2XcVkvsc7Cj&DV8yf>-BNJ=4*Jv`Zu(!8%0F&3meY@)jaBu-QIRi`$ z>;RVL7RL6D1~#uOogBbT{)WNI(wWHsXlY_?1AudUrL=bh7&v-POIQ>Z+IQc*3mVKRC0qFb|C2$Uqv;ms_uNb`GNM0|p-^1P_^M4^GWB7x{ z#o6A>?3IXviIuT*A6Hco&)`d@R2xwD0Vi=E}$ z9WoI+b0=pEBb71MGjJ zo8K0Ev+SD_!8yRaY6a>w5&WSQBNt~2dq<~N$U47O@LN@qE|xYnMlQAh5(!5m7k3f? z^S^g6nZ0$q0|3ozm`p5$$bOH}`@8nu4D%-Zzp(tfZcPDBCXSZBjq{u4)r)M6Y)G6e zEX`isgGBa^T>RlJ7n1<(Uo(!grI8H-iL*V4F@VIy>6OOK0$@kt{OTmH{-&-7H^O zGW{K=ossSDgs=L_=U-9%ox%}d258|{hJ68qCII0vLx?^*`^ zr+2-j5tHBX{I4#?%Jv^T=g(9AUFW86IpNK3%xvt9oVhsu-v8zRjz%VLmNalQvNM09 z`Rm~Qt04Wuf$d$KfiBMf!>Rx~fTOv`>$Ly0^lyXwBcJ^Cc7@+V|3LM2|Lr0JxC3o0 zO)Q-aUVYTm^3^kbYvE63{Y?b~I2ydU>njguN0&bm%I^u@=5TN^dW{68%jymo7SO-TkiqNE081NtI|oe3 zcigWuua_p&>mLIrOLJQz2P{e9``?{^?|sWM4%m`iyV3NeKza!mV+WkqhRgI;l|$-R zWJzDqzqPljjPd$|@!x+jzWcYTZ1SoA=6_`KKLs4~@5Ijhj}*lGJ9wFZ9{=b1J&wN` zyV`ssXh1q9WBQKMxt7wyAL9VnjtA7sKTm@?^|f`$+4AhrzX$!JOXgOOZViU@_(AqnUv!lXrHRvBTq_0}j8Ltt-t# z{wa>r2Jb4NpQl@E7dO$3t@;Asd2@}=O-^Mi|BK08Ntv;m*2^KS_1#IWrBKhKL3-=` zg*O9GPRCmaQ_tW;OD_@mqnq3)-DY7BK)CNNxwg~tN^&sy$t#U~U%W|xW+SCb>YeVd zY7+x4mm!aS%EyXHkaHflemi1UHKHjRQw7yy^?ONB< zD>E_WIB`ov;+lDeE9-br)zst1(=GpwyBB8h_@snc%Vh*bYDHmlm^kVQeRPW(s5JMp z4poS%H`1I+E>i$883{>q&1nb>ETvWgr)IlqQ z9ta&wnc(&hK_YHJbIohj+mGE_-N!t}9ZoD%3^8Gzhc-)u-FsLT7#uZ-ir>@paW?Ht z0<3Gc$FDXKw}$Y^BFq)>RTWkBSf%`e1(~_6F34>L1LHCnU_OtF&YJZai{!3F(e(?P zvtqp`=m4$JRNq{FCjZ8Dl6OEV*)|dawstHU+6ppAHtA5rcj^yLVFLp~Wrt}9!-H&n z*RvUj4t4qd^Cp-tDlOk8#Jh|F)Lq_S{qD0mg70qj5WV=wlszBaL)S0R+J=MIFTmPn zKDEe5R5@CYQ~XHK{Sv)#1xJ2|6`|W%u-elo|6uuG(ij$Q0g7nJ&Gs;gw>Q|TrRt4u zMTKUyS>jEbfLKDR&yZZk5ucu`oqQY-dc*(D;5$wdOM?qKdFD8F-mf4u2^`?ESWak8 zk{}>hXnD=Z!se{j_-pT1O`ZZuYtM_x{V)`AOdWJFP>p*~QkfqnR3N{+p-dTWLNBD+#jS%~acWU2H^|j79c?0ewH}O*} z`WB~VJ<60l!!m6^cP5DN3gGit;i&Rh`F3Oe|1@ zlr;Z7iF-Ao+R{##`ABPlMkh0ZvnVLBStW;*MSDW!J8xvc%}N`oE-@QW_iO08`2-|5 z966M+pF@oT48O9H^e$)ZJ$eZHAysqt+f=k}0PBpuq8NJ>Y**_;kA~6-Xww(4U*;hP zA)k+%3JT?KOe^~f)e`eyowCsxuuXI>3FvQ!jXg9tf0j&eqWj+?uTnPtFv8iY*LqfL zrs#a1F#uD&15{lDaO?Q5#^~3nNwH+6KWpux`FoBM$zz&T_+v@FVE9v$ zMi2%$YR&oV$1cKxYN&&1Osl@@M=09zj~nV!gZNVAnw;b zXJFHpqz%DkUu0H4a61v@3#)<+_T+gluBj@+oJwWcT3W@Scaphe+5X(R(Ecb`ov21F z!VL`ttDXeH)Be96$8+!f_c^Z`bw)}#4M=DEbT6P;1;%g~$B3@)QE{`Zo$3{mF2hSX1!!UC`CfW0csUrF_33{QSrP$LPx+>6n4t?HxO2%N;%Q^Rw^xP`ZLG zxQaV#3NDn%+P8ZHb}$UJnb;`>RlsZYkec;p8mcRPzRvzQ=Z6|AA20 zj3K3@FcUQb5-8D9E)K)E1B|K*Fs3+IW5eb0=2hPDXO_O5Bu5=*8m@j8wl5~f2Xs_f zCewV)sJUa6<>vc&R%c(zJh6E0DV(Ruz$Xdo3PjS~Am2GE@zY2c%J>Gv#h6IGSXk_Z zJP?Fee{f`&v>U;b)jIPaopq2&#js0K$~qlt0wXND-O=%WzC4{^XFylUG?DE_p&?2_ zwz6?}S`;=D^z)D8<?9%HeaWmtgP!(qwfg$_y`#SWW2EUa^pp+Q@;b%DC{VL1Ef)S*PAmb3VB+$DWvX5w`|!mAq8HiO2Z^K-2OCC(odzZa zarnZg=#-*MPxxH^qQu}tcSVU?APjuim#Ra0j5P}n#e63`4A(di4MtgK0?gY$w-hQ{Z)#sxvj zCHrgcQ^J&me#Wpl83+{yVz>dKZR4OBC5O1+v7bB=rYr(9(d1SxTzGH>0xxCpv)^N{ zUO4E8MZ;7mi@k%|$CalI-~oK0(nfJa6_+O|eYtOE)zT6C(Ecn5g*@ybB|0Hi<$NS{ zR=fD~>N8G@>yMH6REwgMV0~E_)=zT*yb;h3dwjB=;~Z-B!z4WxZq6{pPbQZM!@tJ> zOG&aR2ESXqcZjbQVDFkHFQdqlknuQb^N*?cCQC?Eyz6fsy3PCXTRWF&?VNBmqCOuQ zz{cKm00hHfd_2~QqR8o)qsFSg|6Vy~aHo#r{tSgtP=jcvE2tOd34`G4E+f`2{}(f8 z$P0DYugcM$t8zFJN@x#CnN=FnU*vsKvE|{52_t5+VrM>#iNS`-h~a&P{4B(@>dGfi zo-SMP$rLz&{N5-i4V^Oi2$)XzQM|*gGCs)tf|;Dm_#kaXvtt2O!KRV7z8%Vwn#!*? z=9^a53Ph9YTHT}9ZH+0GC_L_LB6S})FRQczaj~^v9fXGevSyErU%A38BzZp`yv+Dl zv4C(=L3m$VK zd26ego49UHINc8QfT=F_!v4NKPcip>g^0RMoMrT+gahucij{_Yj zjLs-7(nc100saJ>@m+n* zptX0MI13@c%|&J$lx=!&*nu=N6-ah>eRwY?uSru{hzU$cJ_bQ1TQ_bJ<39AZ1fy1z zlP8SfZJJJZ_3X^=ZjkVbV!oY*jAia9R|PezMK5^HCnAvgOMnF)I6lZ5vT)S{@F7Rz(V-yNc))kfySbWs zU@QwQEeDC@h~V=$ajk}VS0EDDsrCu5v6fal5DD76M$a~=k~_wS)~3 zt$UM9e)F3z*QnO-@VxPypyRy5<+fWZGg+LTLF~Z)xB*l>e?UBr)o*Kc>ze^ss}Bw} zq~;C{$c(|>4|P);BH{&2_lFCkT3imuj6?17`o*ie*Q{eRvm;*!5y%x+Y?-FiQYFgT zYF8^FYbRw?W~)zY(qY8pVv2*SyYH&|@tp2c53Nj2clKs(Sqh`nLp5F_`3I`fv$3oI zt|beM>dT!a?K$_DSk;4}lr!Q>-?!EFL{b#M{uuIX^}eKYYLL`cQJHOZ?w5jOG@;n~ zI8fowI_8PGv}*eY(=*{#`F${RdhX(~BiHW_l8n zme^giSh#^>aq}GIk~7%V>?L0P{OK!BwHfb1Je`6{|%UdY^3(6&>m zIEro0fxRbn9vS0`s8edFVbB{{!f1hdFdlxYrmUm=$g+EBrty-vE*FOj-f4oLU4lb6 zHZg;pqxjVf(sjVotCf5N8qtpg#d!LqW*rOdnYORjGuD_umJK39jzeKAEs37h1!4@| zhk*RHZzlOiS7#?;L_K5)A*?nk_z%&bf{{nImp*=4%$ot~G(m*_YHVVT?gIonbqqLBfs=Ntjp_Il7sq)HXiH z$mzJ29NxBYS#xYx{$Kg0_4&4hDd(}Rpnk0_%~CqY^{JfQtq{rutfSM4jOEV|S$Evj zSO;rIQN+1g_Qm+)Y?tEVtkgn1)eV|vsY9L*FMXr7ALq)0mq$z0eL`F4B-Pw?UyfXY z^VH}*x8k$6vO`@7h1L`YP8cC)z4t26kt^nYEOeU^j*Fz))MOvuBZK-_^887ahI3`z z^y2~q{NzZKPa1ND;^C~`JIU)zTO$2Jb2_WL)#nq^?RQQhpr~n_CtRNFgRVr&5<>(T z4LhITnZww##u(#|HQ=Q;>{aUQnY86g?o{U}>L_I_d0NiEWw^w8tf?Sd5@gR zFfNgGPGhydPOd=+_E#fV{8eR~t4I+a$zl8k{YxL|Em-c#j3!@;Nb#Yt=`jVxY2iR;?2K zmxfu?loJBi9UW%7gk~kGf~+k?;q;=w>94)9Kd<-_V3eFpUIG|jNJm%m(9>`!9@!mg zXFGQWC3dXGrKbYRSzGRcOxCB)5KZI(fEHh3o(X|DOC5$Dw<9j=tT2SkbOdaNSnfNg z2W}coD14;!+eH30lBFs!S zlpnV7$I%3C2ISmS5JC!`eRYW|DoNKw+Sq58IjT$geD|VqOkKZon7Ldx)>^QcLi>q7 zY*V-Ka$?Lq7}f2L&0v4~E?*+$FDOpGurzH&fjV%yfR!GRs>^ruW6f0(#V&p@)oO0^ z(-(2#HWwX!=)vqKlLpwNpgfBY$H!Tc$6VL^^Vc2^q*%A|Aq8 zp0npvsl6}5dD-!c9X&-b)eUVqdP(WZ^d7IY`KRLxGO(1V*^{x`?q?HW- z9%uU+J^jSEU$noNJJUA_S8ndne>Q{eN{N8b<`RMTL!L-gs_f3a>%myKf2?HL7c1Fc z@4s3^lpaXEmH0YQ#CuqSUYi`q*TDa>RQXepzR*YPN+K0=IJfAm?Ur)3O;exCoIuto z!;Fl2{L+mop(n zyz3L8eIb}(pV2A#Fv4tsdFmdjN{}BuSA{-5Okjaq>nAp)Ykssj7KP0fB-Q90@UTa8 z^Wt(U;@tU3G`m^hDye2k%f>l@$21K#BWu#aW;Zb^Y4D`X*s+2Z5&Pcp%`NJcC^$8koBFbV}EOp!w0e1J#uL zsS@an8?^N~r(Beq4s>jBr4=-5U%rTne9gx-t@{Nw=>PX{_P`2E`jCP}=5ObmDx1o4rX6E|J+Y(>x zCNWliZ1X>CR1q4FoT&5Y*5IQbIE1+tPQ0g(k%)0d;CAQt-n%b;oJBNvlCNA!u4{3_ z$1%KXt0yK|;u=Nx9-2f-Uvy{dAo0$kWz}XBJ&x5gOgIi*)$rMC%qI$a!YJGe9;^P; z;FMwK1F`8^N{P#!3AJf!-a0OCkr0BT-}D_LkBkGYf~z?{ufo)Ez;^XoZJ5K{+9GXsPIOae0D4$4@c03^&?6#48R$Ay<5v~jeWTg4`edjw zs;N(~vM}Ni#(jayX!0}9LZ|#g0MX(GHvrPSB2l$sJ(RtX>py173{uP#=0)1 zl62;MS#lLg&$}v9$?(ha?o9p8k{;_B_cWxFI?(J)d0EopxdJ~Ti{|8TGfKFmk)N2e z?+_UE?Q15fgmC9dhU$RdBR$ncyNZWP@QsP^gU=@hJ=S?K*@YQZJ?#7OtPbrP_U}UJ zSy5J6snNbU-!@U9+G~r2Q@dYH1|~g!V=JFN0}`xWtWaJysa(2g_cw#KJ?|-&6>wu2 z+uz)jh;4!B(-rb~K5l=rI<%}+KJMpXl^^`3_G?r4;e3XKJ^y`TB%2Zw-TR`?X!H3J z_ZLm-b6fTu7AP^s?WrVGk<#4>1N!%EO{rwJsE?7Q#=@*}ve=J1%b>j|?C%p&Y0V;* zkA?&Cwd%vwr;DcurkVC4k)~<)B7S}t1gk}P#*PT;Hj`Ef2N`83^&sd?wh{xzB%O2By+&zQ=A>MN)A27 zDsQG~H*Zs}J+q3tKPKjCuRyx{7Xmv<%q#SFl=aUD_C}5^M6)nL3;A(gYLlsAQ;eYy z_G}-rC!RlVAPMiRO`0(=YLe=^(mMIhpFH}v#Xn%U_S5`qhw zf53CtbnvgE(CY&>$4-D|BFf?FYe*u45z!p(Beg@)|IYM`W zFGc#yeJ)+fuY+dyo(r{&k9FES_IIK{Jho9K=B}!f*85`W?n3yCSEDzBjxaQ@!k)jR zD}Bxn1T%zgS1=a_G^Zf)Q>uMj)Fkd>KDIcAqTQ2DAG~PyY{fe9mJT%X2r|Z&!o#l? zpA5rwU8Hx-D;MKG1qBC2L=k~A0@Vr%tSyY2urs=s1@C}$ zuPJnLx%Ce1vd9$%B_cb-%Ib*Onv-ZS4g263zI82eWRaK;zZ$+pfKNnqD3|3UbN-0b zw}XGES`#MounzGt=tX4zWegayUZ*pIiVmt^-_Q4<_Pug7I5UWV=!|-n;CR3yI1%lk zP-r#xj2mdmhX?A=Hk;v56cH+qM2qx72U~qRBqu++LNwqgW!CmPyH>C;e{dO4jSsD= z=5DexCl}ikld&%^6oL%yEq<4D7quy+ITbhQEX#dzGc6OtM5Ny`j3u)gwly6xdGssL zj|ZR}0(q2dFp7C#{S>e=`f2keGb7JC@)hpIyKqoIDH!oXLViIH#KAW?uvN`u8caN~54W{&2<1^};E~BhbdKfG$h*3D z_I3z(U>e{a2-zl!_Nc|XL~|%@mepmdLeYpU_pgR>ifLBYgbd)LHjQ1RT#@|U++CSBE81p}Sl=^<+900%faTUSjUvVvao>o0)Lh8z= zzWzB}Q|Gm#u?fz+x%X1uv;&tgLkPYMLJ7u%!AqM4MVcSCI>rhRd?yXT(Lk85nl)B* z1ZE1+(jiSULAV_^ox9N~pL^-Qr|{j`drtmZn%psclm~Z{f~^Ey3-22Y7zL5DDYPUi zWp}Yl-nfc;{iHW$OP8nAoq*DWlTOKj3Gc=x&lpOaIQ6mHPL#%2Q~kq@p~9zY*gmJz zLBaT!Frx+NV)?}-O9W_@N4`K#-D?7d)Nuiw$MC6lqj{{G$`utgLaM zPg-DdAte$47(wD(L0t!e@Rvs<&3+{$(b~(U-@b6ilbkYKic&Pe>S7X8`pOwsI5Mvm*>rZ`I z`a*UiV+>glZQ-P|!Toka)75 z%Qx=2as9iPG-Bfu>%+Gy;2>4vxyGi*7O=7iv-8Fe2(knfQoPyw9_bDM-SvfjX7{d; zK48t5wZPCJ5t$5*s?05XOggk!l1KX=9C9poVQ}fuIG%W4n4yzFW4F&r>n7Qr>H|9w z-CGNra5t2sCqZIK4>&Bh+$Utai$BcY<4d8MTRAm2yIK@e{bUI3sw{+s*gpvOxnVgHFyAri?Rg>isY$niOeSBW_7SSV zWd{ig!UIDQC&ldqWcb~a4^BSS)koyy*5R95f)AK^n4mbt%$sf&yJr{Ld$k~^K9)sJ zfVg!Q=6JN^9P3`$bWd18GW&ZXq^Qm+zl(60UcNz4C@r1^>4No3J^kcWI-z&A>pHe@ z;#-o^=GAC+2i>vOH|C=$M=Z7%!n4r6X76mb7T~{HZb`i<9ye0c5vsnE?#b#138!pgAafpMpu=r)U4|44e)XX2kCtzA(U#&S}}9yUr@Y3yErSlvxzx3F-R2 z-qxYD>e+@vv?q4UFp#vwis0qV0d-3!?ZU{lhTzVL6oeI{H+Er4d>RtHiziL5UGGgb zT+^yvyZE8?y*aaqd*1pT9|a-f3dY9tvfE97@ANkBO_#BqLuG4}!N-Eu`93*o$fn$6 zc~Cd-Yuknlwsuq(@GQN_n!4mpkapOJY>g@0vFEbn7bLs3!i@H{SanAG$G9ysoiU{T zJ#|}SPFp;k)#4c5s4XeNtf-8=ah?U7Yz;{bkPR}+tO+}d*&m=P<_dPR6T$jfH^(Qr z6Hu!HOvDQJ_pX)EZt2`<0e%LZN2fjOdpNN1u4+uR%vu-H zNbt^>IGbDE$Jv<8b?zkoz}DUsKm`2=Qe0g}%SVYi+|0vne8L+tsD6O>B+!49zXdtG{Dg z$f3*qod6AyvRg~o0J(9ByQ%;c5w=1&^?Y4P|LpxE!iyi5lb;vbT@&(wzav|mreN^y zIfAdh(niv@QD!NEYvfR&`}-LMT!CIg>{xo{tB9sajjZg5iY(|oW}0_0m!Aud{ST(l zv|)n;s1AR6+%j-S?)P79Ltf^)OSd!VqgF>Q_8)H}UKUsG=)e$MA}h5Et40ph%4vSx zL*pY9&3-rFsLwJgmx+8G|o=B&OvA!J#g>F z&YcwPexHW-UECB2;ZvP3k_a=Yv;o8cbCIn-vq>MjcT03<&|r1${lrBHKq{`t9ka9u zTo}|_ZYGAWEo&>s=AKWqYGR}JVu#JDudF?woUpF~Jdp^3rz-OZ^jp7(SKNX-YTquV zGvlU?V$hZ%WERprSJfE8WK}Lb$hV%_!n$*!?;ZKaYj-$j6}dy8U648qx9)S-1Ixap z_im;`zrfz!bNOWcz?H-U0f~M8xA$DH+Y|qNlOYHX2>O41&-J&hg@5dU|I?ns-@6Nu z|L0q+zdmFj3PSS#&FIfB(g^-1%YQkT;t!^Y^D3x>g0h^5s)2^Is+56>sGAddBOcDt>}?%H{Y-wUcn7sn{!R!ZHXUXO|eaLhlDC&x?pF>ZMBH{g z;`DvKpAU)LWv=ul#Fql$7>D7sV2pr7*yvBF=}Ho-JEIJ9L*+q&W6i-%Jg{CQ{v#lf zESk`O9(*nFgU63R5Lr+Z^9T5fqM-cH#*RrtFOpV5cPt};lO6ITMT~%fQN2LE+YaNY zR>qUkma-JNW(#RhHW(cV z2Wc%k*9EYvhDlrL=0)Zu*N`K+lW!MSPF%JagJy?uDT_hU#B`9ZinP{9dJ57BMmOXs z3A)e|Br(U3SzsQ8bdPZAsjFIDzgcm~leP(ML5HxTMCes-5x@ zlf%OUl{KRKutr@7onxeEQqDXov9IQY0g}(`qMRN9Y#vWA<1B0rJ{(&vfiN%R6@|Uq zO{$sUkp(7m?4|reR5Y0eGcda$JK>!<$hC6@J9{Jfi*5Tj+59@0l3)&4ToHW)KG&GUJOlz1Z#bzwmcZ4JIr*j=ke5WB z9)3$bxfAI^Csq6mPb#YVJfHd11dX{E=1?xUU(~w;atjKq&wu6{djb?bY0_8$%9;!b z_idAsvf$@cDy3*!rTH#fb&i@8 zLpr=Mhlpki7WcKd5LYND<@n0*q*Au7fn0la?5T%cQ5glUjvg1$L>Xk3cq6x< zp%5%(y&1Z|)=nz!xqIFk((dAzE?CLS~M^s7X^Q3b{o*z z!cz0gGqk17JMb1)Uo-0Jdp`F1ui`7BWRJ2PUIEopN3P0wMvjy1Sh4|+7bj7FC%u&1 z$s3=ls!0KSpRRI`pU%VuFKoZ$VB15J$+;Y7Kk`?ujz-IIQ26JUO6;y#67v{zuAi3Y z?B2SZ-zJ$XAxtLH$}D0eE$Wu483AmMnKrI?r9k4Pb*A8_7bQ7h?0 z#=OYYe3{Z=Owt=Q{I%ZoR&TX)^_rVrj~m`${HHP^0)qGdTSop|1+o63pd@L)+7Qg+ zHoA3$4aHXDz=TpHxQM!D-KE&3SU#B06z<@f8? z%UzdA7he3HnwpyJDyKg;(%$Vg;@Q_zB7%d}7D4W22bEvyb#%0#dH{8q@e9gcb5&gQ zAq49#n;0rcb_&k>^TQ$a`r3NK(4vr&(V)fxj8hlmE-QYqHa!jmbx!5TTDB4S(i8uA zEY5R;<=7;wTGld)wza!CAu^YXyu|i`=snY+-0ijxu~r9PxaOEY(xl4~<|lC|8=Ete z3w%CkG%{-8XqRU(nK717Cqze&tkT?pM$fn7$uHAjLmz8)<}&#poV=aLsMr33uBRf2 zOo%o$aBcu68Dn%*c}zYMQ4R`sLiS6b7Ig!w{+HriU2?U&Sh2o9E#^~0gp~myJr`9i zm1;#pb*i~`vX#(g1JtW8Od~(rjaT!Jrs5j;%EAXz4=^zJ-15zED0?!=8I8Vshg2Nk zsCUL)8nS(2PC}>cv58Y>YgFyaSjbGj&9NxO=+ISVSXPREDJ;pgic$8;wMtUVP9m0J z$tVPf$m|wW3(nAJ^yi6p{N7QZ2i})?|R(^HcU2UCF}xz8TqkC z(@ybBW?TcCsU@we2ovq_M0R~M*xF=rp?_OtEcR|@k`CgT@27PQOLNojUgduD}+ zHkfp9Ji{=q@)3nJ7xB!WO=1nR=@FqgRtE%Ya-C0eQ@JB+axs-YT|w4Tlw-ArZpgE2rjJD~}C7S|u#VCnVn}DK(N3>M?v>S^D04L%qt{uDOxzkUi(sntegU zRwh(N@O8MLVEC+lan@2PPoXF<^K_yjFhbzhRn$oq6`SA6c1TC!w<7U*SMK?&(5mji zns<0G?o$e#Hd>Kq-=Cg~XjUH&Y@9?@;0HNB6!zORej$b3>$Ti53ZZV=o|W>9QeYWi zH{+)5-O9stB+hP(#+Exi#-$VLhwaD)!+8 z>WGQsSK!u5$=%`i-8{8^<(%CxnwFH8cby{1+=G6!xczUx-}TAC~9dypINEx)Yn$_peUd-15F?A9vm{T-bPz=J*Gb0Z!*{vZ^4<;6$z1(>; zrOH~>(Qwp0ueOeygWd8mq#hTE2V5l7F39P1cH&a4Ofr9x8DnB~_?pG4d7@Ry z%F`0(%ACTbyTI&dKcg4hT5X7tRMi35*C@HEdb*M?>x_x1t%g^vvbuUm(>M#z@z`^b zU(IkbNEu0vrzPZIPA(IXXW?!YRytrVG}e+C+2(VQWjEC{Cg1*Evz|k?LCz2==iBPL zC9$bY2EVr#i6H1;PFhDvQx`pA~JxN@PuC9ZuAhqt(GHT**E1OS} z$nx%|GAVf45Pvds{&pgn*#~D)D)_)gh+P(GDS?*u1((_VF5Hc zMJTlb6u1IB3=t&^F$dKsi2L%OE>8|<&RmH(h{^HKkmDie>i7KcFB1~`xgHljOjct1 zc~@JOLfZaVWT=dR`w&)=QPXJdtssoz(utyk7Pz?_NAGbBTNSyFL`J9K+?&uUMa!o{ zED*J$7V>T^z(?}+RmIsmVaCb#KpqJBB&LEO5)=ahpKjP3-bF^_}B$yHLFG zN1tBbD&Wz{#{4o#wnW5+0yg927|2bDK``HJ>G~t1uVI>j8GLoJcAtD(Ip?43}H$AUPEGk{BE4y?5XONSvF?LFl?YLA{BNy6Yr0^Z04qWN?U~&WIs?1@E5auQm1E zb8YIA+|!IcM=w4>jtahS5e~kr>wZqyRAIQIom3mbVZs1EQ)f=WhDEB=m5aWL8f@=u z$FY5%F=bW{X)C@uY$54{;#Z2dI+enVDj=*5%!9fxfVjJ#GGCk} zxeL8~MEMQ?lL6cGA`+0kNb$I($Kaap1|vJb3`DM5ehGCfw+0det3^e;o3S>nA%uJD z6jZ{kfkd0f5b&41BsRv@un(9$qubF(vCaGD0bBSG=LI#4f&63;9k8&9%?yFFAm7OD zbp*VqLpo5~Q8!m>oGwq4;EG_FK)k~?ciQJ*vwdPxJQK{n6<+`$ei3YT!Po65ZeL>K&pU@Ba; z&s-QO%nzXoT~I?WjB3z(1I+-8s+etke`1vO2=gXb**Sa~YRSunZXc-&AE76{?q5B( zq(J1>}Q%%h5(P{DUwdP>;75Ka?yU6ngW7n0e}=pi4)^=?g;5@@}t;bbCm6$~qG z2)vL&^jOUOcM=Sw>I4KQmSCqY->x9=%WXa^K0IWf;6z zZ-CHHL-&Rr2s6HiW?kY2AvW0&LM8dAB5UkU$rA>WV;a-858OZ4<3(sna%3K;!3raU zA-Jpgz6<&TW&nPhsS5N7WFX674Uq@z{u0D5b9>!E0E!2WXQ4ZM(sFZHA{hg}nM?kQ zt|a(k@j63OrsLhfWe^D-__TmbEc91!jBhBf=W}Ag?X15?HCKLe5W*w*$oXUThSZwC zetX8=%D5S8L{BZM3pBQ2b4mI#PupA25cEPE)R+G_8tglX%m+~An@go`7k}j1!9`97 z>a80FHxqO>HlKuDJGk%-(aYTi9_VHwpaQ!=if^QF#eglCEbk^Ou_OpLNH)v`bb6X0 zbVdtBwP{Ov`s+Wl7HhD41C!yM@iCWck7SR4)^*y=ww_AVH6i0+h4dENO^i{shKtxs ztk4SNxGP#AhRt&@J23CWZS;nF@z0v-Xmu)#vxc(NDaG%NaGyDR-Ct#^M8R7Nl1eaJ=wB7Cj$_ij%5$qEc?@GeWD zy@h9Riofv)F(SBPi~lM7wp94OtWcl-deXR$;6E)DL_wJT+oi(49RU90WclxhPjr8= zR`}Cnr+<9x(9*FlkizmkHwbt_i^J(Bv?;x_^vW@FEnHt{=TQ$yJ%WfqB$CuQSatI+8bsgpd7SWn5>oN_ z#BLSH8`|>|wT4f}CmHOWa@lO+Ktp&aSPpGttYQnfZd7u;FCVEW49w6MnyALFt^)=f z-217$qTw(k_tRj{a~Sm(xK`o)hIj=NBh*$uO!&3aDZW%-){*0NC&vd0OsDtb--c75 zrN`xyz6TMQUPz3r!3_kk_vxJIgR&+9oi7^2b_kA)&!3pVy}TboO6Ejv zn5o;B?<3bBekAWcML^(#Z8y|&y`DI6&2Hlbk@8r=rjl-RGNo;RtT<>o5vg(JR#X^` z<-Z1MY+5&(VZ$5^53NsQNIW)iB(8LHPmm=`ev2BE)gx= zZn!x*Gn!iBmFGR?y##;1hy%*>$e0YBt(b{2j*Lv=>xMy0p91|8g2%~|SrZ(|&1(+| z4klAC7-KQk`sP_a7^;mj-!S|9k?;&soDJ@Pj6z_R#A~nCv{xCyeVEI7MQgc#kH-8v zkE?;jLl|v+^X9NkUhi06y+~jc&?@pQ<^Q$z7EpC$OWQCIT!Op1ySuvwhXBFd-642z zcL)S`f(D1+?(XhRAn1QGbLS3|k$d0!e|xQ7-RG>+^;GTMeY$q_u6h^*;SFrKnnRPM zlngm@#pyr8ttWM9nL|-gJEUYgn?9mwbAXn4mva}jgITUK*kby~^1q9)j&kakxgk`8 z142KtI(RlW9`@Oc9Tn9WF23BIUNYHkrs%L#VD>2}2vSMD=YCBOvm>>`z3Er2x3Y>@ zdU{yfG_MZd4C}$A8Fvhi@M%vuNwF&b5Nw%W++1L%jxN95lID_}xI%qDm?7XUGVobY z{T^5oWv$i57ozues*0en^kgtH)SH<=2&Xe9ZHRDbPz12VJ%DR*(cA~fDCs!&b|G}< z3)hMLW(=02kEwjRpG2^~^^U+{-A4z@sSHdEsu{nm0~&hYP|7(u-2P-)vD0?;u>fOH zs8z?kh#vffz^=*Q(X*3NHZOjOU5YEB`6Q{YYoLiXk^;q#hx_y!PMX14cu!}7 z&!4*=-%QDtKwj)jKI#qKRxnMDS*h#@TQ_rl<5QOPvl|~`{}K^fNwD2!VE7uB@XVxm z2Z`U-i^Arths5#K&<&UjDH*vrXL8ClT_UKt*qX-5*7Ja@@X^1Tfedo7?zYFs^u-;XoS@3PqywjPo9MC)BiY=Hw#GF_ZjZw9}~F zZR0B4SqN;{d*s1i-><1FyA1lh9W)Q<(!35fL_`SSz7{uVS>uDHQeSK7&hI_n_Q1|p zC6ArVLT@-#<&#cAz!ZxwrWIiOy zr`<+XNC`;C@^fU=qwjGKII-Pif8f4#%he-+geQB*B!x^Yg%h$Cwt@(~-gylZwf-*+YZYw5zzfmmF0^IC~wq$1{R zd1s9xy_JtYBiA5L5{ePY3(eHHjo9*|`gP3m3yxf~ZTjuCNt1{@CDY%!^Aj+`_Yp?m zgH>-bLicQr-vggZ5Y4#3L28WDE#Thdk;OH(C&GIJB|?2S?__(5^V8~sx7Kwc#A7sP zVJIdmIm@`Iq_R)C+B(H?St>nuJb@x#GT@2ppPoQ@%Mcsi8CGO@lt}-;fNQ{zoHm$a z#GfR(fB1z4W7sHW2`g@{g(?XQ-9h7MWY~!0oBREo-|H!vT=y=$Q%nzYLH_0KUV zWKO8Q0o~U;(5yE_{jG9QADk9AyNq{_jXs)$VUblx;>n3G&?okwR zS{>_J$ma9bDki_~ma7)*B&|g)Lynv(rZ_w{w?2?Z(!B+4%~TRhY4%;!$2KtetJ$@E z_z_NEP#%OYyt}#29J{%5jk?=lPseXZhB)^!Q1qSoQGOu*L zNL~u6f<>LFOy)>uYLmv}Zmaa?o0Sr^n-wa7CQjvp#{-|F)0KO#)s`D=om^HF6qNL8 zXI`y^$u={U(Sav8-b1xAjq%m$tJ3K~WH22#d%+xqG~bGke7RQ;LhkP@d@QFzqOd3{ zxD|6{{d3oJ&1sDGh>NU^oH&fGtQm?bmNO`3G>?%ac^Qj0ci+|9J9E6d$`3be0@Tpp z{XAnCC6#1a&3wF%$m`qr=uIp(TpoeFu0DmouYHU^MNCO`G?9ATStscEmDimGRhr2s zT`JL>`8f8j+pc8(BO}^u6Dn~cY=9@($G0h};r3R`Cn8@|t!K|%c}_$E>-u|)SYCZV zyuiNwWa6%>FZ@9lGrNPrPu-Cs5l3*+I8-JIdI$>+?I<$~1rrBJCdgdA>0IV~gf1E@fyfiY7S0p^@!>DYm*gD7xEBab=A6fg2?Lf)cD z`407Wjv`Sx^2NxJjmRns`_3MOUqcESxw#d|wg(rs!$^a<=eH9X$%Rf*5g(beaj(cYYq=REFSk-+c^ud^w9*Sc zvidB>Rqa*kI|-FYUz!akSqap?-#4>ex&WmHdhQ56Wmr5cFmBG7x8hzbYN1@nG0@7T z{;H1K3g=C=;GFZMTt+AFMKjM*#@nlM>_nzJ%%wl0RRA+-N*7*jSycM95t+HBD2H3! zs5)|~(2B<|3jC|!F*c5o@w*eYiVD>fZ5GLneS*i22}kfl>gxC*hwL7E~_3hwSz_6S@R zp+OUWkxNhS)n46*l{-^lc+$bpPOUdDs{BEbaw#(so(Ll`N?7L6Vd`Q^=YUJsqhHyR zUPi^3nKhD|{X0*Z@to-}k59Z=Cta4kOKz;}lB1Lsn}P1Z>S@Po*2o6Vr_dt9e zlF5WD3te{OUP{pc>t2{&Fq*MvUs^BxE;G)!Ra47!)6gok*>Wbt)0s9)C2aIc$Yamh zTW-NIv4^Ty-D*plsI(Kr%3s;1Z?kkM@?n`sFKz#Jj!a32ue+*r9@=v8V|)UM9@4j< zp>xF)+HG8y4dvk^8=LX{qVe+8(yuADK_R#Mry+-E?Q!I~IL@;=A3@-yKUGG&VVY%qWQEictX&{K_v7P0q4nd;5%#-2X{Y|d} zb3`(gn{&H{5i4)tJ5I?kkM4Gfh=slJikB*MAj-_DM9`(U{}tob#XJ- zi#Nybo6~N6k40W_=n@Tw3gewo4rXBcPJ7)1j;Dv@NN`GPQTJ9>ugoyr8;Q%*_FEi% zok6?kb+QpxYnljz(j}w$Vp}@^io5G*0yF49!(GWEku$Mz2U69yCoz>qxs+%zc#)3y z^!Pd<1;}Ls$4ZXPCpELXH#@_n)G&oMgI~qooeBjz0yIJJ?YY=S z4+ffTo&$0WGN+u=oK1Ko-BI^9olG0y#H56D50%JgPGg5-pKjSXx{lMXU9ObWg^;iF zwvA(UNqBO1{(9l_QW;B?Oz)lEs@8((n{vx0R|O16yaOLWRg`*njxnO;$Z~m=eR7cg zNWS~PxD`HAMSiO8!m)gE)y|!4+x^A$2&qAR#M1z)ac{JFDaD0+hQkKFRC@?Zcn10q z{6wx4wH&7NcAvo-9@k%+Tqa(dWAg1uev}5Y4-@%8b2_tNpq zEQ5$iTj$C}gSUr;HHzGu9ASyH&^8+eIIm|{!F+nzyZiE+>8ZQQnF-BRYl-6>#@Ke* z_)z3FE)E>wh-)j%4b>uJEyr`*;bIrIBhTogxWsiJOhsizj!C-cj$MC#t0WsiHW zLlc9aLFoe)pZ6G}8~0XZJV*8HFp`@pa5V2`5wOrSD7SwR_FCQ>YV=|_T|ZB@e!eDB zVn-6zsU>)4#jDjBRKbt;+A?DAq+{HcWG+fr4zwP_MSpH_cRJ$Q+)p!FE?|iM!K4I6 zIJYp+JVq0co=R7%Osh0F4Lx0^TD&<$#iLd{ec)OvR;F@z_O?P!fxI|ws^8q7OElwy zZhq0^H0Bt#Y>TAIR7$b9l`2eOn})e47ostB*j#ZAOPX@BaXng4XK7n997@5w;TdF9X zlEJI`joDqfTw`+S$uboTFJR9wV_V~ICxsG0*x#!mkJ2%Hfocv+#{RY)>rHCsfyfP3 z>U$$`=aTykpR~2D4`)f{%zL9n^oYO_{+2ThADG=AUh>YO3xXySjE`o&Li)H((la;Y z5z!syRy%DKVif#Y#r)JJ_9;cVt+?x=2Ba6JeyxO|%zeUA5NbdbrCPN?gnHmLauQlv zoofg9`DXVK!4=!t*DidtuR$zn!gUEDbSj4+^q&v1K>4~!M{23|he*G%Zz1jKL!3V# z)xj>kZjERt3TWAvs(yV2oWy|+LHY=$M!`L&me*NvzM)sN!%a65_)Hb5d!a8#*(FPA ziNJP8FhV95UrD?sXA4oK*jWxVO8QKT`snc`H;Z^$SF(9=gD&;c8%9ePsXL>*jdUNS zi8k8Wwu*S{@Ddr60z|c@p_tKE9w8RmGQe4|nc_U%8{gNgC3R`0b%mDp^$Yyaed!ag zum>Ahzf{YlfAerXZ)F-ex_8@-U}D66e2=~wyYPU#I<#xFF@X~AurUQq9jO^zWkPIF zPo(QH&E^#89=j0vu$X;!R}s>m`)kJX!&FJf!VE7s5@8Y zsK67SlD4_=86Aa5$^g@~m8ewps2m3pP(oKo=$cI}YuhuQ%pb6wz{#wKTj4 z7X`8-A+tHYMx_iID)O%D(2&cP{`y=$V1z9f8#cKwhBwrq<% zfSa&3x_kt4rvKs9#{TA9xSg3ZRyUiS1!_|+w6#b%=WCv_)Md>K;{tzt0v(o2US2#8 za7=YD@5-R#GRbOYN_2!a?Lo9DLU2mCl8%bms_SbrXS7?%%DeEpH>tFrXk;Bl9<@_5 z;O?@_tqJg>uT~E{AhjjhcJtr^HN``9RCT-61+;qob(D1Ft5HAhx#}#+IkuEd@D7>rR7uw%qu_vKx@ah;F3jse?3~73$;zWk?)>WkZ{+)kAJMic}KI}Y2 z^RTPR1^JvL)y5RpvgU+leDO_puxTx7i#D z?(O=QSctX+-4=FZ_lsK5f!pL9QS6(SVeqRVi_>bqTGgU1Jr&1L4EXl7qIGj-NnyBs$|C24fm6B6>LwFbWhk2CTDHy`c(XcdQKs!Y3=n*Gy*4h_~XgPskheA2fIyvLX{bH}R!6bX}_|b>2RFaF;7BMfWOBL_d_o zq0mEQ;THI`-Mt;#5aX%`#f^Ie;Unj%2IT?HBavbVqUJUNs%S~x`YAw#fE>IX=6h#- z%;o57@kSPN*#n}$gxka zHm2$7&~WmrOLE)9csc@^VMnIdR%K9-A8sYrlJjixHrq)$^wcCSiY4_O&}ETu?L%${ z-qgXk2Sl=UHA&CYt|4M&=ANc>4R0UEy28`*5rjfmu0TXiWT(-426T~2&nqDqIQp)2 z*GnwfgxZC)ZFK8O4?iVxs1MZnB0FYg&~P28#ECwFZwTidy|P7mm@+Is_m91WJ*Un) z0;~@HxaPA9Giv5Lr`QxAd!@4x$y5i)10Bg8B4{OXR|j5RjT_q#hh7oi zRpNH8DcRhy<9)_9Fz2}*jNRqr7LSf^gD|@drL#oNdQW0!gjoV+15uGhG%~V7nS)2L zqSTWKsw# za6{qyFSd?NAf8xuW?*T+t+5;G6A&S7H0)9CYJN^yHsHE&a0AKf;ZsOEsrSRLOHKe>V3lXA!gKkPrGh3_{TbvJLN(#@1?bZbgp zA@FN;=m64@(vVhyb8ll3J{GmK{kthHqjiOxV6x8%~Pu}yZ~vZL^B_oAd{ z|FHM#bFvVZ%Z(P2R`cm*%#F?F+kmMfhpzjjA+;3adS4%^eQ}aU^scu0Tiw09e%Nz+ z#Z`&}WbDB2X~g}<5&DlE%F*dv7NpxhJ3nrSOz0x9~y6ROvi)Wt#uZ~lc4?e~|`6~O+pI;^%4NJL)B^%2x zdO|+na*;{gorYwU8Ah!gIAh7wG^TbX4u;%> zGk@3Zh_h_Pf8e!E%FPmeKrea1UF#q_k9Dic)tw-@?Y5R|)nGEAKeyPN!+zJu$}9h{ z!|{#pR*s8b@@}!~nRBFhazzv!R*N6{(b|5IY@GCoY(GolnxB}i=A`e&7gdzO$B$^B zfJ*<`sq4OEZg9=R4Ev3%{8Lb_j`d8rzK!qctE8da4%I4Mzf%ac5tM7FuKVI4nEO|2 zn9OD0Rsx?0T+4GuiJz*wJ{_n9H|);nq8QSfNWbe8 z-8tYflEEWKRWMdxAUoEQmzQ*q;`Smpin?`O-G%1*P)$FlyV}8V!{5R0(M?-cL3_vQ zz_EngfutF%Jvz2q?sPh9S7S_K(OX-i!E5<9V_2fuSZTT1l}znIo_AugDTjSx!mS9M zolbpVx#dcOXy?!6Ot>ltsrY<=|4~bz`gULsKK4nmy3(+_T4bJ8nx!D-_D-`7#uI@j z%hiRp3Fg7Z@gV*RZ^NRDUHM8bZ*&IV_lc%OFqPEZR%Ajv_l|MHAB|t!=Di8j4nIXd zyyUblN#tymMLR#p^C=A<_JiUZezMl`$YoIPOeQlJZPdfL`gZ>Ow-ZDMvj|b@6|3Gt zMYY}Lr_OuyW78@Vr-a^98|eF>Q$(;k?6D7Z%8zkqHW2x2>nM@37>5I3O}NNB zbavk!f$~W7sSH@`!;Oq#0E~G5*L@(cCf0)L{j@QaS+w+M1l=S{kqr}zH9oN!T%RsY z>QCX&Ql{6h2IlA3%bW1=@h>XgS()zg3H_|<-lAQT zgYL?mS!P^SrrKC$JU%CNGwTpkf1(Zc3QIQPH5;Cu!qK)xvvwf*x|bz^SjP`vQkXi=q|WPc!Xz$a(*g6LTz2NvH)NBd7#~o&P@7LlB7Vf4!Dp#+jM^ zeO3Sa_4YnSVa;Ow_i3S#9>gt}cCVI&Vew)yl;8bK6PAMGg!Z?2fO_DF?_ zUBWZ)8aVm!T&;y!rukX1;mTPGz8zU?mQCfIQ0OkN2Wg^a)M+@Jqa!cTW|f|Fo0wCu z!A|vw&ghpV5flN(PkqLrHnEf;bR|$yLPeS8y?9&xJ$UZ$3n+GpP5DgB2l0T-<7iwh zO;c-@VC{F9T};b$22E)JO~nWzo|()2oSM7o>TjGyTT~YVk}5T<+_cQ6#$_g(XhlK< zR~y|n7jkym+IPN15(eEt)RiuDA}IPGe97r?PM!`9LLkMdU|1<^HG@genVP?s{B{Va zk@S%`#-I3TnGoXOPBRULp%BYrd~wQ&l&!OzbKtengT1&dD0h!FJlK7spEEALiHk#J z{W9#)v~@W$p}$*>vx65pk83?}cwtJ(XOsDPJ!8z$z!dkH2XqaiIn$ciWKOxt8+#yE zOFOf3UspkvZd|#5?LfJ-cj9>pGGWa$K)Gy~N6eoM?2-xx825{JvJAsKl5MwINn}SI zqTdihDs~ikwYM~l9?8zWg3q=FS`ZQYz#6hoWS~SUrT~o767Sz1c%(B}^9d)4SqObJ zk|iqEo(pQ2f4JNe>?`#g`M&MgZOT&2$8sAyFFq*mMBV`*NmgWWG49oAImkDgP^|=0 z{f~RY$OyYVI2sLc34WyxuRM!O-qN*^bL<1Xvg7-}Bs1=yHC_r9lSnx!+)ZM~FqBqF zj!AFRu1p(`R!AKP>)f~rMDzxj+&cBM*+phBdQ6e2n|8W&HnkOmjq3(JXyNzWd6(*p z-b&HlontDEm=6TP9o*T>@eTIbIetN((ptJhXG0IGKt`{H5R(F-2P}y<+hPWJ<{>f< zucB>g0;sFH6D*XZ(z6vS9g)TatF(aY>R|?2+R>O8bG#+Csy>A<+clzcO`L4xD)tvQ z)9r+qNgkUFw{;~3DWJ+WGOdlf_BdSE_!rZ2Czefg366hu<#zNfOqwquN(fT!H40&sqJg(KdCukAT3smZK(};D>qe>b;GNN`1BP{~ z%?<&V?>$WQoKQ08oH}|Cmyk0LB}amp;uCcKQ#lFw$hOtX(#F9uLtZAqke8fCqQMPj z25F^~xz^{~1ZVH1zDu){EH!-sQl{MuSMPH|mI()~B1usSFr?RCH%;2w*JD8@_8PAE zT<>6g!{Wt2>Z$0S3unRv9p|#w-4)3DVUh;#eWrIMf<|0Y28k16j9hYIj_(sq1vO+3PNQxf7i+{;PY>HL4-fePoL|+^~ECg#T>TWqbj~?)i9vYLc&EU=Y zaIMTfLw@fN>G}4ClJ1@(BaeOp`HF;BR5u4uqWOxJ@|!k_hZ=pvKC}7*`NNJBCIr0n zj5b#`zV9W!@Qs}6jpD`w4R5dKA%TZ zn?%L9+Mbg;ppL{mYq=wniM>l3=lsUcZlmb~W8RhnoE@V0LDLv3tD zZx6oElS2Qn`V;n+JYXTe6V&(9=z@X^cPpy4rzK^wIo>=Je3=v(-(`#u~sPrVwu&zNsAGe#)T3oZwYNS|HaWMv#2 zJ_kKihgV%C-A&U2p6?Klcxogm(g2nX4}@Qz_kT^z#s-A=-#qVs%Ypra>k9YBhx{+U zA<%teWME@&DCT7S3sDRgfE&d?&(;AzR6@W-ApGkQ^bh?~^a=H;0coc{w!crE6#ydt zAN70bju8J*)4z5_{tx-H(m;Q~sHG1eZZgybkmUSetb(?CMWIb$y9#3egXuyVJMUaJ)+&^)WSTf`~wV*=1*S~MsHL0Hh zjx<~@RN&_3Xo~yLMv5bW$@>BHb{${Z3d%L)joACA~Z0$+9LhoYMh~uoN5Xsb(eul86YlFV_5kmM-rpcvNGCGP+DB+$Yg$T=Y}0-MPfzD*5DBB^fvS?6@)ciXAC<8wlHxf?5<>iysZKbQUBbT|}mZxv8j=r88@|~n$O3w7z3%Sgm zi%z6-&OUQG7+*9Yx|^Gtnp*2sF*UvQ;%$3qXsl`Gcs^2ea{IP28SF_RZs!^TKdJIy zR!LUmPP6AOH+>AAVI5tK^FW8NdJQs@tnsa}R6q!NaQ;wf)glQ#L;r(n1e z-pe5;^;XL>M0a!}1kZ@rcO_p3Na}}hw%!t0lv^CG!eq4D<4)>SrN3j9?>R*fwK(H- z#F)qMq#q7EDkU>Z@n*GcaW7do;&NiFU2u}Ik>01z^$|fAvOj+_fJE}yHA$t+xpziZ zRx9PRyk;W%ZLcHcR`G)hfr@8b%5%mg`IUMf=adQ*!!6&zh5%*dO6?LGb&J)P);8lA ztvu3tOv_X}YNxv7YDJ`>ad_{dVG!o%krJl1O~zFh>AKn=L%~6({IED=(pBg=*WS`L z=Ghe8*hm`DNS0%Lld#hQb|T*^4OKEloP zKqsZn0;Ls;#|xM*9gb$SOS`C=3M*7VgYynaU0>3V;jnH8RMkw}j-#*#t-w~T?&sMO zO3hp(Gh$RN?}tPec9}vXDP4gpg;N#Ad=nofgrSoH+3@wsu5;Fp44BqgQNF~AGlUU; z+O`MzLE_bL{Q-vvn@rQ7r}Di-dI)drsS9uRrxw)yw&#lu8brHTJcb0Cg6y8mHW@fV zg2;*5fY?+&)+zhRr3)hgQiyFK{o*FpI&03z?NS4dS=kCbgCG&QicrqJcCAxrv%)7v zO?gPoPCZ6s#CPn`$mdorN&@02d9?M1W`fATj8;)85oHY+>VB>VRgH^cP;fzPjk$-+ z{`dtq$g$25OMq!peaM$}O}IcJ-TfdgAF+t&$_ zepD-tOYf5;@mbdA_=@b_izi2n@zwm!sYae-@EQv?B{Lkx^-9w~cGf^o6-sIneh?4K zPO5~+ozHlmC>wJ|&PefVF|mju=pD2}=UTyjSUk97U^mHZH3CjP?wv4L;bLeIs>iv+ zyoFY)!ZobdN<}dh5`K=E1SN`*`ke3%IKQ;KVM!R#=(vK>tWCJHl(FB*il#-Di>IOA&#dj%7KW*r8Ykn(F2W)Znzmjom#MGUpUhG8SCLg(52z~&;iWGarQ-1`N1OwVAN?4qJ%xKF&0I2 zB(RP(HD=cy75)0yF*a@xqj5ksB(4AMo+rM1eG<|g6_u3}Zh)COn9RnfM7BfPZ*E)} zh%hRY&aXSvgUH+=QIg;z+G}CGn?wje)7w86@3dYM4~FugtDS%Z@_9G7en6@Vb$EPT zq9q%~%0#s1xM-m+g!t9Eg-#0+bJR|AAlxJ;x!7666vd-K0zR-BuUKE92Zhco#c6jW zqc0IN_HC%cf!OGt3s(NAXmhN{kAzk>M_q~)x}sDT5&V3f&c)>{(ZILkVBSW@#T~kDLx^H^u6})!y)pEoCYE`p?%E%k^8v2A4uxtlWUpJ@LjP@C4 z>0{w>=o%B^d9)U&Cj%>oOf1a0mhcQ zeg6>bKGZP!I4Y;(%&RhO7(WD}lWf%yQp22O$~#M!EYpS^HR{Q3<3axqrlnsxK@y`RqEX4=q^cYq26fy@F1u?!XlaVeYYNIaWKBYW=y zfwCOocIu3Qsfw2Vr_6c(XxHX!S{N*6luQk0c&71-PEUkR%zMZZXzP=K|VVTQ< z!hVPKj=BFlm) za(Ntw1+9+-H_Qj1tD#K;w%i7xm*WDN;+t3-=DhvBWUUCe;OjmpI>V?Q(x9nJpC86Q!UGAR5WXB9UAiqB@$+HF90E)7Um3V53II8PnRv z7^(6)w;iByFZtMn?q~Z2@W5B0Z=*Odim6p?%08+et6#gYG0!w6F1JJZanLf}R@j$t zqS16Bz*=4E4&t9b2}{GaGlaL}&B5TW-}IhA!9q=VCA6K+@3a+d1yi#J)l>U)jb^W_ z^3-2M>9M-;W7q=mbr(ULJ3zkPyg}rE?p@>1@d77b1C&}RSYt`|tG0bT@)g|{7&1~j zK>C&!I%!?H0Yr@u!VSaULc!P0Q>RNFitHVdPr*A)9qnN3i<6SG6N_8$QyU>n>-ANu z9Wj-(Uuas5Qo+ugE~I9}G}k!C^QqaQHpez-T&=Xg05< zb3MQ+{5r93UyG}*d8h~CfBay;`jA`+?qmw&PgW`nQ$gz*X6*^VRui)VjLy~G@d_Vg zuM^yI!^`o8&;GJayzaYvfC%_VA23gEBG(=@oSxB} zSJ`u-PSI~dze9*2=~0CWa%i~uclmU8WGQt$quc~4B`~c#!<{DG>}~R&?lZ#N+SG?V z1OITA!!19>eS`)AO2hr9Ngp;4`k$O-YX?UFH>bm|%*@Dtv$p0+qB^Nv1ALJwhKK+%X-^tH^mi3ZU z9k6KsMMDBi{{BkV??thGBQ*b6=1cvEUq}Q3O86U@zw7r4Is4B7k6&<}6aGOU5}=F! zO37d12Y;vEOS0A<{Num3v_FaajW+#Pm42sneMvba@VfxO5hLJ#;EyW(6x;Z{1OJrc z`5VdW&uA|Tze)W8tpff3KicmT^4~iCugl841Wsc554gWZaDEH0{cbZLlJgtNZw$FF ztF;Ne1OWm9yrBNazWfK2pRM|R>|U1j`4{o;zfjx%6W;H3{6<9lTSd5^hwNoSu;3qz zg5dfOM*Tg=-`n&u^X89z$Rze(!~N-c{5|LRXY*bzA&~ilc^vZpFYx}}(w8rM|2TT9 z`u_p>*Q*7Gx;D%b+oU!|X*%v-+>#e!I_J1hRj= z%U(ui1zsQk0dZUZuNXgD^!q*bGJN=>3ghpM%jUmA`$O#a=izzjZvpP67wwhZ|6Tge zR{TCLFMTfT7bt+rlAeFYdEv|b-1#rvC)7U(ivN?Ke>rSFi+Jh%0RHq}G$gN|b@{i~ z^|O$d4hqQ&5kNqRe-iT7Zj*vEI5^F}&cu5~`)?UPYIOc+zwKuM>oI;b4)0(0fb#%v fZGeE_gMa-D@B{!Z64(!44sZ?z(2w8$`rH2xT0RHC literal 0 HcmV?d00001 From 733012373b2e5c6b690425d4f61c76e6dcd88e3c Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Tue, 6 Aug 2024 16:47:02 +0200 Subject: [PATCH 6/9] update --- src/core/ml_metatensor/compute.hpp | 80 +++++++++++++++++++++++++++ src/core/ml_metatensor/load_model.hpp | 18 ++++++ 2 files changed, 98 insertions(+) create mode 100644 src/core/ml_metatensor/compute.hpp diff --git a/src/core/ml_metatensor/compute.hpp b/src/core/ml_metatensor/compute.hpp new file mode 100644 index 0000000000..bea9df4e5b --- /dev/null +++ b/src/core/ml_metatensor/compute.hpp @@ -0,0 +1,80 @@ + +torch_metatensor::TensorMapHolder run_model(metatensor_toch::System& system, + const metatensor_torch::ModelEvaluationOptions evaluation_options, + torch::dtype dtypem + torch::Device device) { + + + // only run the calculation for atoms actually in the current domain + auto options = torch::TensorOptions().dtype(torch::kInt32); + this->selected_atoms_values = torch::zeros({n_particles, 2}, options); + for (int i=0; i( + std::vector{"system", "atom"}, mts_data->selected_atoms_values + ); + evaluation_options->set_selected_atoms(selected_atoms->to(device)); + + torch::IValue result_ivalue; + model->forward({ + std::vector{system}, + evaluation_options, + check_consistency + }); + + auto result = result_ivalue.toGenericDict(); + return result.at("energy").toCustomClass(); +} + +double get_energy(torch_metatensor::TensorMapHolder& energy, bool energy_is_per_atom) { + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); + auto energy_tensor = energy_block->values(); + auto energy_detached = energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64); + auto energy_samples = energy_block->samples(); + + // store the energy returned by the model + torch::Tensor global_energy; + if (energy_is_per_atom) { + assert(energy_samples->size() == 2); + assert(energy_samples->names()[0] == "system"); + assert(energy_samples->names()[1] == "atom"); + + auto samples_values = energy_samples->values().to(torch::kCPU); + auto samples = samples_values.accessor(); + +// int n_atoms = selected_atoms_values.sizes(); +// assert(samples_values.sizes() == selected_atoms_values.sizes()); + + auto energies = energy_detached.accessor(); + global_energy = energy_detached.sum(0); + assert(energy_detached.sizes() == std::vector({1})); + } else { + assert(energy_samples->size() == 1); + assert(energy_samples->names()[0] == "system"); + + assert(energy_detached.sizes() == std::vector({1, 1})); + global_energy = energy_detached.reshape({1}); + } + + return global_energy.item(); +} + + +torch::Tensor get_forces(torch_metatensor::TensorMap& energy, torch_metatensor::System& system) { + // reset gradients to zero before calling backward + system->positions.mutable_grad() = torch::Tensor(); + + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); + auto energy_tensor = energy_block->values(); + + // compute forces/virial with backward propagation + energy_tensor.backward(-torch::ones_like(energy_tensor)); + auto forces_tensor = sytem->positions.grad(); + assert(forces_tensor.is_cpu() && forces_tensor.scalar_type() == torch::kFloat64); + return forces_tensor; +} + + + diff --git a/src/core/ml_metatensor/load_model.hpp b/src/core/ml_metatensor/load_model.hpp index b749c41a3a..e358ff8162 100644 --- a/src/core/ml_metatensor/load_model.hpp +++ b/src/core/ml_metatensor/load_model.hpp @@ -45,3 +45,21 @@ std::vector get_requested_neighbor_lists(ModelPtr& model) { } return res; } + + +torch_metatensor::ModelEvaluationOptions +init_evaluation_optoins(std::string length_unit, std::string energy_unit, torch_metatensor::ModelCapabilities& capabilities) { + torch_metatensor::ModelEvaluationOptoins evaluaoitn_optoins = torch::make_intrusive(); + this->evaluation_options->set_length_unit(std::move(length_unit)); + + auto output = torch::make_intrusive(); + output->explicit_gradients = {}; + output->set_quantity("energy"); + output->set_unit(std::move(energy_unit)); + output->per_atom = capabilities->outputs.at("energy").per_atom; + + evaluation_options->outputs.insert("energy", output); + return evaluation_options; +} + + From 29e5f2f0ab8578ad3de49d4028c9d3d4329f6c4b Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Fri, 9 Aug 2024 18:47:56 +0200 Subject: [PATCH 7/9] metatensor dependencies cmake --- src/core/ml_metatensor/stub.cpp | 17 +++++++++++++++++ src/core/ml_metatensor/stub.hpp | 9 +++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/core/ml_metatensor/stub.cpp create mode 100644 src/core/ml_metatensor/stub.hpp diff --git a/src/core/ml_metatensor/stub.cpp b/src/core/ml_metatensor/stub.cpp new file mode 100644 index 0000000000..67eb68ef77 --- /dev/null +++ b/src/core/ml_metatensor/stub.cpp @@ -0,0 +1,17 @@ +#include "config/config.hpp" + +#ifdef METATENSOR +#undef CUDA +#include +#include +#include + +#if TORCH_VERSION_MAJOR >= 2 + #include +#endif + +#include +#include +torch::Tensor test_tensor{}; + +#endif diff --git a/src/core/ml_metatensor/stub.hpp b/src/core/ml_metatensor/stub.hpp new file mode 100644 index 0000000000..e6828ec9c5 --- /dev/null +++ b/src/core/ml_metatensor/stub.hpp @@ -0,0 +1,9 @@ +#include "config/config.hpp" + +#ifdef METATENSOR +#undef CUDA + +#include +#include +#include +#endif From 319a352c4cb80eeaae4d330650e257cdd6c70a5c Mon Sep 17 00:00:00 2001 From: Rudolf Weeber Date: Mon, 12 Aug 2024 12:55:57 +0200 Subject: [PATCH 8/9] s --- CMakeLists.txt | 60 ++++++++++++++++--------- src/core/CMakeLists.txt | 6 +++ src/core/ml_metatensor/CMakeLists.txt | 3 +- src/core/ml_metatensor/load_model.hpp | 65 --------------------------- src/core/ml_metatensor/stub.cpp | 2 - 5 files changed, 47 insertions(+), 89 deletions(-) delete mode 100644 src/core/ml_metatensor/load_model.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 662ec25815..2cc06edb38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -605,27 +605,45 @@ if(ESPRESSO_BUILD_WITH_METATENSOR) # expression from `metatensor_torch` find_package(Torch REQUIRED) - # cmake-format: off - set(METATENSOR_URL_BASE "https://github.com/lab-cosmo/metatensor/releases/download") - set(METATENSOR_CORE_VERSION "0.1.8") - - include(FetchContent) - FetchContent_Declare( - metatensor - URL "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz" - URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef - ) - - # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146 - if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS "${metatensor_SOURCE_DIR}") - message(STATUS "Fetching metatensor v${METATENSOR_CORE_VERSION} from github") - FetchContent_Populate(metatensor) - endif() - # cmake-format: on - - set(BUILD_SHARED_LIBS on CACHE BOOL "") - set(METATENSOR_INSTALL_BOTH_STATIC_SHARED off CACHE BOOL "") - add_subdirectory("${metatensor_SOURCE_DIR}") +# # cmake-format: off +# set(METATENSOR_URL_BASE "https://github.com/lab-cosmo/metatensor/releases/download") +# set(METATENSOR_CORE_VERSION "0.1.8") +# set(METATENSOR_TORCH_VERSION "0.5.3") +# +# include(FetchContent) +# set(BUILD_SHARED_LIBS on CACHE BOOL "") +# FetchContent_Declare( +# metatensor +# URL "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz" +# URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef +# ) +# set(BUILD_SHARED_LIBS on CACHE BOOL "") +# +# # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146 +# if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS "${metatensor_SOURCE_DIR}") +# message(STATUS "Fetching metatensor v${METATENSOR_CORE_VERSION} from github") +# FetchContent_Populate(metatensor) +# endif() +# set(BUILD_SHARED_LIBS on CACHE BOOL "") +# +# FetchContent_Declare( +# metatensor_torch +# URL "${METATENSOR_URL_BASE}/metatensor-torch-v${METATENSOR_TORCH_VERSION}/metatensor-torch-cxx-${METATENSOR_TORCH_VERSION}.tar.gz" +# ) +# set(BUILD_SHARED_LIBS on CACHE BOOL "") +# if(NOT DEFINED metatensor_torch_SOURCE_DIR OR NOT EXISTS "${metatensor_torch_SOURCE_DIR}") +# message(STATUS "Fetching metatensor torch v${METATENSOR_CORE_VERSION} from github") +# FetchContent_Populate(metatensor_torch) +# endif() +# # cmake-format: on +# set(BUILD_SHARED_LIBS on CACHE BOOL "") +# +# set(METATENSOR_INSTALL_BOTH_STATIC_SHARED on CACHE BOOL "") +# add_subdirectory("${metatensor_SOURCE_DIR}") +# add_subdirectory("${metatensor_torch_SOURCE_DIR}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + find_package(metatensor) + find_package(metatensor_torch) endif() # diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index ce44167449..8e613015fe 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -73,6 +73,11 @@ if(ESPRESSO_BUILD_WITH_CUDA) install(TARGETS espresso_cuda LIBRARY DESTINATION ${ESPRESSO_INSTALL_PYTHON}/espressomd) endif() +if(ESPRESSO_BUILD_WITH_METATENSOR) + target_link_libraries(espresso_core PUBLIC "${TORCH_LIBRARIES}") + target_link_libraries(espresso_core PUBLIC metatensor::shared) + target_link_libraries(espresso_core PUBLIC metatensor_torch) +endif() install(TARGETS espresso_core LIBRARY DESTINATION ${ESPRESSO_INSTALL_PYTHON}/espressomd) @@ -111,6 +116,7 @@ add_subdirectory(immersed_boundary) add_subdirectory(integrators) add_subdirectory(io) add_subdirectory(lb) +add_subdirectory(ml_metatensor) add_subdirectory(magnetostatics) add_subdirectory(nonbonded_interactions) add_subdirectory(object-in-fluid) diff --git a/src/core/ml_metatensor/CMakeLists.txt b/src/core/ml_metatensor/CMakeLists.txt index fa3eb2b411..e7c3e6b474 100644 --- a/src/core/ml_metatensor/CMakeLists.txt +++ b/src/core/ml_metatensor/CMakeLists.txt @@ -17,4 +17,5 @@ # along with this program. If not, see . # -#target_sources(espresso_core PRIVATE bonded_interaction_data.cpp) +target_sources(espresso_core PRIVATE stub.cpp) +target_sources(espresso_core PRIVATE load_model.cpp) diff --git a/src/core/ml_metatensor/load_model.hpp b/src/core/ml_metatensor/load_model.hpp deleted file mode 100644 index e358ff8162..0000000000 --- a/src/core/ml_metatensor/load_model.hpp +++ /dev/null @@ -1,65 +0,0 @@ -using ModelPtr = std::unique_ptr; -using NeighborListRequest = - std::pair; - - -ModelPtr load_model(const std::string& path, const std::string& extensions_directory, torch::device device) { - - return std::make_unique( - metatensor_torch::load_atomistic_model(path, extensions) - ); -} - - -metatensor_torch::ModelCapabilitiesHolder -get_model_capabilites(const ModelPtr& model) { - auto capabilities_ivalue = model->run_method("capabilities"); - return capabilities_ivalue.toCustomClass(); -}; - -bool modle_provides_energy(const torch_metatensor::ModelCapabilitiesHolder& capabilities) { - return (capabilities->outputs().contains("energy")); -} - - -metatensor_toch::ModelMetadataHolder get_model_metadata(ModelPtr& model) { - auto metadata_ivalue = model->run_method("metadata"); - return metadata_ivalue.toCustomClass(); -} - - - -double required_range(const metatensor_torch::ModelCapabilities& capabilities, const metatensor_torch::ModelEvaluatoinOptoins& evaluation_optoins) { - return range = mts_data->capabilities->engine_interaction_range(evaluation_options->length_unit()); -} - -std::vector get_requested_neighbor_lists(ModelPtr& model) { - - std::vector res; - auto requested_nl = mts_data->model->run_method("requested_neighbor_lists"); - for (const auto& ivalue: requested_nl.toList()) { - auto options = ivalue.get().toCustomClass(); - auto cutoff = options->engine_cutoff(mts_data->evaluation_options->length_unit()); - - res.push_back({cutoff, options}); - } - return res; -} - - -torch_metatensor::ModelEvaluationOptions -init_evaluation_optoins(std::string length_unit, std::string energy_unit, torch_metatensor::ModelCapabilities& capabilities) { - torch_metatensor::ModelEvaluationOptoins evaluaoitn_optoins = torch::make_intrusive(); - this->evaluation_options->set_length_unit(std::move(length_unit)); - - auto output = torch::make_intrusive(); - output->explicit_gradients = {}; - output->set_quantity("energy"); - output->set_unit(std::move(energy_unit)); - output->per_atom = capabilities->outputs.at("energy").per_atom; - - evaluation_options->outputs.insert("energy", output); - return evaluation_options; -} - - diff --git a/src/core/ml_metatensor/stub.cpp b/src/core/ml_metatensor/stub.cpp index 67eb68ef77..806d0ec8a3 100644 --- a/src/core/ml_metatensor/stub.cpp +++ b/src/core/ml_metatensor/stub.cpp @@ -12,6 +12,4 @@ #include #include -torch::Tensor test_tensor{}; - #endif From 29ba46b1347d43c77ee001f756fbe9970ad92f41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20Ho=C3=9Fbach?= Date: Tue, 8 Oct 2024 10:12:24 +0200 Subject: [PATCH 9/9] cleanup --- CMakeLists.txt | 63 ++++---- src/core/CMakeLists.txt | 4 +- src/core/ml_metatensor/CMakeLists.txt | 2 +- src/core/ml_metatensor/add_neighbor_list.hpp | 114 +++++++-------- src/core/ml_metatensor/compute.hpp | 142 ++++++++++--------- src/core/ml_metatensor/stub.cpp | 8 +- src/core/ml_metatensor/stub.hpp | 4 +- src/core/ml_metatensor/system.hpp | 40 +++--- 8 files changed, 189 insertions(+), 188 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cc06edb38..90bfb66f13 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -605,42 +605,33 @@ if(ESPRESSO_BUILD_WITH_METATENSOR) # expression from `metatensor_torch` find_package(Torch REQUIRED) -# # cmake-format: off -# set(METATENSOR_URL_BASE "https://github.com/lab-cosmo/metatensor/releases/download") -# set(METATENSOR_CORE_VERSION "0.1.8") -# set(METATENSOR_TORCH_VERSION "0.5.3") -# -# include(FetchContent) -# set(BUILD_SHARED_LIBS on CACHE BOOL "") -# FetchContent_Declare( -# metatensor -# URL "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz" -# URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef -# ) -# set(BUILD_SHARED_LIBS on CACHE BOOL "") -# -# # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146 -# if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS "${metatensor_SOURCE_DIR}") -# message(STATUS "Fetching metatensor v${METATENSOR_CORE_VERSION} from github") -# FetchContent_Populate(metatensor) -# endif() -# set(BUILD_SHARED_LIBS on CACHE BOOL "") -# -# FetchContent_Declare( -# metatensor_torch -# URL "${METATENSOR_URL_BASE}/metatensor-torch-v${METATENSOR_TORCH_VERSION}/metatensor-torch-cxx-${METATENSOR_TORCH_VERSION}.tar.gz" -# ) -# set(BUILD_SHARED_LIBS on CACHE BOOL "") -# if(NOT DEFINED metatensor_torch_SOURCE_DIR OR NOT EXISTS "${metatensor_torch_SOURCE_DIR}") -# message(STATUS "Fetching metatensor torch v${METATENSOR_CORE_VERSION} from github") -# FetchContent_Populate(metatensor_torch) -# endif() -# # cmake-format: on -# set(BUILD_SHARED_LIBS on CACHE BOOL "") -# -# set(METATENSOR_INSTALL_BOTH_STATIC_SHARED on CACHE BOOL "") -# add_subdirectory("${metatensor_SOURCE_DIR}") -# add_subdirectory("${metatensor_torch_SOURCE_DIR}") + # # cmake-format: off set(METATENSOR_URL_BASE + # "https://github.com/lab-cosmo/metatensor/releases/download") + # set(METATENSOR_CORE_VERSION "0.1.8") set(METATENSOR_TORCH_VERSION "0.5.3") + # + # include(FetchContent) set(BUILD_SHARED_LIBS on CACHE BOOL "") + # FetchContent_Declare( metatensor URL + # "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz" + # URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef ) + # set(BUILD_SHARED_LIBS on CACHE BOOL "") + # + # # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146 + # if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS + # "${metatensor_SOURCE_DIR}") message(STATUS "Fetching metatensor + # v${METATENSOR_CORE_VERSION} from github") FetchContent_Populate(metatensor) + # endif() set(BUILD_SHARED_LIBS on CACHE BOOL "") + # + # FetchContent_Declare( metatensor_torch URL + # "${METATENSOR_URL_BASE}/metatensor-torch-v${METATENSOR_TORCH_VERSION}/metatensor-torch-cxx-${METATENSOR_TORCH_VERSION}.tar.gz" + # ) set(BUILD_SHARED_LIBS on CACHE BOOL "") if(NOT DEFINED + # metatensor_torch_SOURCE_DIR OR NOT EXISTS "${metatensor_torch_SOURCE_DIR}") + # message(STATUS "Fetching metatensor torch v${METATENSOR_CORE_VERSION} from + # github") FetchContent_Populate(metatensor_torch) endif() # cmake-format: on + # set(BUILD_SHARED_LIBS on CACHE BOOL "") + # + # set(METATENSOR_INSTALL_BOTH_STATIC_SHARED on CACHE BOOL "") + # add_subdirectory("${metatensor_SOURCE_DIR}") + # add_subdirectory("${metatensor_torch_SOURCE_DIR}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") find_package(metatensor) find_package(metatensor_torch) diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 8e613015fe..89cb7620fd 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -75,8 +75,8 @@ if(ESPRESSO_BUILD_WITH_CUDA) endif() if(ESPRESSO_BUILD_WITH_METATENSOR) target_link_libraries(espresso_core PUBLIC "${TORCH_LIBRARIES}") - target_link_libraries(espresso_core PUBLIC metatensor::shared) - target_link_libraries(espresso_core PUBLIC metatensor_torch) + target_link_libraries(espresso_core PUBLIC metatensor::shared) + target_link_libraries(espresso_core PUBLIC metatensor_torch) endif() install(TARGETS espresso_core diff --git a/src/core/ml_metatensor/CMakeLists.txt b/src/core/ml_metatensor/CMakeLists.txt index e7c3e6b474..9c1cd2a3a1 100644 --- a/src/core/ml_metatensor/CMakeLists.txt +++ b/src/core/ml_metatensor/CMakeLists.txt @@ -18,4 +18,4 @@ # target_sources(espresso_core PRIVATE stub.cpp) -target_sources(espresso_core PRIVATE load_model.cpp) +# target_sources(espresso_core PRIVATE load_model.cpp) diff --git a/src/core/ml_metatensor/add_neighbor_list.hpp b/src/core/ml_metatensor/add_neighbor_list.hpp index ceb18cf1ee..2790ec9090 100644 --- a/src/core/ml_metatensor/add_neighbor_list.hpp +++ b/src/core/ml_metatensor/add_neighbor_list.hpp @@ -1,71 +1,71 @@ +#include "metatensor/torch/atomistic/system.hpp" +#include "utils/Vector.hpp" +#include + struct PairInfo { - int part_id_1, - int part_id_2, + int part_id_1; + int part_id_2; Utils::Vector3d distance; -} - -using Sample = std::array; -using Distances = - std::variant>, std::vector>>; +}; +using Sample = std::array; +using Distances = std::variant>, + std::vector>>; template -TorchTensorBlock neighbor_list_from_pairs(const metatensor_torch::System& system, const PairIterable& pairs) { - auto dtype = system->positions().scalar_type(); - auto device = system->positions().device(); - std::vector samples; - Distances distances; - if (dtype == torch::kFloat64) { - distances = {std::vector>()}; - } - else if (dtype == torch::kFloat32) { - distances = {std::vector>()}; - } - else { - throw std::runtime_error("Unsupported floating poitn data type"); - } +metatensor_torch::TorchTensorBlock +neighbor_list_from_pairs(const metatensor_torch::System &system, + const PairIterable &pairs) { + auto dtype = system->positions().scalar_type(); + auto device = system->positions().device(); + std::vector samples; + Distances distances; - for (auto const& pair: pairs) { - auto sample = Sample{ - pair.particle_id_1, pair.particle_id_2, 0, 0, 0}; - samples.push_back(sample); - (*distances).push_back(pair.distance); - } + if (dtype == torch::kFloat64) { + distances = {std::vector>()}; + } else if (dtype == torch::kFloat32) { + distances = {std::vector>()}; + } else { + throw std::runtime_error("Unsupported floating point data type"); + } + for (auto const &pair : pairs) { + samples.emplace_back(pair.particle_id_1, pair.particle_id_2, 0, 0, 0); + std::visit([&pair](auto &vec) { vec.push_back(pair.distance); }, distances); + } - int64_t n_pairs = samples.size(); - auto samples_tensor = torch::from_blob( - reinterpret_cast(samples.data()), - {n_pairs, 5}, - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU) - ); + auto n_pairs = static_cast(samples.size()); - auto samples = torch::make_intrusive( - std::vector{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"}, - samples_values - ); + auto samples_tensor = torch::from_blob( + samples.data(), {n_pairs, 5}, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU)); - distances_vectors = torch::from_blob( - (*distances).data(), - {n_pairs, 3, 1}, - torch::TensorOptions().dtype(dtype).device(torch::kCPU) - ); - return neighbors = torch::make_intrusive( - distances_vectors.to(dtype).to(device), - samples->to(device), - std::vector{ - metatensor_torch::LabelsHolder::create({"xyz"}, {{0}, {1}, {2}})->to(device), - }, - metatensor_torch::LabelsHolder::create({"distance"}, {{0}})->to(device) - ); + auto samples_ptr = torch::make_intrusive( + std::vector{"first_atom", "second_atom", "cell_shift_a", + "cell_shift_b", "cell_shift_c"}, + samples); -} + auto distances_vectors = torch::from_blob( + std::visit([](auto &vec) { return vec.data(); }, distances), + {n_pairs, 3, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCPU)); -void add_neighbor_list_to_system(MetatensorTorch::system& system, - const TorchTensorBlock& neighbors, - const NeighborListOptions& options) { - metatensor_torch::register_autograd_neighbors(system, neighbors, options_.check_consistency); - system->add_neighbor_list(options, neighbors); -} + auto neighbors = torch::make_intrusive( + distances_vectors.to(dtype).to(device), samples_ptr->to(device), + std::vector{ + metatensor_torch::LabelsHolder::create({"xyz"}, {{0}, {1}, {2}}) + ->to(device), + }, + metatensor_torch::LabelsHolder::create({"distance"}, {{0}})->to(device)); + return neighbors; +} +void add_neighbor_list_to_system( + metatensor_torch::System &system, + const metatensor_torch::TorchTensorBlock &neighbors, + const metatensor_torch::NeighborListOptions &options, + bool check_consistency) { + metatensor_torch::register_autograd_neighbors(system, neighbors, + check_consistency); + system->add_neighbor_list(options, neighbors); +} diff --git a/src/core/ml_metatensor/compute.hpp b/src/core/ml_metatensor/compute.hpp index bea9df4e5b..ba133c3631 100644 --- a/src/core/ml_metatensor/compute.hpp +++ b/src/core/ml_metatensor/compute.hpp @@ -1,80 +1,82 @@ - -torch_metatensor::TensorMapHolder run_model(metatensor_toch::System& system, - const metatensor_torch::ModelEvaluationOptions evaluation_options, - torch::dtype dtypem - torch::Device device) { - - - // only run the calculation for atoms actually in the current domain - auto options = torch::TensorOptions().dtype(torch::kInt32); - this->selected_atoms_values = torch::zeros({n_particles, 2}, options); - for (int i=0; i( - std::vector{"system", "atom"}, mts_data->selected_atoms_values - ); - evaluation_options->set_selected_atoms(selected_atoms->to(device)); - - torch::IValue result_ivalue; - model->forward({ - std::vector{system}, - evaluation_options, - check_consistency - }); - - auto result = result_ivalue.toGenericDict(); - return result.at("energy").toCustomClass(); +#include "metatensor/torch/atomistic/system.hpp" +#include +#include + +metatensor_torch::TensorMapHolder +run_model(metatensor_torch::System &system, int64_t n_particles, + const metatensor_torch::ModelEvaluationOptions evaluation_options, + torch::Dtype dtype, torch::Device device, bool check_consistency) { + + // only run the calculation for atoms actually in the current domain + auto options = torch::TensorOptions().dtype(torch::kInt32); + auto selected_atoms_values = torch::zeros({n_particles, 2}, options); + + for (int i = 0; i < n_particles; i++) { + selected_atoms_values[i][0] = 0; + selected_atoms_values[i][1] = i; + } + auto selected_atoms = torch::make_intrusive( + std::vector{"system", "atom"}, selected_atoms_values); + evaluation_options->set_selected_atoms(selected_atoms->to(device)); + + torch::IValue result_ivalue; + model->forward({std::vector{system}, + evaluation_options, check_consistency}); + + auto result = result_ivalue.toGenericDict(); + auto energy = + result.at("energy").toCustomClass(); + auto energy_tensor = + metatensor_torch::TensorMapHolder::block_by_id(energy, 0); } -double get_energy(torch_metatensor::TensorMapHolder& energy, bool energy_is_per_atom) { - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); - auto energy_tensor = energy_block->values(); - auto energy_detached = energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64); - auto energy_samples = energy_block->samples(); - - // store the energy returned by the model - torch::Tensor global_energy; - if (energy_is_per_atom) { - assert(energy_samples->size() == 2); - assert(energy_samples->names()[0] == "system"); - assert(energy_samples->names()[1] == "atom"); - - auto samples_values = energy_samples->values().to(torch::kCPU); - auto samples = samples_values.accessor(); - -// int n_atoms = selected_atoms_values.sizes(); -// assert(samples_values.sizes() == selected_atoms_values.sizes()); - - auto energies = energy_detached.accessor(); - global_energy = energy_detached.sum(0); - assert(energy_detached.sizes() == std::vector({1})); - } else { - assert(energy_samples->size() == 1); - assert(energy_samples->names()[0] == "system"); - - assert(energy_detached.sizes() == std::vector({1, 1})); - global_energy = energy_detached.reshape({1}); - } +double get_energy(metatensor_torch::TensorMapHolder &energy, + bool energy_is_per_atom) { + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); + auto energy_tensor = energy_block->values(); + auto energy_detached = + energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64); + auto energy_samples = energy_block->samples(); + + // store the energy returned by the model + torch::Tensor global_energy; + if (energy_is_per_atom) { + assert(energy_samples->size() == 2); + assert(energy_samples->names()[0] == "system"); + assert(energy_samples->names()[1] == "atom"); + + auto samples_values = energy_samples->values().to(torch::kCPU); + auto samples = samples_values.accessor(); + + // int n_atoms = selected_atoms_values.sizes(); + // assert(samples_values.sizes() == selected_atoms_values.sizes()); + + auto energies = energy_detached.accessor(); + global_energy = energy_detached.sum(0); + assert(energy_detached.sizes() == std::vector({1})); + } else { + assert(energy_samples->size() == 1); + assert(energy_samples->names()[0] == "system"); + + assert(energy_detached.sizes() == std::vector({1, 1})); + global_energy = energy_detached.reshape({1}); + } return global_energy.item(); } +torch::Tensor get_forces(metatensor::TensorMap &energy, + metatensor_torch::System &system) { + // reset gradients to zero before calling backward + system->positions().mutable_grad() = torch::Tensor(); -torch::Tensor get_forces(torch_metatensor::TensorMap& energy, torch_metatensor::System& system) { - // reset gradients to zero before calling backward - system->positions.mutable_grad() = torch::Tensor(); - - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); - auto energy_tensor = energy_block->values(); + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); + auto energy_tensor = energy_block->values(); - // compute forces/virial with backward propagation - energy_tensor.backward(-torch::ones_like(energy_tensor)); - auto forces_tensor = sytem->positions.grad(); - assert(forces_tensor.is_cpu() && forces_tensor.scalar_type() == torch::kFloat64); + // compute forces/virial with backward propagation + energy_tensor.backward(-torch::ones_like(energy_tensor)); + auto forces_tensor = system->positions().grad(); + assert(forces_tensor.is_cpu() && + forces_tensor.scalar_type() == torch::kFloat64); return forces_tensor; } - - - diff --git a/src/core/ml_metatensor/stub.cpp b/src/core/ml_metatensor/stub.cpp index 806d0ec8a3..9865c4214d 100644 --- a/src/core/ml_metatensor/stub.cpp +++ b/src/core/ml_metatensor/stub.cpp @@ -1,13 +1,13 @@ -#include "config/config.hpp" +#include "config/config.hpp" #ifdef METATENSOR #undef CUDA -#include -#include #include +#include +#include #if TORCH_VERSION_MAJOR >= 2 - #include +#include #endif #include diff --git a/src/core/ml_metatensor/stub.hpp b/src/core/ml_metatensor/stub.hpp index e6828ec9c5..b0517cb5d4 100644 --- a/src/core/ml_metatensor/stub.hpp +++ b/src/core/ml_metatensor/stub.hpp @@ -3,7 +3,7 @@ #ifdef METATENSOR #undef CUDA -#include -#include #include +#include +#include #endif diff --git a/src/core/ml_metatensor/system.hpp b/src/core/ml_metatensor/system.hpp index d9a5dac359..0c3dfa5452 100644 --- a/src/core/ml_metatensor/system.hpp +++ b/src/core/ml_metatensor/system.hpp @@ -1,32 +1,39 @@ -using ParticleTypeMap = std::unorderd_map; +#include "ATen/core/TensorBody.h" +#include "metatensor/torch/atomistic/system.hpp" +#include "utils/Vector.hpp" +#include -metatensor_torch::System - : system_from_lmp(const TypeMapping &type_map, - const std::vector &engine_positions, - const std::vector &engine_particle_types, - const Vector3d &box_size, bool do_virial, - torch::ScalarType dtype, torch::Device device) { +using ParticleTypeMap = std::unordered_map; + +metatensor_torch::System ::system_from_lmp( + const ParticleTypeMap &type_map, std::vector &engine_positions, + const std::vector + &engine_particle_types, // TODO: This should be std::vector? + const Utils::Vector3d &box_size, bool do_virial, torch::ScalarType dtype, + torch::Device device) { auto tensor_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); - if (engine_positions % 3 != 0) + if (engine_positions.size() % 3 != 0) throw std::runtime_error( - "Positoin array must have a multiple of 3 elements"); - const int n_particles = engine_positions.size() / 3; + "Position array must have a multiple of 3 elements"); + const auto n_particles = static_cast(engine_positions.size()) / 3; if (engine_particle_types.size() != n_particles) throw std::runtime_error( - "Length of positon and particle tyep arrays inconsistent"); + "Length of position and particle type arrays inconsistent"); auto positions = torch::from_blob( - engien_positions.data(), {n_particles, 3}, + engine_positions.data(), {n_particles, 3}, // requires_grad=true since we always need gradients w.r.t. positions tensor_options.requires_grad(true)); std::vector particle_types_ml; std::ranges::transform( - particle_types_engine, std::back_inserter(particle_types_ml), + engine_particle_types, std::back_inserter(particle_types_ml), [&type_map](int engine_type) { return type_map.at(engine_type); }); auto particle_types_ml_tensor = - Torch::Tensor(particle_types_ml, tensor_options.requires_grad(true)); + torch::from_blob(particle_types_ml.data(), + {static_cast(particle_types_ml.size())}, + tensor_options.requires_grad(true)); auto cell = torch::zeros({3, 3}, tensor_options); for (int i : {0, 1, 2}) @@ -35,6 +42,7 @@ metatensor_torch::System positions.to(dtype).to(device); cell = cell.to(dtype).to(device); - return system = torch::make_intrusive( - particle_types_ml_tensor.to(device), positions, cell); + auto system = torch::make_intrusive( + particle_types_ml_tensor.to(device), positions, cell); + return system; }