26
26
from aiida_wannier90_workflows .utils .kpoints import (
27
27
get_explicit_kpoints ,
28
28
get_mesh_from_kpoints ,
29
+ get_path_from_kpoints
29
30
)
30
31
from aiida_wannier90_workflows .utils .workflows .builder .setter import set_kpoints
31
32
from aiida_wannier90_workflows .workflows import (
@@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
626
627
"""Initialize context variables."""
627
628
628
629
self .ctx .current_structure = self .inputs .structure
629
-
630
- if "bands_kpoints" in self .inputs :
631
- self .ctx .bands_kpoints = self .inputs .bands_kpoints
630
+
632
631
633
632
# Converged mesh from YamboConvergence
634
633
self .ctx .kpoints_gw_conv = None
@@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
676
675
677
676
def should_run_seekpath (self ):
678
677
"""Run seekpath if the `inputs.bands_kpoints` is not provided."""
679
- return "bands_kpoints" not in self .inputs
678
+ if "bands_kpoints" in self .inputs :
679
+ self .ctx .current_kpoint_path = get_path_from_kpoints (
680
+ self .inputs ["bands_kpoints" ]
681
+ )
682
+ return False
683
+ else :
684
+ return True
680
685
681
686
def run_seekpath (self ):
682
687
"""Run the structure through SeeKpath to get the primitive and normalized structure."""
@@ -692,7 +697,11 @@ def run_seekpath(self):
692
697
693
698
self .ctx .current_structure = result ["primitive_structure" ]
694
699
695
- self .ctx .current_bands_kpoints = result ["explicit_kpoints" ]
700
+ # Add `kpoint_path` for Wannier bands
701
+ self .ctx .current_kpoint_path = get_path_from_kpoints (
702
+ result ["explicit_kpoints" ]
703
+ )
704
+
696
705
697
706
structure_formula = self .inputs .structure .get_formula ()
698
707
primitive_structure_formula = result ["primitive_structure" ].get_formula ()
@@ -1056,11 +1065,12 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict:
1056
1065
1057
1066
inputs .wannier90 .structure = self .ctx .current_structure
1058
1067
1059
- # params = inputs.wannier90.parameters.get_dict()
1060
- # params["bands_plot"] = False
1061
- # inputs.wannier90.parameters = orm.Dict(params)
1068
+ params = inputs .wannier90 .parameters .get_dict ()
1069
+ params ["bands_plot" ] = False
1070
+ inputs .wannier90 .parameters = orm .Dict (params )
1062
1071
1063
- inputs .wannier90 .bands_kpoints = self .ctx .current_bands_kpoints
1072
+ if self .ctx .current_kpoint_path :
1073
+ inputs .wannier90 .kpoint_path = self .ctx .current_kpoint_path
1064
1074
1065
1075
# Use commensurate kmesh
1066
1076
if self .ctx .kpoints_w90_input != self .ctx .kpoints_w90 :
@@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict:
1172
1182
)
1173
1183
1174
1184
inputs .structure = self .ctx .current_structure
1175
- inputs .bands_kpoints = self .ctx .current_bands_kpoints
1185
+ if self .ctx .current_kpoint_path :
1186
+ inputs .wannier90 .wannier90 .kpoint_path = self .ctx .current_kpoint_path
1176
1187
1177
1188
# Use commensurate kmesh
1178
1189
if self .ctx .kpoints_w90_input != self .ctx .kpoints_w90 :
@@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict:
1258
1269
)
1259
1270
1260
1271
inputs .wannier90 .structure = self .ctx .current_structure
1261
- inputs .wannier90 .bands_kpoints = self .ctx .current_bands_kpoints
1272
+ if self .ctx .current_kpoint_path :
1273
+ inputs .kpoint_path = self .ctx .current_kpoint_path
1262
1274
1263
1275
if self .ctx .kpoints_w90_input != self .ctx .kpoints_w90 :
1264
1276
set_kpoints (
0 commit comments