Skip to content

Commit

Permalink
Noncol+SOC may be working
Browse files Browse the repository at this point in the history
  • Loading branch information
YutaYahagi committed Oct 21, 2024
1 parent 29239d6 commit f916c6a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 46 deletions.
63 changes: 41 additions & 22 deletions run_automated_wannier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# from ase.io import read as aseread
from pymatgen.core import Structure
from aiida_wannier90_workflows.workflows import Wannier90BandsWorkChain
from aiida_quantumespresso.common.types import SpinType
import numpy as np

# Please modify these according to your machine
Expand All @@ -21,10 +22,10 @@ def check_codes():
# will raise NotExistent error
try:
codes = dict(
pw_code=orm.load_code(str_pw),
pw2wannier90_code=orm.load_code(str_pw2wan),
projwfc_code=orm.load_code(str_projwfc),
wannier90_code=orm.load_code(str_wan),
pw=orm.load_code(str_pw),
pw2wannier90=orm.load_code(str_pw2wan),
projwfc=orm.load_code(str_projwfc),
wannier90=orm.load_code(str_wan),
)
except NotExistent as e:
print(e)
Expand All @@ -46,8 +47,8 @@ def parse_arugments():
parser.add_argument(
'-p',
"--protocol",
help="available protocols are 'theos-ht-1.0' and 'testing'",
default="testing"
help="available protocols are 'moderate', 'precise', and 'fast'",
default="fast"
)
parser.add_argument(
'-m',
Expand Down Expand Up @@ -75,6 +76,11 @@ def parse_arugments():
help="Retrieve Wannier Hamiltonian after the workflow finished",
action="store_true"
)
parser.add_argument(
"--soi",
help="Consider Spin-Orbit Interaction",
action="store_true"
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -168,7 +174,7 @@ def print_help(workchain, structure):

def submit_workchain(
cif_file, protocol, only_valence, do_disentanglement, do_mlwf,
retrieve_hamiltonian, group_name
retrieve_hamiltonian, group_name, soi
):
codes = check_codes()

Expand Down Expand Up @@ -201,21 +207,34 @@ def submit_workchain(
)
)

wannier90_workchain_parameters = {
"code": {
'pw': codes['pw_code'],
'pw2wannier90': codes['pw2wannier90_code'],
'projwfc': codes['projwfc_code'],
'wannier90': codes['wannier90_code']
},
"protocol": orm.Dict(dict={'name': protocol}),
"structure": structure,
"controls": controls
}

workchain = submit(
Wannier90BandsWorkChain, **wannier90_workchain_parameters
spintype = SpinType.NONE
if soi:
spintype = SpinType.SPIN_ORBIT

# wannier90_workchain_parameters = {
# "code": {
# 'pw': codes['pw_code'],
# 'pw2wannier90': codes['pw2wannier90_code'],
# 'projwfc': codes['projwfc_code'],
# 'wannier90': codes['wannier90_code']
# },
# "protocol": orm.Dict(dict={'name': protocol}),
# "structure": structure,
# "controls": controls,
# "spin_type": spintype
# }
# workchain = submit(
# Wannier90BandsWorkChain, **wannier90_workchain_parameters
# )
builder = Wannier90BandsWorkChain.get_builder_from_protocol(
codes,
structure,
protocol=protocol,
retrieve_hamiltonian=retrieve_hamiltonian,
# controls=controls,
spin_type=spintype
)
workchain = submit(builder)

add_to_group(workchain, group_name)
print_help(workchain, structure)
Expand All @@ -227,5 +246,5 @@ def submit_workchain(

submit_workchain(
args.cif, args.protocol, args.only_valence, args.do_disentanglement,
args.do_mlwf, args.retrieve_hamiltonian, group_name
args.do_mlwf, args.retrieve_hamiltonian, group_name, args.soi
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ retrieve_matrices:
- '*.amn'
- '*.mmn'
- '*.spn'
spin_collinear:
scf:
pw:
parameters:
SYSTEM:
nspin: 2
nscf:
pw:
parameters:
SYSTEM:
nspin: 2
spin_noncollinear:
scf:
pw:
Expand Down
39 changes: 15 additions & 24 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
elif spin_type == SpinType.SPIN_ORBIT:
overrides = recursive_merge(protocol_overrides["spin_orbit"], overrides)
pw_spin_type = SpinType.NONE
elif spin_type == SpinType.COLLINEAR:
overrides = recursive_merge(protocol_overrides["spin_collinear"], overrides)
pw_spin_type = SpinType.NONE
else:
pw_spin_type = spin_type

Expand Down Expand Up @@ -471,31 +474,19 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
# Prepare SCF builder
scf_overrides = inputs.get("scf", {})
scf_overrides["pseudo_family"] = pseudo_family
if spin_type is SpinType.COLLINEAR:
scf_builder = PwBaseWorkChain.get_builder_from_protocol(
code=codes["pw"],
structure=structure,
protocol=protocol,
overrides=scf_overrides,
# Setting initial_magnetic_moments to scf is sufficient
initial_magnetic_moments=initial_magnetic_moments,
electronic_type=electronic_type,
spin_type=pw_spin_type,
)
else:
scf_builder = PwBaseWorkChain.get_builder_from_protocol(
code=codes["pw"],
structure=structure,
protocol=protocol,
overrides=scf_overrides,
electronic_type=electronic_type,
spin_type=pw_spin_type,
scf_builder = PwBaseWorkChain.get_builder_from_protocol(
code=codes["pw"],
structure=structure,
protocol=protocol,
overrides=scf_overrides,
electronic_type=electronic_type,
spin_type=pw_spin_type,
)
if initial_magnetic_moments:
# For non-collinear magnetism
scf_builder["pw"]["parameters"]["SYSTEM"].update(
initial_magnetic_moments
)
if initial_magnetic_moments:
# For non-collinear magnetism
scf_builder["pw"]["parameters"]["SYSTEM"].update(
initial_magnetic_moments
)
# Remove workchain excluded inputs
scf_builder["pw"].pop("structure", None)
scf_builder.pop("clean_workdir", None)
Expand Down

0 comments on commit f916c6a

Please sign in to comment.