-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fourier #49
Add fourier #49
Changes from 13 commits
5e58d74
290f223
0a57674
8ebf2dc
2908a5c
497b31f
332d555
17a05a9
00aaa95
e94b606
aa67d9c
adb5efb
a62712e
07bd9f6
b704c98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,8 +65,10 @@ | |
# ----------------- | ||
# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, | ||
# please refer to the [Code References](../../../reference/nemos/basis). After instantiation, all classes | ||
# share the same syntax for basis evaluation. The following is an example of how to instantiate and | ||
# evaluate a log-spaced cosine raised function basis. | ||
# share the same syntax for basis evaluation. | ||
# | ||
# ### The Log-Spaced Raised Cosine Basis | ||
# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis. | ||
|
||
# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter | ||
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50) | ||
|
@@ -81,3 +83,89 @@ | |
plt.plot(samples, eval_basis) | ||
plt.show() | ||
|
||
# %% | ||
# ### The Fourier Basis | ||
# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and | ||
# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals. | ||
# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and | ||
# interpretation of the model parameters, each of which will represent the relative contribution of a specific | ||
# oscillation frequency to the overall signal. | ||
# | ||
# A Fourier basis can be instantiated with the following syntax: | ||
# the user can provide the maximum frequency of the cosine and negative | ||
# sine pairs by setting the `max_freq` parameter. | ||
# The sinusoidal basis elements will have frequencies from 0 to `max_freq`. | ||
|
||
|
||
fourier_basis = nmo.basis.FourierBasis(max_freq=3) | ||
|
||
# evaluate on equi-spaced samples | ||
samples, eval_basis = fourier_basis.evaluate_on_grid(1000) | ||
|
||
# plot the `sin` and `cos` separately | ||
plt.figure(figsize=(6, 3)) | ||
plt.subplot(121) | ||
plt.title("Cos") | ||
plt.plot(samples, eval_basis[:, :4]) | ||
plt.subplot(122) | ||
plt.title("Sin") | ||
plt.plot(samples, eval_basis[:, 4:]) | ||
plt.tight_layout() | ||
|
||
# %% | ||
# ## Fourier Basis Convolution and Fourier Transform | ||
BalzaniEdoardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is | ||
# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$ | ||
# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $. | ||
# | ||
# When computing the cross-correlation of a signal with the Fourier basis functions, | ||
# we essentially measure how well the signal correlates with sinusoids of different frequencies, | ||
# within a specified temporal window. This process mirrors the operation performed by the Fourier transform. | ||
# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here | ||
# is equivalent to computing the discrete Fourier transform on a sliding window of the same size | ||
# as that of the basis. | ||
BalzaniEdoardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
n_samples = 1000 | ||
max_freq = 20 | ||
|
||
# define a signal | ||
signal = np.random.normal(size=n_samples) | ||
|
||
# evaluate the basis | ||
_, eval_basis = nmo.basis.FourierBasis(max_freq=max_freq).evaluate_on_grid(n_samples) | ||
|
||
# compute the cross-corr with the signal and the basis | ||
# Note that we are inverting the time axis of the basis because we are aiming | ||
# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis. | ||
Comment on lines
+139
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just compute the correlation directly to avoid this confusion? It's true, but provides an extra hurdle for folks (and then we could call out this equivalency in an admonition) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still think this |
||
xcorr = np.array( | ||
[ | ||
np.convolve(eval_basis[::-1, k], signal, mode="valid")[0] | ||
for k in range(2 * max_freq + 1) | ||
] | ||
) | ||
|
||
# compute the power (add back sin(0 * t) = 0) | ||
fft_complex = np.fft.fft(signal) | ||
fft_amplitude = np.abs(fft_complex[:max_freq + 1]) | ||
fft_phase = np.angle(fft_complex[:max_freq + 1]) | ||
# compute the phase and amplitude from the convolution | ||
xcorr_phase = np.arctan2(np.hstack([[0], xcorr[max_freq+1:]]), xcorr[:max_freq+1]) | ||
xcorr_aplitude = np.sqrt(xcorr[:max_freq+1] ** 2 + np.hstack([[0], xcorr[max_freq+1:]]) ** 2) | ||
|
||
fig, ax = plt.subplots(1, 2) | ||
ax[0].set_aspect("equal") | ||
ax[0].set_title("Signal amplitude") | ||
ax[0].scatter(fft_amplitude, xcorr_aplitude) | ||
ax[0].set_xlabel("FFT") | ||
ax[0].set_ylabel("cross-correlation") | ||
|
||
ax[1].set_aspect("equal") | ||
ax[1].set_title("Signal phase") | ||
ax[1].scatter(fft_phase, xcorr_phase) | ||
ax[1].set_xlabel("FFT") | ||
ax[1].set_ylabel("cross-correlation") | ||
plt.tight_layout() | ||
|
||
print(f"Max Error Amplitude: {np.abs(fft_amplitude - xcorr_aplitude).max()}") | ||
print(f"Max Error Phase: {np.abs(fft_phase - xcorr_phase).max()}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
"OrthExponentialBasis", | ||
"AdditiveBasis", | ||
"MultiplicativeBasis", | ||
"FourierBasis", | ||
] | ||
|
||
|
||
|
@@ -103,7 +104,7 @@ def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: | |
# make sure array is at least 1d (so that we succeed when only | ||
# passed a scalar) | ||
xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) | ||
except TypeError: | ||
except (TypeError, ValueError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what additionally is being caught here? |
||
raise TypeError("Input samples must be array-like of floats!") | ||
|
||
# check for non-empty samples | ||
|
@@ -1086,7 +1087,8 @@ def _check_rates(self): | |
"linearly dependent set of function for the basis." | ||
) | ||
|
||
def _check_sample_range(self, sample_pts: NDArray): | ||
@staticmethod | ||
def _check_sample_range(sample_pts: NDArray): | ||
""" | ||
Check if the sample points are all positive. | ||
|
||
|
@@ -1177,6 +1179,96 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: | |
return super().evaluate_on_grid(n_samples) | ||
|
||
|
||
class FourierBasis(Basis): | ||
"""Set of 1D Fourier basis. | ||
|
||
This class defines a cosine and negative sine basis (quadrature pair) | ||
with frequencies ranging 0 to max_freq. | ||
|
||
Parameters | ||
---------- | ||
max_freq | ||
Highest frequency of the cosine, negative sine pairs. | ||
The number of basis function will be 2*max_freq + 1. | ||
""" | ||
|
||
def __init__(self, max_freq: int): | ||
super().__init__(n_basis_funcs=2 * max_freq + 1) | ||
|
||
self._frequencies = np.arange(max_freq + 1, dtype=float) | ||
self._n_input_dimensionality = 1 | ||
|
||
def _check_n_basis_min(self) -> None: | ||
"""Check that the user required enough basis elements. | ||
|
||
Checks that the number of basis is at least 1. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If an insufficient number of basis element is requested for the basis type. | ||
""" | ||
if self.n_basis_funcs < 0: | ||
raise ValueError( | ||
f"Object class {self.__class__.__name__} requires >= 1 basis elements. " | ||
f"{self.n_basis_funcs} basis elements specified instead" | ||
) | ||
|
||
def evaluate(self, sample_pts: ArrayLike) -> NDArray: | ||
"""Generate basis functions with given spacing. | ||
|
||
Parameters | ||
---------- | ||
sample_pts | ||
Spacing for basis functions. | ||
|
||
Returns | ||
------- | ||
basis_funcs | ||
Evaluated Fourier basis, shape (n_samples, n_basis_funcs). | ||
|
||
Notes | ||
----- | ||
The frequencies are set to np.arange(max_freq+1), convolving a signal | ||
of length n_samples with this basis is equivalent, but slower, | ||
then computing the FFT truncated to the first max_freq components. | ||
|
||
Therefore, convolving a signal with this basis is equivalent | ||
to compute the FFT over a sliding window. | ||
|
||
Examples | ||
-------- | ||
>>> import nemos as nmo | ||
>>> import numpy as np | ||
>>> n_samples, max_freq = 1000, 10 | ||
>>> basis = nmo.basis.FourierBasis(max_freq) | ||
>>> eval_basis = basis.evaluate(np.linspace(0, 1, n_samples)) | ||
>>> sinusoid = np.cos(3 * np.arange(0, 1000) * np.pi * 2 / 1000.) | ||
>>> conv = [np.convolve(eval_basis[::-1, k], sinusoid, mode='valid')[0] for k in range(2*max_freq+1)] | ||
>>> fft = np.fft.fft(sinusoid) | ||
>>> print('FFT power: ', np.round(np.real(fft[:max_freq]), 4)) | ||
>>> print('Convolution: ', np.round(conv[:max_freq], 4)) | ||
Comment on lines
+1241
to
+1250
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way in mkdocs to set this as a python codeblock? Can you remove the for example at the bottom here the arrows render and get copied: https://nemos.readthedocs.io/en/latest/reference/nemos/utils/#nemos.utils.pytree_map_and_reduce |
||
""" | ||
(sample_pts,) = self._check_evaluate_input(sample_pts) | ||
# assumes equi-spaced samples. | ||
if sample_pts.shape[0] / np.max(self._frequencies) < 2: | ||
raise ValueError("Not enough samples, aliasing likely to occur!") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe report |
||
|
||
# rescale to [0, 2pi) | ||
mn, mx = np.nanmin(sample_pts), np.nanmax(sample_pts) | ||
# first sample in 0, last sample in 2 pi - 2 pi / n_samples. | ||
sample_pts = ( | ||
2 | ||
* np.pi | ||
* (sample_pts - mn) | ||
/ (mx - mn) | ||
* (1.0 - 1.0 / sample_pts.shape[0]) | ||
) | ||
# create the basis | ||
angles = np.einsum("i,j->ij", sample_pts, self._frequencies) | ||
return np.concatenate([np.cos(angles), -np.sin(angles[:, 1:])], axis=1) | ||
|
||
|
||
def mspline(x: NDArray, k: int, i: int, T: NDArray): | ||
"""Compute M-spline basis function. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably explain orthogonal here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still think this. at least a foot note or link