Skip to content

Commit

Permalink
remove old block style
Browse files Browse the repository at this point in the history
  • Loading branch information
maxscheurer committed Oct 16, 2024
1 parent 82cb34d commit 094916f
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 167 deletions.
83 changes: 9 additions & 74 deletions adcc/AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None,
for block, order in self.block_orders.items() if order is not None
}
# TODO Rename to self.block in 0.16.0
self.blocks_ph = {bl: blocks[bl].apply for bl in blocks}
self.blocks = {bl: blocks[bl].apply for bl in blocks}
if diagonal_precomputed:
self.__diagonal = diagonal_precomputed
else:
Expand All @@ -185,16 +185,16 @@ def __iadd__(self, other):
"""
if not isinstance(other, AdcExtraTerm):
return NotImplemented
if not all(k in self.blocks_ph for k in other.blocks):
if not all(k in self.blocks for k in other.blocks):
raise ValueError("Can only add to blocks of"
" AdcMatrix that already exist.")
for sp in other.blocks:
orig_app = self.blocks_ph[sp]
orig_app = self.blocks[sp]
other_app = other.blocks[sp].apply

def patched_apply(ampl, original=orig_app, other=other_app):
return sum(app(ampl) for app in (original, other))
self.blocks_ph[sp] = patched_apply
self.blocks[sp] = patched_apply
other_diagonal = sum(bl.diagonal for bl in other.blocks.values()
if bl.diagonal)
self.__diagonal = self.__diagonal + other_diagonal
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init_space_data(self, diagonal):
"""Update the cached data regarding the spaces of the ADC matrix"""
self.axis_spaces = {}
self.axis_lengths = {}
for block in diagonal.blocks_ph:
for block in diagonal.blocks:
self.axis_spaces[block] = getattr(diagonal, block).subspaces
self.axis_lengths[block] = np.prod([
self.mospaces.n_orbs(sp) for sp in self.axis_spaces[block]
Expand All @@ -249,27 +249,6 @@ def __repr__(self):
def __len__(self):
return self.shape[0]

@property
def blocks(self):
# TODO Remove in 0.16.0
return self.__diagonal.blocks

def has_block(self, block):
warnings.warn("The has_block function is deprecated and "
"will be removed in 0.16.0. "
"Use `in matrix.axis_blocks` in the future.")
return self.block_spaces(block) is not None

def block_spaces(self, block):
warnings.warn("The block_spaces function is deprecated and "
"will be removed in 0.16.0. "
"Use `matrix.axis_spaces[block]` in the future.")
return {
"s": self.axis_spaces.get("ph", None),
"d": self.axis_spaces.get("pphh", None),
"t": self.axis_spaces.get("ppphhh", None),
}[block]

@property
def axis_blocks(self):
"""
Expand All @@ -278,27 +257,11 @@ def axis_blocks(self):
"""
return list(self.axis_spaces.keys())

def diagonal(self, block=None):
@property
def diagonal(self):
"""Return the diagonal of the ADC matrix"""
if block is not None:
warnings.warn("Support for the block argument will be dropped "
"in 0.16.0.")
if block == "s":
return self.__diagonal.ph
if block == "d":
return self.__diagonal.pphh
return self.__diagonal

def compute_apply(self, block, tensor):
warnings.warn("The compute_apply function is deprecated and "
"will be removed in 0.16.0.")
if block in ("ss", "sd", "ds", "dd"):
warnings.warn("The singles-doubles interface is deprecated and "
"will be removed in 0.16.0.")
block = {"ss": "ph_ph", "sd": "ph_pphh",
"ds": "pphh_ph", "dd": "pphh_pphh"}[block]
return self.block_apply(block, tensor)

def block_apply(self, block, tensor):
"""
Compute the application of a block of the ADC matrix
Expand All @@ -311,7 +274,7 @@ def block_apply(self, block, tensor):
with self.timer.record(f"apply/{block}"):
outblock, inblock = block.split("_")
ampl = AmplitudeVector(**{inblock: tensor})
ret = self.blocks_ph[block](ampl)
ret = self.blocks[block](ampl)
return getattr(ret, outblock)

@timed_member_call()
Expand All @@ -320,21 +283,12 @@ def matvec(self, v):
Compute the matrix-vector product of the ADC matrix
with an excitation amplitude and return the result.
"""
return sum(block(v) for block in self.blocks_ph.values())
return sum(block(v) for block in self.blocks.values())

def rmatvec(self, v):
# ADC matrix is symmetric
return self.matvec(v)

def compute_matvec(self, ampl):
"""
Compute the matrix-vector product of the ADC matrix
with an excitation amplitude and return the result.
"""
warnings.warn("The compute_matvec function is deprecated and "
"will be removed in 0.16.0.")
return self.matvec(ampl)

def __matmul__(self, other):
if isinstance(other, AmplitudeVector):
return self.matvec(other)
Expand Down Expand Up @@ -565,25 +519,6 @@ def to_ndarray(self, out=None):
return out


class AdcBlockView(AdcMatrix):
def __init__(self, fullmatrix, block):
warnings.warn("The AdcBlockView class got deprecated and will be "
"removed in 0.16.0. Use the matrix.block_view "
"function instead.")
assert isinstance(fullmatrix, AdcMatrix)

self.__fullmatrix = fullmatrix
self.__block = block
if block == "s":
block_orders = dict(ph_ph=fullmatrix.block_orders["ph_ph"],
ph_pphh=None, pphh_ph=None, pphh_pphh=None)
else:
raise NotImplementedError(f"Block {block} not implemented")
super().__init__(fullmatrix.method, fullmatrix.ground_state,
block_orders=block_orders,
intermediates=fullmatrix.intermediates)


class AdcMatrixShifted(AdcMatrix):
def __init__(self, matrix, shift=0.0):
"""
Expand Down
65 changes: 6 additions & 59 deletions adcc/AmplitudeVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,12 @@


class AmplitudeVector(dict):
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
"""
Construct an AmplitudeVector. Typical use cases are
``AmplitudeVector(ph=tensor_singles, pphh=tensor_doubles)``.
"""
if args:
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.16.0. Use "
"AmplitudeVector(ph=tensor_singles, pphh=tensor_doubles) "
"instead.")
if len(args) == 1:
super().__init__(ph=args[0])
elif len(args) == 2:
super().__init__(ph=args[0], pphh=args[1])
else:
super().__init__(**kwargs)
super().__init__(**kwargs)

def __getattr__(self, key):
if self.__contains__(key):
Expand All @@ -53,54 +43,11 @@ def __setattr__(self, key, item):

@property
def blocks(self):
warnings.warn("The blocks attribute will change behaviour in 0.16.0.")
if sorted(self.blocks_ph) == ["ph", "pphh"]:
return ["s", "d"]
if sorted(self.blocks_ph) == ["pphh"]:
return ["d"]
elif sorted(self.blocks_ph) == ["ph"]:
return ["s"]
elif sorted(self.blocks_ph) == []:
return []
else:
raise NotImplementedError(self.blocks_ph)

@property
def blocks_ph(self):
"""
Return the blocks which are used inside the vector.
Note: This is a temporary name. The attribute will be removed in 0.16.0.
"""
return sorted(self.keys())

def __getitem__(self, index):
if index in (0, 1, "s", "d"):
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.16.0. Use "
"block labels like 'ph', 'pphh' instead.")
if index in (0, "s"):
return self.__getitem__("ph")
elif index in (1, "d"):
return self.__getitem__("pphh")
else:
raise KeyError(index)
else:
return super().__getitem__(index)

def __setitem__(self, index, item):
if index in (0, 1, "s", "d"):
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.16.0. Use "
"block labels like 'ph', 'pphh' instead.")
if index in (0, "s"):
return self.__setitem__("ph", item)
elif index in (1, "d"):
return self.__setitem__("pphh", item)
else:
raise KeyError(index)
else:
return super().__setitem__(index, item)

def copy(self):
"""Return a copy of the AmplitudeVector"""
return AmplitudeVector(**{k: t.copy() for k, t in self.items()})
Expand All @@ -112,7 +59,7 @@ def evaluate(self):

@property
def needs_evaluation(self):
return any(t.needs_evaluation for k, t in self.items())
return any(t.needs_evaluation for _, t in self.items())

def ones_like(self):
"""Return an empty AmplitudeVector of the same shape and symmetry"""
Expand Down Expand Up @@ -158,7 +105,7 @@ def __matmul__(self, other):

def __forward_to_blocks(self, fname, other):
if isinstance(other, AmplitudeVector):
if sorted(other.blocks_ph) != sorted(self.blocks_ph):
if sorted(other.blocks) != sorted(self.blocks):
raise ValueError("Blocks of both AmplitudeVector objects "
f"need to agree to perform {fname}")
ret = {k: getattr(tensor, fname)(other[k])
Expand Down Expand Up @@ -198,13 +145,13 @@ def __itruediv__(self, other):
return self.__forward_to_blocks("__itruediv__", other)

def __repr__(self):
return "AmplitudeVector(" + "=..., ".join(self.blocks_ph) + "=...)"
return "AmplitudeVector(" + "=..., ".join(self.blocks) + "=...)"

# __add__ is special because we want to be able to add AmplitudeVectors
# with missing blocks
def __add__(self, other):
if isinstance(other, AmplitudeVector):
allblocks = sorted(set(self.blocks_ph).union(other.blocks_ph))
allblocks = sorted(set(self.blocks).union(other.blocks))
ret = {k: self.get(k, 0) + other.get(k, 0) for k in allblocks}
ret = {k: v for k, v in ret.items() if v != 0}
else:
Expand Down
52 changes: 26 additions & 26 deletions adcc/ExcitedStates.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,29 +540,29 @@ def excitations(self):


# deprecated property names of ExcitedStates
deprecated = {
"excitation_energies": "excitation_energy",
"transition_dipole_moments": "transition_dipole_moment",
"transition_dms": "transition_dm",
"transition_dipole_moments_velocity":
"transition_dipole_moment_velocity",
"transition_magnetic_dipole_moments":
"transition_magnetic_dipole_moment",
"state_dipole_moments": "state_dipole_moment",
"state_dms": "state_dm",
"state_diffdms": "state_diffdm",
"oscillator_strengths": "oscillator_strength",
"oscillator_stenths_velocity": "oscillator_strength_velocity",
"rotatory_strengths": "rotatory_strength",
"excitation_vectors": "excitation_vector",
}

for dep_property in deprecated:
new_key = deprecated[dep_property]

def deprecated_property(self, key=new_key, old_key=dep_property):
warnings.warn(f"Property '{old_key}' is deprecated "
" and will be removed in version 0.16.0."
f" Please use '{key}' instead.")
return getattr(self, key)
setattr(ExcitedStates, dep_property, property(deprecated_property))
# deprecated = {
# "excitation_energies": "excitation_energy",
# "transition_dipole_moments": "transition_dipole_moment",
# "transition_dms": "transition_dm",
# "transition_dipole_moments_velocity":
# "transition_dipole_moment_velocity",
# "transition_magnetic_dipole_moments":
# "transition_magnetic_dipole_moment",
# "state_dipole_moments": "state_dipole_moment",
# "state_dms": "state_dm",
# "state_diffdms": "state_diffdm",
# "oscillator_strengths": "oscillator_strength",
# "oscillator_stenths_velocity": "oscillator_strength_velocity",
# "rotatory_strengths": "rotatory_strength",
# "excitation_vectors": "excitation_vector",
# }

# for dep_property in deprecated:
# new_key = deprecated[dep_property]

# def deprecated_property(self, key=new_key, old_key=dep_property):
# warnings.warn(f"Property '{old_key}' is deprecated "
# " and will be removed in version 0.16.0."
# f" Please use '{key}' instead.")
# return getattr(self, key)
# setattr(ExcitedStates, dep_property, property(deprecated_property))
5 changes: 2 additions & 3 deletions adcc/IsrMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def __init__(self, method, hf_or_mp, operator, block_orders=None):
variant=variant)
for block, order in self.block_orders.items() if order is not None
} for op in self.operator]
# TODO Rename to self.block in 0.16.0
self.blocks_ph = [{
self.blocks = [{
b: bl[b].apply for b in bl
} for bl in blocks]

Expand All @@ -136,7 +135,7 @@ def matvec(self, v):
"""
ret = [
sum(block(v) for block in bl_ph.values())
for bl_ph in self.blocks_ph
for bl_ph in self.blocks
]
if len(ret) == 1:
return ret[0]
Expand Down
4 changes: 2 additions & 2 deletions adcc/adc_pp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def check_doubles_amplitudes(spaces, *amplitudes):


def check_have_singles_block(*amplitudes):
if any("ph" not in amplitude.blocks_ph for amplitude in amplitudes):
if any("ph" not in amplitude.blocks for amplitude in amplitudes):
raise ValueError("ADC(0) level and "
"beyond expects an excitation amplitude with a "
"singles part.")


def check_have_doubles_block(*amplitudes):
if any("pphh" not in amplitude.blocks_ph for amplitude in amplitudes):
if any("pphh" not in amplitude.blocks for amplitude in amplitudes):
raise ValueError("ADC(2) level and "
"beyond expects an excitation amplitude with a "
"singles and a doubles part.")
Expand Down
2 changes: 1 addition & 1 deletion adcc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def lincomb(coefficients, tensors, evaluate=False):
return AmplitudeVector(**{
block: lincomb(coefficients, [ten[block] for ten in tensors],
evaluate=evaluate)
for block in tensors[0].blocks_ph
for block in tensors[0].blocks
})
elif not isinstance(tensors[0], libadcc.Tensor):
raise TypeError("Tensor type not supported")
Expand Down
4 changes: 2 additions & 2 deletions adcc/solver/explicit_symmetrisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def symmetrise(self, new_vectors):
if not isinstance(vec, AmplitudeVector):
raise TypeError("new_vectors has to be an "
"iterable of AmplitudeVector")
for b in vec.blocks_ph:
for b in vec.blocks:
if b not in self.symmetrisation_functions:
continue
vec[b] = evaluate(self.symmetrisation_functions[b](vec[b]))
Expand All @@ -88,7 +88,7 @@ def symmetrise(self, new_vectors):
# Only work on the doubles part
# the other blocks are not yet implemented
# or nothing needs to be done ("ph" block)
if "pphh" in vec.blocks_ph:
if "pphh" in vec.blocks:
# TODO: Note that the "d" is needed here because the C++ side
# does not yet understand ph and pphh
amplitude_vector_enforce_spin_kind(
Expand Down

0 comments on commit 094916f

Please sign in to comment.