Skip to content

Commit 7abce1a

Browse files
author
Miki Bonacci
committed
fixing kpoint path
1 parent e54dd58 commit 7abce1a

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

aiida_yambo_wannier90/workflows/__init__.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aiida_wannier90_workflows.utils.kpoints import (
2727
get_explicit_kpoints,
2828
get_mesh_from_kpoints,
29+
get_path_from_kpoints
2930
)
3031
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_kpoints
3132
from aiida_wannier90_workflows.workflows import (
@@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
626627
"""Initialize context variables."""
627628

628629
self.ctx.current_structure = self.inputs.structure
629-
630-
if "bands_kpoints" in self.inputs:
631-
self.ctx.bands_kpoints = self.inputs.bands_kpoints
630+
632631

633632
# Converged mesh from YamboConvergence
634633
self.ctx.kpoints_gw_conv = None
@@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
676675

677676
def should_run_seekpath(self):
678677
"""Run seekpath if the `inputs.bands_kpoints` is not provided."""
679-
return "bands_kpoints" not in self.inputs
678+
if "bands_kpoints" in self.inputs:
679+
self.ctx.current_kpoint_path = get_path_from_kpoints(
680+
self.inputs["bands_kpoints"]
681+
)
682+
return False
683+
else:
684+
return True
680685

681686
def run_seekpath(self):
682687
"""Run the structure through SeeKpath to get the primitive and normalized structure."""
@@ -692,7 +697,11 @@ def run_seekpath(self):
692697

693698
self.ctx.current_structure = result["primitive_structure"]
694699

695-
self.ctx.current_bands_kpoints = result["explicit_kpoints"]
700+
# Add `kpoint_path` for Wannier bands
701+
self.ctx.current_kpoint_path = get_path_from_kpoints(
702+
result["explicit_kpoints"]
703+
)
704+
696705

697706
structure_formula = self.inputs.structure.get_formula()
698707
primitive_structure_formula = result["primitive_structure"].get_formula()
@@ -1060,7 +1069,8 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict:
10601069
params["bands_plot"] = False
10611070
inputs.wannier90.parameters = orm.Dict(params)
10621071

1063-
#inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
1072+
if self.ctx.current_kpoint_path:
1073+
inputs.wannier90.kpoint_path = self.ctx.current_kpoint_path
10641074

10651075
# Use commensurate kmesh
10661076
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
@@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict:
11721182
)
11731183

11741184
inputs.structure = self.ctx.current_structure
1175-
inputs.bands_kpoints = self.ctx.current_bands_kpoints
1185+
if self.ctx.current_kpoint_path:
1186+
inputs.wannier90.wannier90.kpoint_path = self.ctx.current_kpoint_path
11761187

11771188
# Use commensurate kmesh
11781189
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
@@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict:
12581269
)
12591270

12601271
inputs.wannier90.structure = self.ctx.current_structure
1261-
inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
1272+
if self.ctx.current_kpoint_path:
1273+
inputs.kpoint_path = self.ctx.current_kpoint_path
12621274

12631275
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
12641276
set_kpoints(

0 commit comments

Comments
 (0)