Skip to content

Commit

Permalink
refactor: vectorized KKR internals; fixed typo
Browse files Browse the repository at this point in the history
  • Loading branch information
IruNikZe committed May 6, 2024
1 parent be5fa7b commit 652740e
Showing 1 changed file with 56 additions and 37 deletions.
93 changes: 56 additions & 37 deletions src/elli/kkr/kkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,79 +29,87 @@

# pylint: disable=invalid-name
from typing import Callable

import numpy as np


def _integrate_im(im: np.ndarray, x: np.ndarray, x_i: float) -> np.ndarray:
def _integrate_im(im: np.ndarray, x: np.ndarray, x_i: np.ndarray) -> np.ndarray:
"""Calculate the discrete imaginary sum (integral) for the kkr.
Args:
im (numpy.ndarray): The imaginary values from which to calculate.
x (numpy.ndarray): The x-axis on which to calculate.
x_i (float): The current point around which to integrate.
im (numpy.ndarray): The imaginary values from which to calculate. (shape (1, n))
x (numpy.ndarray): The x-axis on which to calculate. (shape (1, n))
x_i (numpy.ndarray): The current points around which to integrate. (shape (m, 1))
Returns:
numpy.ndarray: The integral sum
numpy.ndarray: The integral sum. (shape (m,))
"""
return np.sum(x * im / (x**2 - x_i**2))

return np.sum(x * im / (x * x - x_i * x_i), axis=1)


def _integrate_im_reciprocal(im: np.ndarray, x: np.ndarray, x_i: float) -> np.ndarray:
def _integrate_im_reciprocal(
im: np.ndarray, x: np.ndarray, x_i: np.ndarray
) -> np.ndarray:
"""Calculate the discrete imaginary sum (integral) for the kkr.
This formulation uses an 1/x axis to transform a wavelength axis.
Args:
im (numpy.ndarray): The imaginary values from which to calculate.
x (numpy.ndarray): The reciprocal x-axis on which to calculate.
x_i (float): The current point around which to integrate.
im (numpy.ndarray): The imaginary values from which to calculate. (shape (1, n))
x (numpy.ndarray): The reciprocal x-axis on which to calculate. (shape (1, n))
x_i (numpy.ndarray): The current point around which to integrate. (shape (m, 1))
Returns:
numpy.ndarray: The integral sum
numpy.ndarray: The integral sum. (shape (m,))
"""
return np.sum(im / (x - x**3 / x_i**2))

return np.sum(im / (x * (1.0 - x * x / (x_i * x_i))), axis=1)

def _integrate_re(re: np.ndarray, x: np.ndarray, x_i: float) -> np.ndarray:

def _integrate_re(re: np.ndarray, x: np.ndarray, x_i: np.ndarray) -> np.ndarray:
"""Calculate the discrete real sum (integral) for the kkr.
Args:
re (numpy.ndarray): The real values from which to calculate.
x (numpy.ndarray): The x-axis on which to calculate.
x_i (float): The current point around which to integrate.
re (numpy.ndarray): The real values from which to calculate. (shape (1, n))
x (numpy.ndarray): The x-axis on which to calculate. (shape (1, n))
x_i (numpy.ndarray): The current point around which to integrate. (shape (m, 1))
Returns:
numpy.ndarray: The real sum
numpy.ndarray: The real sum. (shape (m,))
"""
return np.sum(x_i * re / (x**2 - x_i**2))
return np.sum(x_i * re / (x * x - x_i * x_i), axis=1)


def _integrate_re_reciprocal(re: np.ndarray, x: np.ndarray, x_i: float) -> np.ndarray:
def _integrate_re_reciprocal(
re: np.ndarray, x: np.ndarray, x_i: np.ndarray
) -> np.ndarray:
"""Calculate the discrete real sum (integral) for the kkr.
This formulation uses an 1/x axis to transform a wavelength axis.
Args:
re (numpy.ndarray): The real values from which to calculate.
x (numpy.ndarray): The reciprocal x-axis on which to calculate.
x_i (float): The current point around which to integrate.
re (numpy.ndarray): The real values from which to calculate. (shape (1, n))
x (numpy.ndarray): The reciprocal x-axis on which to calculate. (shape (1, n))
x_i (float): The current point around which to integrate. (shape (m, 1))
Returns:
numpy.ndarray: The real sum
numpy.ndarray: The real sum. (shape (m,))
"""
return np.sum(re / (x_i - x**2 / x_i))

return np.sum(re / (x_i - x * x / x_i), axis=1)


def _calc_kkr(
t: np.ndarray,
x: np.ndarray,
trafo: Callable[[np.ndarray, np.ndarray, float], np.ndarray],
trafo: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
) -> np.ndarray:
"""Calculates the kramers-kronig relation
"""Calculates the Kramers-Kronig relation
according to Maclaurin's formula.
Args:
t (np.ndarray): The y-axis on which to transform.
x (np.ndarray): The x-axis on which to transform.
trafo (Callable[[np.ndarray, np.ndarray, float], np.ndarray]):
t (numpy.ndarray): The y-axis on which to transform.
x (numpy.ndarray): The x-axis on which to transform.
trafo (Callable[[numpy.ndarray, numpy.ndarray, numpy.ndarray], numpy.ndarray]):
The transformation function.
Raises:
Expand All @@ -110,21 +118,28 @@ def _calc_kkr(
Returns:
np.ndarray: The kkr transformed y-axis
"""

if len(t) != len(x):
raise ValueError(
"y- and x-axes arrays must have the same length, "
f"but have lengths {len(t)} and {len(x)}."
)

integral = np.zeros(len(t))
integral = np.empty(len(t))
interval = np.diff(x, prepend=x[1] - x[0])
odd_y, odd_x = t[1::2], x[1::2]
even_y, even_x = t[::2], x[::2]
for i, x_i in enumerate(x):
if i % 2 == 0:
integral[i] = trafo(odd_y, odd_x, x_i)
else:
integral[i] = trafo(even_y, even_x, x_i)
odd_slice = slice(1, None, 2)
even_slice = slice(0, None, 2)

integral[even_slice] = trafo(
t[np.newaxis, odd_slice],
x[np.newaxis, odd_slice],
x[even_slice, np.newaxis],
)
integral[odd_slice] = trafo(
t[np.newaxis, even_slice],
x[np.newaxis, even_slice],
x[odd_slice, np.newaxis],
)

return 4 / np.pi * interval * integral

Expand All @@ -147,6 +162,7 @@ def re2im(re: np.ndarray, x: np.ndarray) -> np.ndarray:
Returns:
numpy.ndarray: The transformed imaginary part.
"""

return _calc_kkr(re, x, _integrate_re)


Expand All @@ -168,6 +184,7 @@ def im2re(im: np.ndarray, x: np.ndarray) -> np.ndarray:
Returns:
numpy.ndarray: The transformed real part.
"""

return _calc_kkr(im, x, _integrate_im)


Expand All @@ -190,6 +207,7 @@ def re2im_reciprocal(re: np.ndarray, x: np.ndarray) -> np.ndarray:
Returns:
numpy.ndarray: The transformed imaginary part.
"""

return _calc_kkr(re, x, _integrate_re_reciprocal)


Expand All @@ -212,4 +230,5 @@ def im2re_reciprocal(im: np.ndarray, x: np.ndarray) -> np.ndarray:
Returns:
numpy.ndarray: The transformed real part.
"""

return _calc_kkr(im, x, _integrate_im_reciprocal)

0 comments on commit 652740e

Please sign in to comment.