Skip to content

Commit e9b27b6

Browse files
author
Miki Bonacci
committed
Merge branch 'fixing/compatibility' into main
2 parents 37598ca + 7abce1a commit e9b27b6

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

aiida_yambo_wannier90/workflows/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aiida_wannier90_workflows.utils.kpoints import (
2727
get_explicit_kpoints,
2828
get_mesh_from_kpoints,
29+
get_path_from_kpoints
2930
)
3031
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_kpoints
3132
from aiida_wannier90_workflows.workflows import (
@@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
626627
"""Initialize context variables."""
627628

628629
self.ctx.current_structure = self.inputs.structure
629-
630-
if "bands_kpoints" in self.inputs:
631-
self.ctx.bands_kpoints = self.inputs.bands_kpoints
630+
632631

633632
# Converged mesh from YamboConvergence
634633
self.ctx.kpoints_gw_conv = None
@@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
676675

677676
def should_run_seekpath(self):
678677
"""Run seekpath if the `inputs.bands_kpoints` is not provided."""
679-
return "bands_kpoints" not in self.inputs
678+
if "bands_kpoints" in self.inputs:
679+
self.ctx.current_kpoint_path = get_path_from_kpoints(
680+
self.inputs["bands_kpoints"]
681+
)
682+
return False
683+
else:
684+
return True
680685

681686
def run_seekpath(self):
682687
"""Run the structure through SeeKpath to get the primitive and normalized structure."""
@@ -692,7 +697,11 @@ def run_seekpath(self):
692697

693698
self.ctx.current_structure = result["primitive_structure"]
694699

695-
self.ctx.current_bands_kpoints = result["explicit_kpoints"]
700+
# Add `kpoint_path` for Wannier bands
701+
self.ctx.current_kpoint_path = get_path_from_kpoints(
702+
result["explicit_kpoints"]
703+
)
704+
696705

697706
structure_formula = self.inputs.structure.get_formula()
698707
primitive_structure_formula = result["primitive_structure"].get_formula()
@@ -1056,11 +1065,12 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict:
10561065

10571066
inputs.wannier90.structure = self.ctx.current_structure
10581067

1059-
#params = inputs.wannier90.parameters.get_dict()
1060-
#params["bands_plot"] = False
1061-
#inputs.wannier90.parameters = orm.Dict(params)
1068+
params = inputs.wannier90.parameters.get_dict()
1069+
params["bands_plot"] = False
1070+
inputs.wannier90.parameters = orm.Dict(params)
10621071

1063-
inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
1072+
if self.ctx.current_kpoint_path:
1073+
inputs.wannier90.kpoint_path = self.ctx.current_kpoint_path
10641074

10651075
# Use commensurate kmesh
10661076
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
@@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict:
11721182
)
11731183

11741184
inputs.structure = self.ctx.current_structure
1175-
inputs.bands_kpoints = self.ctx.current_bands_kpoints
1185+
if self.ctx.current_kpoint_path:
1186+
inputs.wannier90.wannier90.kpoint_path = self.ctx.current_kpoint_path
11761187

11771188
# Use commensurate kmesh
11781189
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
@@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict:
12581269
)
12591270

12601271
inputs.wannier90.structure = self.ctx.current_structure
1261-
inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
1272+
if self.ctx.current_kpoint_path:
1273+
inputs.kpoint_path = self.ctx.current_kpoint_path
12621274

12631275
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
12641276
set_kpoints(

examples/example_01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from aiida_wannier90_workflows.cli.params import RUN
1313
from aiida_wannier90_workflows.utils.workflows.builder.serializer import print_builder
1414
from aiida_wannier90_workflows.utils.kpoints import get_explicit_kpoints_from_mesh
15-
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands, set_kpoints
15+
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands
1616
from aiida_wannier90_workflows.utils.workflows.builder.submit import submit_and_add_group
1717
from aiida_wannier90_workflows.common.types import WannierProjectionType
1818
from aiida_wannier90_workflows.workflows import Wannier90BandsWorkChain

0 commit comments

Comments
 (0)