Skip to content

Commit

Permalink
Add global slatm and tests
Browse files Browse the repository at this point in the history
todo: condensed version; mbtypes options
  • Loading branch information
briling committed Dec 4, 2023
1 parent 7593182 commit 6020610
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 40 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ dependencies:
- pytest==6.2.5
- scipy==1.7.3
- toml==0.10.2
- ase==3.22
- tqdm==4.66
- git+https://github.com/lab-cosmo/equistore.git@e5b9dc365369ba2584ea01e9d6a4d648008aaab8#subdirectory=python/equistore-core
1 change: 1 addition & 0 deletions qstack/qml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from qstack.qml import slatm
8 changes: 7 additions & 1 deletion qstack/qml/slatm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_cos(a, b, c):


def get_slatm(q, r, mbtypes, qml_compatible=True, stack_all=True,
global_repr=False,
r0=defaults.r0, rcut=defaults.rcut, sigma2=defaults.sigma2, dgrid2=defaults.dgrid2,
theta0=defaults.theta0, sigma3=defaults.sigma3, dgrid3=defaults.dgrid3):

Expand Down Expand Up @@ -157,18 +158,22 @@ def get_slatm(q, r, mbtypes, qml_compatible=True, stack_all=True,
if stack_all:
slatm = np.vstack(slatm)

if global_repr:
slatm = np.sum(slatm, axis=0)

return slatm



def get_slatm_for_dataset(molecules,
progress=False,
global_repr=False,
qml_mbtypes=True, qml_compatible=True, stack_all=True,
r0=defaults.r0, rcut=defaults.rcut, sigma2=defaults.sigma2, dgrid2=defaults.dgrid2,
theta0=defaults.theta0, sigma3=defaults.sigma3, dgrid3=defaults.dgrid3):

if isinstance(molecules[0], str):
import ase
import ase.io
molecules = [ase.io.read(xyz) for xyz in molecules]

qs = [mol.numbers for mol in molecules]
Expand All @@ -181,6 +186,7 @@ def get_slatm_for_dataset(molecules,
slatm = []
for mol in molecules:
slatm.append(get_slatm(mol.numbers, mol.positions, mbtypes,
global_repr=global_repr,
qml_compatible=qml_compatible, stack_all=stack_all,
r0=r0, rcut=rcut, sigma2=sigma2, dgrid2=dgrid2,
theta0=theta0, sigma3=sigma3, dgrid3=dgrid3))
Expand Down
39 changes: 0 additions & 39 deletions qstack/qml/test.py

This file was deleted.

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ numpy===1.22.3
scipy==1.7.3
toml==0.10.2
scikit-learn==0.24.2
ase==3.22
tqdm==4.66
equistore-core @ git+https://github.com/lab-cosmo/equistore.git@e5b9dc365369ba2584ea01e9d6a4d648008aaab8#subdirectory=python/equistore-core

19 changes: 19 additions & 0 deletions tests/data/slatm/0.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
17
charge = 0
C 1.334277 0.421896 -0.503330
O 1.606734 -0.382606 0.614935
C 2.562285 0.137308 1.501005
C 3.936133 0.288058 0.884991
C 2.115760 1.334063 2.366971
O 2.402190 0.606383 3.567262
C 2.549701 -0.616684 2.832718
H 0.474474 -0.017852 -1.008820
H 2.173000 0.455753 -1.208437
H 1.079495 1.450047 -0.212828
H 3.949111 1.048035 0.100338
H 4.264756 -0.662057 0.456512
H 4.653061 0.590591 1.650946
H 1.046433 1.556985 2.270572
H 2.698776 2.257263 2.297841
H 1.674223 -1.266707 2.933492
H 3.457962 -1.159075 3.110051
13 changes: 13 additions & 0 deletions tests/data/slatm/1.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
11
charge = 0
C 0.737797 0.000562 0.073413
C 2.064519 -0.031214 -0.079240
N 2.711390 0.065790 -1.302740
C 4.063267 0.040246 -1.527779
O 4.541353 0.130874 -2.627340
C 2.899768 -0.173867 1.077472
N 3.550775 -0.288994 2.019337
H 0.076383 0.107156 -0.777476
H 0.302610 -0.080503 1.057996
H 2.136563 0.166701 -2.127042
H 4.656784 -0.071049 -0.605791
17 changes: 17 additions & 0 deletions tests/data/slatm/2.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
15
charge = 0
O 2.634473 0.575247 -0.666752
C 2.184265 -0.046602 0.250982
O 2.882994 -0.627798 1.223051
C 4.301892 -0.493465 1.148105
C 4.942708 0.783769 1.698156
C 6.043112 -0.116490 2.298216
C 5.067033 -1.307379 2.185849
H 1.111814 -0.226951 0.418138
H 4.622887 -0.695969 0.124698
H 4.314974 1.214015 2.481754
H 5.218382 1.560136 0.985767
H 6.415691 0.135583 3.290388
H 6.892530 -0.222863 1.621469
H 4.470240 -1.413280 3.094053
H 5.452897 -2.283473 1.891117
15 changes: 15 additions & 0 deletions tests/data/slatm/3.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
13
charge = 0
N 0.922783 0.082122 -0.086886
C 2.194494 0.111758 -0.036301
N 2.983989 1.248023 -0.232815
O 4.204400 1.227224 0.484302
C 5.010475 0.201859 -0.046570
C 4.293889 -1.106183 0.007541
C 2.965284 -1.128936 0.067134
H 0.524250 1.016312 -0.182736
H 2.531718 2.113654 0.029945
H 5.903984 0.190610 0.581907
H 5.315589 0.440806 -1.077588
H 4.877672 -2.020729 -0.001784
H 2.387973 -2.041720 0.135840
17 changes: 17 additions & 0 deletions tests/data/slatm/4.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
15
charge = 0
C 1.333783 0.656278 0.501107
O 2.322436 0.083498 -0.324248
C 2.926464 0.857651 -1.230538
N 3.825087 0.402262 -2.004531
C 4.245574 1.517329 -2.841689
C 3.409274 2.756754 -2.410818
N 2.597611 2.180625 -1.360752
H 0.983716 -0.142924 1.151507
H 0.485446 1.029742 -0.082410
H 1.741307 1.460546 1.122987
H 4.085435 1.277443 -3.897289
H 5.317557 1.698653 -2.717549
H 2.793491 3.149736 -3.225927
H 4.034346 3.572761 -2.034656
H 1.915074 2.679067 -0.825776
Binary file added tests/data/slatm/slatm_global.npy
Binary file not shown.
Binary file added tests/data/slatm/slatm_local.npy
Binary file not shown.
25 changes: 25 additions & 0 deletions tests/test_slatm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import glob
import numpy as np
from qstack.qml.slatm import get_slatm_for_dataset


def test_slatm_global():
path = os.path.dirname(os.path.realpath(__file__))
v0 = np.load(f'{path}/data/slatm/slatm_global.npy')
xyzs = [f for f in sorted(glob.glob(f"{path}/data/slatm/*.xyz"))]
v = get_slatm_for_dataset(xyzs, progress=False, global_repr=True)
assert(np.linalg.norm(v-v0)<1e-10)


def test_slatm_local():
path = os.path.dirname(os.path.realpath(__file__))
v0 = np.load(f'{path}/data/slatm/slatm_local.npy')
xyzs = [f for f in sorted(glob.glob(f"{path}/data/slatm/*.xyz"))]
v = get_slatm_for_dataset(xyzs, progress=False)
assert(np.linalg.norm(v-v0)<1e-10)


if __name__ == '__main__':
test_slatm_local()
test_slatm_global()

0 comments on commit 6020610

Please sign in to comment.