Skip to content

Commit 47fada0

Browse files
committed
Sketching unit tests for new factory functions
1 parent 011e6f0 commit 47fada0

File tree

2 files changed

+54
-19
lines changed

2 files changed

+54
-19
lines changed

abipy/abio/factories.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,17 @@ def scf_for_phonons(structure, pseudos, kppa=None, ecut=None, pawecutdg=None, nb
13811381

13821382
def ddkpert_from_gsinput(gs_input, ddk_pert, nband=None, use_symmetries=False, ddk_tol=None, manager=None) -> AbinitInput:
13831383
"""
1384-
Returns an |AbinitInput| to perform a DDK calculations for a specific perturbation and based on a ground state |AbinitInput|.
1384+
Returns an |AbinitInput| to perform a DDK calculations for a specific perturbation based on a ground state |AbinitInput|.
13851385
13861386
Args:
1387-
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
1387+
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
13881388
ddk_pert: dict with the Abinit variables defining the perturbation
13891389
Example: {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]},
13901390
use_symmetries: boolean that determines if the irreducible components of the perturbation are used.
13911391
Default to False. (TODO: Should be implemented)
1392-
ddk_tol: a dictionary with a single key defining the type of tolerance used for the DDK calculations and its value. Default: {"tolvrs": 1.0e-22}.
1393-
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
1392+
ddk_tol: a dictionary with a single key defining the type of tolerance used for the DDK calculations and its value.
1393+
Default: {"tolvrs": 1.0e-22}.
1394+
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
13941395
"""
13951396
gs_input = gs_input.deepcopy()
13961397
gs_input.pop_irdvars()
@@ -1409,19 +1410,20 @@ def ddkpert_from_gsinput(gs_input, ddk_pert, nband=None, use_symmetries=False, d
14091410

14101411
return ddk_inp
14111412

1413+
14121414
def ddepert_from_gsinput(gs_input, dde_pert, use_symmetries=True, dde_tol=None, manager=None) -> AbinitInput:
14131415
"""
1414-
Returns an |AbinitInput| to perform a DDE calculations for a specific perturbation and based on a ground state |AbinitInput|.
1416+
Returns an |AbinitInput| to perform a DDE calculations for a specific perturbation based on a ground state |AbinitInput|.
14151417
14161418
Args:
1417-
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
1419+
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
14181420
dde_pert: dict with the Abinit variables defining the perturbation
1419-
Example: {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]},
1421+
Example: {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]},
14201422
use_symmetries: boolean that determines if the irreducible components of the perturbation are used.
14211423
Default to True. Should be set to False for nonlinear coefficients calculation.
1422-
dde_tol: a dictionary with a single key defining the type of tolerance used for the DDE calculations and
1424+
dde_tol: a dictionary with a single key defining the type of tolerance used for the DDE calculations and
14231425
its value. Default: {"tolvrs": 1.0e-22}.
1424-
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
1426+
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
14251427
"""
14261428
gs_input = gs_input.deepcopy()
14271429
gs_input.pop_irdvars()
@@ -1434,15 +1436,16 @@ def ddepert_from_gsinput(gs_input, dde_pert, use_symmetries=True, dde_tol=None,
14341436

14351437
return dde_inp
14361438

1439+
14371440
def dtepert_from_gsinput(gs_input, dte_pert, manager=None) -> AbinitInput:
14381441
"""
14391442
Returns an |AbinitInput| to perform a DTE calculations for a specific perturbation and based on a ground state |AbinitInput|.
14401443
14411444
Args:
1442-
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
1445+
gs_input: an |AbinitInput| representing a ground state calculation, likely the SCF performed to get the WFK.
14431446
dte_pert: dict with the Abinit variables defining the perturbation
1444-
Example: {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]},
1445-
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
1447+
Example: {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]},
1448+
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
14461449
"""
14471450
gs_input = gs_input.deepcopy()
14481451
gs_input.pop_irdvars()
@@ -1452,6 +1455,7 @@ def dtepert_from_gsinput(gs_input, dte_pert, manager=None) -> AbinitInput:
14521455

14531456
return dte_inp
14541457

1458+
14551459
def dte_from_gsinput(gs_input, use_phonons=True, ph_tol=None, ddk_tol=None, dde_tol=None,
14561460
skip_dte_permutations=False, manager=None) -> MultiDataset:
14571461
"""
@@ -1509,7 +1513,7 @@ def dte_from_gsinput(gs_input, use_phonons=True, ph_tol=None, ddk_tol=None, dde_
15091513
gs_input.set_vars(nband=nband)
15101514
gs_input.pop('nbdbuf', None)
15111515
multi_dte = gs_input.make_dte_inputs(phonon_pert=use_phonons, skip_permutations=skip_dte_permutations,
1512-
manager=manager)
1516+
manager=manager)
15131517
multi_dte.add_tags(atags.DTE)
15141518
multi.extend(multi_dte)
15151519

abipy/abio/tests/test_factories.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import abipy.data as abidata
23
import abipy.abilab as abilab
34

@@ -7,8 +8,9 @@
78
from abipy.abio.factories import *
89
from abipy.abio.factories import (BandsFromGsFactory, IoncellRelaxFromGsFactory, HybridOneShotFromGsFactory,
910
ScfForPhononsFactory, PhononsFromGsFactory, PiezoElasticFactory, PiezoElasticFromGsFactory, ShiftMode)
10-
from abipy.abio.factories import _find_nscf_nband_from_gsinput
11-
import json
11+
from abipy.abio.factories import _find_nscf_nband_from_gsinput, minimal_scf_input
12+
from abipy.abio.input_tags import DDK, DDE, PH_Q_PERT, STRAIN, DTE, PH_Q_PERT
13+
1214

1315
write_inputs_to_json = False
1416

@@ -304,7 +306,6 @@ def test_phonons_from_gsinput(self):
304306
with_bec=False, ph_tol=None, ddk_tol=None, dde_tol=None)
305307
self.abivalidate_multi(multi)
306308

307-
from abipy.abio.input_tags import DDK, DDE, PH_Q_PERT
308309
inp_ddk = multi.filter_by_tags(DDK)[0]
309310
inp_dde = multi.filter_by_tags(DDE)[0]
310311
inp_ph_q_pert_1 = multi.filter_by_tags(PH_Q_PERT)[0]
@@ -499,7 +500,7 @@ def test_dfpt_from_gsinput(self):
499500
do_dte=True, ph_tol=None, ddk_tol=None, dde_tol=None)
500501
self.abivalidate_multi(multi)
501502

502-
from abipy.abio.input_tags import DDK, DDE, PH_Q_PERT, STRAIN, DTE
503+
503504
inp_ddk = multi.filter_by_tags(DDK)[0]
504505
inp_dde = multi.filter_by_tags(DDE)[0]
505506
inp_ph_q_pert_1 = multi.filter_by_tags(PH_Q_PERT)[0]
@@ -530,10 +531,40 @@ def test_dfpt_from_gsinput(self):
530531
self.assert_input_equality('dfpt_from_gsinput_dte.json', inp_dte)
531532

532533
def test_minimal_scf_input(self):
533-
from abipy.abio.factories import minimal_scf_input
534534
inp = minimal_scf_input(self.si_structure, self.si_pseudo)
535535

536536
self.abivalidate_input(inp)
537-
538537
self.assertEqual(inp["nband"], 1)
539538
self.assertEqual(inp["nstep"], 0)
539+
540+
def test_ddkpert_from_gsinput(self):
541+
gs_inp = gs_input(self.si_structure, self.si_pseudo, kppa=None, ecut=2, spin_mode="unpolarized")
542+
gs_inp["nband"] = 4
543+
gs_inp["autoparal"] = 1
544+
gs_inp["npfft"] = 10
545+
546+
ddk_pert = {'idir': 1, 'ipert': 3, 'qpt': [0.0, 0.0, 0.0]}
547+
ddk_input = ddkpert_from_gsinput(gs_inp, ddk_pert)
548+
assert ddk_input["tolwfr"] == 1.0e-22
549+
assert "autoparal" not in ddk_input
550+
assert "npfft" not in ddk_input
551+
self.abivalidate_input(ddk_input)
552+
553+
dde_pert = {'idir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]}
554+
dde_input = ddepert_from_gsinput(gs_inp, dde_pert)
555+
assert "autoparal" not in dde_input
556+
assert "npfft" not in dde_input
557+
assert dde_input["tolvrs"] == 1.0e-22
558+
self.abivalidate_input(dde_input)
559+
560+
#dte_pert = {'i1dir': 1, 'ipert': 4, 'qpt': [0.0, 0.0, 0.0]}
561+
#dte_input = dtepert_from_gsinput(gs_inp, dte_pert)
562+
#assert "autoparal" not in dte_input
563+
#assert "npfft" not in dte_input
564+
#assert dte_input["tolvrs"] == 1.0e-22
565+
#self.abivalidate_input(dte_input)
566+
567+
#def test_dte_from_gsinput(self):
568+
# gs_inp = gs_input(self.si_structure, self.si_pseudo, kppa=None, ecut=2, spin_mode="unpolarized")
569+
# multi = dte_from_gsinput(gs_inp, use_phonons=True)
570+
# self.abivalidate_input(multi)

0 commit comments

Comments
 (0)