Skip to content

Commit

Permalink
Change block subworkflow to only do one block; rejigged the way that …
Browse files Browse the repository at this point in the history
…workflow names work
  • Loading branch information
elinscott committed Jun 27, 2024
1 parent 2117cb1 commit 775efec
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 93 deletions.
3 changes: 3 additions & 0 deletions src/koopmans/processes/wannier.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(self, merge_function: Callable[[List[List[str]]], List[str]], **kwa
super().__init__(**kwargs)

def _run(self):
if len(self.inputs.src_files) == 0:
raise ValueError('No input files provided to merge.')

filecontents = [utils.get_content(calc, relpath) for calc, relpath in self.inputs.src_files]

merged_filecontents = self.merge_function(filecontents)
Expand Down
14 changes: 10 additions & 4 deletions src/koopmans/projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ProjectionBlock(object):
def __init__(self,
projections: List[Union[str, Dict[str, Any]]],
spin: Optional[str] = None,
directory: Optional[Path] = None,
name: Optional[str] = None,
num_wann: Optional[int] = None,
num_bands: Optional[int] = None,
include_bands: Optional[List[int]] = None,
Expand All @@ -25,7 +25,7 @@ def __init__(self,
proj = proj_string_to_dict(proj)
self.projections.append(proj)
self.spin = spin
self.directory = directory
self.name = name
self.num_wann = num_wann
self.num_bands = num_bands
self.include_bands = include_bands
Expand Down Expand Up @@ -114,6 +114,12 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__
return False

def __getitem__(self, key):
return self._blocks[key]

def __setitem__(self, key, value):
self._blocks[key] = value

def divisions(self, spin: Optional[str]) -> List[int]:
# This algorithm works out the size of individual "blocks" in the set of bands
divs: List[int] = []
Expand Down Expand Up @@ -166,11 +172,11 @@ def blocks(self):
if len(to_exclude) > 0:
b.exclude_bands = list_to_formatted_str(to_exclude)

# Construct directory
# Construct name
label = f'block_{iblock + 1}'
if spin:
label = f'spin_{spin}_{label}'
b.directory = Path(label)
b.name = label

return self._blocks

Expand Down
2 changes: 1 addition & 1 deletion src/koopmans/workflows/_koopmans_dfpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _run(self):
# 1) Create the calculators
for band in self.bands.to_solve:
kc_screen_calc = self.new_calculator('kc_screen', i_orb=band.index)
kc_screen_calc.prefix += f'_band_{band.index}'
kc_screen_calc.prefix += f'_orbital_{band.index}'
kc_screen_calcs.append(kc_screen_calc)

# 2) Run the calculators (possibly in parallel)
Expand Down
148 changes: 75 additions & 73 deletions src/koopmans/workflows/_wannierize.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ def _run(self):

if self.parameters.init_orbitals in ['mlwfs', 'projwfs'] \
and self.parameters.init_empty_orbitals in ['mlwfs', 'projwfs']:
wannierize_blocks_subworkflow = WannierizeBlocksWorkflow.fromparent(self, force_nspin2=self._force_nspin2)
wannierize_blocks_subworkflow.run()

for block in self.projections:
wannierize_block_subworkflow = WannierizeBlockWorkflow.fromparent(
self, force_nspin2=self._force_nspin2, block=block)
wannierize_block_subworkflow.name = f'wannierize_{block.name}'
wannierize_block_subworkflow.run()

# Merging Hamiltonian files, U matrix files, centers files if necessary
if self.parent is not None:
Expand Down Expand Up @@ -199,7 +203,7 @@ def _run(self):
# Calculate a projected DOS
pseudos = [read_pseudo_file(calc_pw_bands.directory / calc_pw_bands.parameters.pseudo_dir / p) for p in
self.pseudopotentials.values()]
if all([p['header']['number_of_wfc'] > 0 for p in pseudos]) and False:
if all([p['header']['number_of_wfc'] > 0 for p in pseudos]):
calc_dos = self.new_calculator('projwfc', filpdos=self.name)
calc_dos.directory = 'pdos'
calc_dos.pseudopotentials = self.pseudopotentials
Expand All @@ -213,7 +217,6 @@ def _run(self):
dos = copy.deepcopy(calc_dos.results['dos'])
else:
# Skip if the pseudos don't have the requisite PP_PSWFC blocks
utils.warn('Manually disabling pDOS')
utils.warn('Some of the pseudopotentials do not have PP_PSWFC blocks, which means a projected DOS '
'calculation is not possible. Skipping...')
dos = None
Expand Down Expand Up @@ -301,82 +304,81 @@ def merge_wannier_files(self, block: List[projections.ProjectionBlock], filling_
self.run_process(merge_centers_proc)


class WannierizeBlocksWorkflow(Workflow):
class WannierizeBlockWorkflow(Workflow):

def __init__(self, *args, force_nspin2=False, **kwargs):
def __init__(self, *args, block: projections.ProjectionBlock, force_nspin2=False, **kwargs):
self._force_nspin2 = force_nspin2
self.block = block
super().__init__(*args, **kwargs)

def _run(self):
# Loop over the various subblocks that we must wannierize separately
for block in self.projections:
n_occ_bands = self.number_of_electrons(block.spin)
if not block.spin:
n_occ_bands /= 2

if max(block.include_bands) <= n_occ_bands:
# Block consists purely of occupied bands
init_orbs = self.parameters.init_orbitals
elif min(block.include_bands) > n_occ_bands:
# Block consists purely of empty bands
init_orbs = self.parameters.init_empty_orbitals
n_occ_bands = self.number_of_electrons(self.block.spin)
if not self.block.spin:
n_occ_bands /= 2

if max(self.block.include_bands) <= n_occ_bands:
# Block consists purely of occupied bands
init_orbs = self.parameters.init_orbitals
elif min(self.block.include_bands) > n_occ_bands:
# Block consists purely of empty bands
init_orbs = self.parameters.init_empty_orbitals
else:
# Block contains both occupied and empty bands
raise ValueError(f'{self.block} contains both occupied and empty bands. This should not happen.')
# Store the number of electrons in the ProjectionBlocks object so that it can work out which blocks to
# merge with one another
self.projections.num_occ_bands[self.block.spin] = n_occ_bands

calc_type = 'w90'
if self.block.spin:
calc_type += f'_{self.block.spin}'

# 1) pre-processing Wannier90 calculation
calc_w90_pp = self.new_calculator(calc_type, init_orbitals=init_orbs, **self.block.w90_kwargs)
calc_w90_pp.prefix = 'wannier90_preproc'
calc_w90_pp.command.flags = '-pp'
self.run_calculator(calc_w90_pp)

# 2) standard pw2wannier90 calculation
calc_p2w = self.new_calculator('pw2wannier', spin_component=self.block.spin)
calc_p2w.prefix = 'pw2wannier90'
calc_nscf = [c for c in self.calculations if isinstance(
c, calculators.PWCalculator) and c.parameters.calculation == 'nscf'][-1]
self.link(calc_nscf, calc_nscf.parameters.outdir, calc_p2w, calc_p2w.parameters.outdir)
self.link(calc_w90_pp, calc_w90_pp.prefix + '.nnkp', calc_p2w, calc_p2w.parameters.seedname + '.nnkp')
self.run_calculator(calc_p2w)

# 3) Wannier90 calculation
calc_w90 = self.new_calculator(calc_type, init_orbitals=init_orbs,
bands_plot=self.parameters.calculate_bands, **self.block.w90_kwargs)
calc_w90.prefix = 'wannier90'
for ext in ['.eig', '.amn', '.eig', '.mmn']:
self.link(calc_p2w, calc_p2w.parameters.seedname + ext, calc_w90, calc_w90.prefix + ext)
self.run_calculator(calc_w90)
self.block.w90_calc = calc_w90

if hasattr(self, 'bands'):
# Add centers and spreads info to self.bands
if self.block.spin is None:
remaining_bands = [b for b in self.bands if b.center is None and b.spin == 0]
else:
# Block contains both occupied and empty bands
raise ValueError(f'{block} contains both occupied and empty bands. This should not happen.')
# Store the number of electrons in the ProjectionBlocks object so that it can work out which blocks to
# merge with one another
self.projections.num_occ_bands[block.spin] = n_occ_bands

calc_type = 'w90'
if block.spin:
calc_type += f'_{block.spin}'

# 1) pre-processing Wannier90 calculation
calc_w90_pp = self.new_calculator(calc_type, init_orbitals=init_orbs, **block.w90_kwargs)
calc_w90_pp.prefix = 'wannier90_preproc'
calc_w90_pp.command.flags = '-pp'
self.run_calculator(calc_w90_pp)

# 2) standard pw2wannier90 calculation
calc_p2w = self.new_calculator('pw2wannier', spin_component=block.spin)
calc_p2w.prefix = 'pw2wannier90'
calc_nscf = [c for c in self.calculations if isinstance(
c, calculators.PWCalculator) and c.parameters.calculation == 'nscf'][-1]
self.link(calc_nscf, calc_nscf.parameters.outdir, calc_p2w, calc_p2w.parameters.outdir)
self.link(calc_w90_pp, calc_w90_pp.prefix + '.nnkp', calc_p2w, calc_p2w.parameters.seedname + '.nnkp')
self.run_calculator(calc_p2w)

# 3) Wannier90 calculation
calc_w90 = self.new_calculator(calc_type, init_orbitals=init_orbs,
bands_plot=self.parameters.calculate_bands, **block.w90_kwargs)
calc_w90.prefix = 'wannier90'
for ext in ['.eig', '.amn', '.eig', '.mmn']:
self.link(calc_p2w, calc_p2w.parameters.seedname + ext, calc_w90, calc_w90.prefix + ext)
self.run_calculator(calc_w90)
block.w90_calc = calc_w90

if hasattr(self, 'bands'):
# Add centers and spreads info to self.bands
if block.spin is None:
remaining_bands = [b for b in self.bands if b.center is None and b.spin == 0]
if self.block.spin == 'up':
i_spin = 0
else:
if block.spin == 'up':
i_spin = 0
else:
i_spin = 1
remaining_bands = [b for b in self.bands if b.center is None and b.spin == i_spin]

centers = calc_w90.results['centers']
spreads = calc_w90.results['spreads']
for band, center, spread in zip(remaining_bands, centers, spreads):
band.center = center
band.spread = spread

if block.spin is None and len(self.bands.get(spin=1)) > 0:
# Copy over spin-up results to spin-down
[match] = [b for b in self.bands if b.index == band.index and b.spin == 1]
match.center = center
match.spread = spread
i_spin = 1
remaining_bands = [b for b in self.bands if b.center is None and b.spin == i_spin]

centers = calc_w90.results['centers']
spreads = calc_w90.results['spreads']
for band, center, spread in zip(remaining_bands, centers, spreads):
band.center = center
band.spread = spread

if self.block.spin is None and len(self.bands.get(spin=1)) > 0:
# Copy over spin-up results to spin-down
[match] = [b for b in self.bands if b.index == band.index and b.spin == 1]
match.center = center
match.spread = spread

def new_calculator(self, calc_type, *args, **kwargs) -> CalcExtType: # type: ignore[type-var, misc]
init_orbs = kwargs.pop('init_orbitals', None)
Expand Down
22 changes: 7 additions & 15 deletions src/koopmans/workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, atoms: Atoms,
pseudopotentials: Dict[str, str] = {},
kpoints: Optional[Kpoints] = None,
projections: Optional[ProjectionBlocks] = None,
name: str = 'koopmans_workflow',
name: Optional[str] = None,
parameters: Union[Dict[str, Any], settings.WorkflowSettingsDict] = {},
calculator_parameters: Optional[Union[Dict[str, Dict[str, Any]],
Dict[str, settings.SettingsDict]]] = None,
Expand All @@ -126,7 +126,7 @@ def __init__(self, atoms: Atoms,
if self.parameters.is_valid(key):
self.parameters[key] = value
self.atoms: Atoms = atoms
self.name = name
self.name: str = self.__class__.__name__.lower() if name is None else name
self.calculations: List[calculators.Calc] = []
self.processes: List[Process] = []
self.silent = False
Expand Down Expand Up @@ -406,10 +406,9 @@ def fromparent(cls: Type[W], parent_wf: Workflow, **kwargs: Any) -> W:
wf_kwargs = dict(atoms=copy.deepcopy(parent_wf.atoms),
parameters=parameters,
calculator_parameters=copy.deepcopy(parent_wf.calculator_parameters),
name=copy.deepcopy(parent_wf.name),
pseudopotentials=copy.deepcopy(parent_wf.pseudopotentials),
kpoints=copy.deepcopy(parent_wf.kpoints),
projections=copy.deepcopy(parent_wf.projections),
projections=parent_wf.projections,
plotting=copy.deepcopy(parent_wf.plotting),
ml=copy.deepcopy(parent_wf.ml))
wf_kwargs.update(**{k: v for k, v in kwargs.items() if not parameters.is_valid(k)})
Expand Down Expand Up @@ -802,7 +801,7 @@ def _pre_run_calculator(self, qe_calc: calculators.Calc) -> bool:
is_complete = self.load_old_calculator(qe_calc)
if is_complete:
if not self.silent:
self.print(f'Not running {os.path.relpath(calc_file)} as it is already complete')
self.print(f'Not running {os.path.relpath(qe_calc.directory)} as it is already complete')

# Check the convergence of the calculation
qe_calc.check_convergence()
Expand Down Expand Up @@ -970,10 +969,6 @@ def _parent_context(self, subdirectory: Optional[str] = None,

assert self.parent is not None

# Automatically pass along the name of the overall workflow
if self.name is None:
self.name = self.parent.name

# Increase the indent level
self.print_indent = self.parent.print_indent + 2

Expand Down Expand Up @@ -1008,7 +1003,7 @@ def _parent_context(self, subdirectory: Optional[str] = None,
b.alpha_history = [b.alpha]
b.error_history = []

subdirectory = self.__class__.__name__.lower() if subdirectory is None else subdirectory
subdirectory = self.name if subdirectory is None else subdirectory

try:
# Prepend the step counter to the subdirectory name
Expand Down Expand Up @@ -1050,9 +1045,6 @@ def _parent_context(self, subdirectory: Optional[str] = None,
# Copy the entire bands object
self.parent.bands = self.bands

# Make sure any updates to the projections are passed along
self.parent.projections = self.projections

def todict(self):
# Shallow copy
dct = dict(self.__dict__)
Expand Down Expand Up @@ -1510,10 +1502,10 @@ def plot_bandstructure(self,
plt.subplots_adjust(right=0.85, wspace=0.05)

# Saving the figure to file (as png and also in editable form)
workflow_name = self.__class__.__name__.lower()
workflow_name = self.name
for s in ['workflow', 'mock', 'benchgen', 'stumbling', 'check']:
workflow_name = workflow_name.replace(s, '')
filename = filename if filename is not None else f'{self.name}_{workflow_name}_bandstructure'
filename = filename if filename is not None else f'{workflow_name}_bandstructure'
legends = [ax.get_legend() for ax in axes if ax.get_legend() is not None]
utils.savefig(fname=filename + '.png', bbox_extra_artists=legends, bbox_inches='tight')

Expand Down

0 comments on commit 775efec

Please sign in to comment.