Skip to content

Commit

Permalink
Add STFT & Synshrosqueezed STFT
Browse files Browse the repository at this point in the history
**FEATURES**:
 - `stft`, `istft`, `ssq_stft`, and `issq_stft` implemented and validated
 - Added to `utils.py`: `buffer`, `unbuffer`, `window_norm`, `window_resolution`, and `window_area`
 - Replaced `numba.njit` with `numba.jit(nopython=True, cache=True)`, accelerating recomputing

**BREAKING**:
 - `cwt()` no longer returns `x_mean`
 - `padsignal` now only returns padded input by default; `get_params=True` for old behavior
 - Moved methods: `phase_cwt` & `phase_cwt_num` from `ssqueezing` to `_ssq_cwt`
 - _In future release_: return order of `cwt` and `stft` will be changed to have `Wx, dWx` and `Sx, dSx`, and `ssq_cwt` and `ssq_stft` to have `Tx, Wx` and `Tx, Sx`

**MISC**:
 - `wavelet` positional argument in `cwt` is now a keyword argument that defaults to `'morlet'`
 - Support for `padsignal(padtype='wrap')`
 - Docstring, comment cleanups
  • Loading branch information
OverLordGoldDragon authored Jan 14, 2021
1 parent 985d7a7 commit ac2fd4b
Show file tree
Hide file tree
Showing 16 changed files with 1,186 additions and 347 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ before_install:
script:
- >
pycodestyle --max-line-length=85
--ignore=E221,E241,E225,E226,E402,E722,E741,E272,E266,E302,E731,E702,E201,E129,W503,W504
--ignore=E221,E241,E225,E226,E402,E722,E741,E272,E266,E302,E731,E702,E201,E129,E203,W503,W504
ssqueezepy
- pytest -s --cov=ssqueezepy tests/

Expand Down
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
### 0.5.5 (1-14-2021): STFT & Synchrosqueezed STFT

#### FEATURES
- `stft`, `istft`, `ssq_stft`, and `issq_stft` implemented and validated
- Added to `utils.py`: `buffer`, `unbuffer`, `window_norm`, `window_resolution`, and `window_area`
- Replaced `numba.njit` with `numba.jit(nopython=True, cache=True)`, accelerating recomputing

#### BREAKING
- `cwt()` no longer returns `x_mean`
- `padsignal` now only returns padded input by default; `get_params=True` for old behavior
- Moved methods: `phase_cwt` & `phase_cwt_num` from `ssqueezing` to `_ssq_cwt`
- _In future release_: return order of `cwt` and `stft` will be changed to have `Wx, dWx` and `Sx, dSx`, and `ssq_cwt` and `ssq_stft` to have `Tx, Wx` and `Tx, Sx`

#### MISC
- `wavelet` positional argument in `cwt` is now a keyword argument that defaults to `'morlet'`
- Support for `padsignal(padtype='wrap')`
- Docstring, comment cleanups
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ Synchrosqueezing is a powerful _reassignment method_ that focuses time-frequency


## Features
- Forward & inverse CWT-based Synchrosqueezing
- Forward & inverse Continuous Wavelet Transform (CWT)
- Continuous Wavelet Transform (CWT), forward & inverse, and its Synchrosqueezing
- Short-Time Fourier Transform (STFT), forward & inverse, and its Synchrosqueezing
- Clean code with explanations and learning references
- Wavelet visualizations

### Coming soon
- Forward & inverse Short-Time Fourier Transform (STFT)
- STFT-based Synchrosqueezing
- Generalized Morse Wavelets

## Installation
Expand All @@ -38,7 +36,7 @@ Synchrosqueezing is a powerful _reassignment method_ that focuses time-frequency

<img src="https://user-images.githubusercontent.com/16495490/99880110-c88f1180-2c2a-11eb-8932-90bf3406a20d.png">

<img src="https://user-images.githubusercontent.com/16495490/99880131-f1170b80-2c2a-11eb-9ace-807df257ad23.png">
<img src="https://user-images.githubusercontent.com/16495490/104537035-9f8b6b80-5632-11eb-9fa4-444efec6c9be.png">

## Introspection

Expand All @@ -59,27 +57,37 @@ Synchrosqueezing is a powerful _reassignment method_ that focuses time-frequency
```python
import numpy as np
import matplotlib.pyplot as plt
from ssqueezepy import ssq_cwt
from ssqueezepy import ssq_cwt, ssq_stft

def viz(x, Tx, Wx):
plt.plot(x); plt.show()
plt.imshow(np.abs(Wx), aspect='auto', cmap='jet')
plt.show()
plt.imshow(np.flipud(np.abs(Tx)), aspect='auto', vmin=0, vmax=.2, cmap='jet')
plt.show()

#%%# Define signal ####################################
N = 2048
t = np.linspace(0, 10, N, endpoint=False)
xo = np.cos(2 * np.pi * np.exp(t / 3))
x = xo + np.sqrt(4) * np.random.randn(N)
xo = np.cos(2 * np.pi * 2 * (np.exp(t / 2.2) - 1))
xo += xo[::-1]
x = xo + np.sqrt(2) * np.random.randn(N)

plt.plot(xo); plt.show()
plt.plot(x); plt.show()

#%%# CWT + SSQ CWT ####################################
Twxo, _, Wxo, *_ = ssq_cwt(xo, 'morlet')
viz(xo, Twxo, Wxo)

Twx, _, Wx, *_ = ssq_cwt(x, 'morlet')
viz(x, Twx, Wx)

#%%# SSQ CWT + CWT ####################################
Txo, _, Wxo, *_ = ssq_cwt(xo, 'morlet')
viz(xo, Txo, Wxo)
#%%# STFT + SSQ STFT ##################################
Tsxo, _, Sxo, *_ = ssq_stft(xo)
viz(xo, Tsxo, np.flipud(Sxo))

Tx, _, Wx, *_ = ssq_cwt(x, 'morlet')
viz(x, Tx, Wx)
Tsx, _, Sx, *_ = ssq_stft(x)
viz(x, Tsx, np.flipud(Sx))
```

## Learning resources
Expand All @@ -95,12 +103,13 @@ The Discrete Fourier Transform lays the foundation of signal processing with rea

## References

`ssqueezepy` was originally ported from MATLAB's [Synchrosqueezing Toolbox](https://github.com/ebrevdo/synchrosqueezing), authored by E. Brevdo and G. Thakur [1]. Synchrosqueezed Wavelet Transform was introduced by I. Daubechies and S. Maes [2], which was followed-up in [3]. Many implementation details draw from [4].
`ssqueezepy` was originally ported from MATLAB's [Synchrosqueezing Toolbox](https://github.com/ebrevdo/synchrosqueezing), authored by E. Brevdo and G. Thakur [1]. Synchrosqueezed Wavelet Transform was introduced by I. Daubechies and S. Maes [2], which was followed-up in [3], and adapted to STFT in [4]. Many implementation details draw from [5].

1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu. ["The Synchrosqueezing algorithm for time-varying spectral analysis: robustness properties and new paleoclimate applications"](https://arxiv.org/abs/1105.0010), Signal Processing 93:1079-1094, 2013.
2. I. Daubechies, S. Maes. ["A Nonlinear squeezing of the CWT Based on Auditory Nerve Models"](https://services.math.duke.edu/%7Eingrid/publications/DM96.pdf).
2. I. Daubechies, S. Maes. ["A Nonlinear squeezing of the Continuous Wavelet Transform Based on Auditory Nerve Models"](https://services.math.duke.edu/%7Eingrid/publications/DM96.pdf).
3. I. Daubechies, J. Lu, H.T. Wu. ["Synchrosqueezed Wavelet Transforms: a Tool for Empirical Mode Decomposition"](https://arxiv.org/pdf/0912.2437.pdf), Applied and Computational Harmonic Analysis 30(2):243-261, 2011.
4. Mallat, S. ["Wavelet Tour of Signal Processing 3rd ed"](https://www.di.ens.fr/~mallat/papiers/WaveletTourChap1-2-3.pdf).
4. G. Thakur, H.T. Wu. ["Synchrosqueezing-based Recovery of Instantaneous Frequency from Nonuniform Samples"](https://arxiv.org/abs/1006.2533), SIAM Journal on Mathematical Analysis, 43(5):2078-2095, 2011.
5. Mallat, S. ["Wavelet Tour of Signal Processing 3rd ed"](https://www.di.ens.fr/~mallat/papiers/WaveletTourChap1-2-3.pdf).

## License

Expand Down
7 changes: 6 additions & 1 deletion examples/ridge_chirp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def echirp(N):
t = np.linspace(0, 10, N, False)
return np.cos(2 * np.pi * np.exp(t / 3)), t

#%%## Configure signal #######################################################
N = 2048
noise_var = 6 # noise variance; compare error against = 12
Expand All @@ -25,18 +26,22 @@ def echirp(N):
plot(xo); scat(xo, s=8, show=1)
plot(x); scat(x, s=8, show=1)
plot(axf, show=1)

#%%# Synchrosqueeze ##########################################################
kw = dict(wavelet=('morlet', {'mu': 4.5}), nv=32, scales='log')
Tx, *_ = ssq_cwt(x, t=ts, **kw)
Wx, *_ = cwt(x, t=ts, **kw)

#%%# Visualize ###############################################################
pkw = dict(abs=1, w=.86, h=.9, aspect='auto', cmap='bone')
pkw = dict(abs=1, cmap='bone')
_Tx = np.pad(Tx, [[4, 4]]) # improve display of top- & bottom-most freqs
imshow(Wx, **pkw)
imshow(np.flipud(_Tx), norm=(0, 4e-1), **pkw)

#%%# Estimate inversion ridge ###############################################
bw, slope, offset = .035, .45, .45
Cs, freqband = lin_band(Tx, slope, offset, bw, norm=(0, 4e-1))

#%%###########################################################################
xrec = issq_cwt(Tx, kw['wavelet'], Cs, freqband)[0]
plot(xo)
Expand Down
8 changes: 6 additions & 2 deletions ssqueezepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@
"""


__version__ = '0.5.1'
__version__ = '0.5.5'
__title__ = 'ssqueezepy'
__author__ = 'OverLordGoldDragon'
__license__ = __doc__
__project_url__ = 'https://github.com/OverLordGoldDragon/ssqueezepy'


from . import ssqueezing
from . import _ssq_cwt
from . import _cwt
from . import _stft
from . import _ssq_cwt
from . import _ssq_stft
from . import wavelets
from . import utils
from . import toolkit
Expand All @@ -46,7 +48,9 @@

from .ssqueezing import *
from ._ssq_cwt import *
from ._ssq_stft import *
from ._cwt import *
from ._stft import *
from .wavelets import *


Expand Down
36 changes: 13 additions & 23 deletions ssqueezepy/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .wavelets import Wavelet


def cwt(x, wavelet, scales='log', fs=None, t=None, nv=32, l1_norm=True,
def cwt(x, wavelet='morlet', scales='log', fs=None, t=None, nv=32, l1_norm=True,
derivative=False, padtype='reflect', rpadded=False, vectorized=True):
"""Continuous Wavelet Transform, discretized, as described in
Sec. 4.3.3 of [1] and Sec. IIIA of [2]. Uses a form of discretized
Expand Down Expand Up @@ -42,7 +42,7 @@ def cwt(x, wavelet, scales='log', fs=None, t=None, nv=32, l1_norm=True,
fs: float / None
Sampling frequency of `x`. Defaults to 1, which makes ssq
frequencies range from 1/dT to 0.5, i.e. as fraction of reference
frequencies range from 1/dT to 0.5*fs, i.e. as fraction of reference
sampling rate up to Nyquist limit; dT = total duration (N/fs).
Used to compute `dt`, which is only used if `derivative=True`.
Overridden by `t`, if provided.
Expand All @@ -64,10 +64,7 @@ def cwt(x, wavelet, scales='log', fs=None, t=None, nv=32, l1_norm=True,
Whether to compute and return `dWx`. Requires `fs` or `t`.
padtype: str
Pad scheme to apply on input. One of:
('zero', 'reflect', 'symmetric', 'replicate').
'zero' is most naive, while 'reflect' (default) partly mitigates
boundary effects. See `padsignal`.
Pad scheme to apply on input. See `help(utils.padsignal)`.
rpadded: bool (default False)
Whether to return padded Wx and dWx.
Expand All @@ -80,11 +77,9 @@ def cwt(x, wavelet, scales='log', fs=None, t=None, nv=32, l1_norm=True,
# Returns:
Wx: [na x n] np.ndarray (na = number of scales; n = len(x))
The CWT of `x`. (rows=scales, cols=timeshifts)
CWT of `x`. (rows=scales, cols=timeshifts)
scales: [na] np.ndarray
Scales at which CWT was computed.
x_mean: float
mean of `x` to use in inversion (CWT needs scale=inf to capture).
dWx: [na x n] np.ndarray
Returned only if `derivative=True`.
Time-derivative of the CWT of `x`, computed via frequency-domain
Expand Down Expand Up @@ -150,14 +145,12 @@ def _process_args(x, scales, nv, fs, t):

nv, dt = _process_args(x, scales, nv, fs, t)

x_mean = x.mean() # store original mean
n = len(x) # store original length
x, nup, n1, n2 = padsignal(x, padtype)
xp, nup, n1, _ = padsignal(x, padtype, get_params=True)

x -= x.mean()
xh = fft(x)
xp -= xp.mean()
xh = fft(xp)
wavelet = Wavelet._init_if_not_isinstance(wavelet)
scales = process_scales(scales, n, wavelet, nv=nv)
scales = process_scales(scales, len(x), wavelet, nv=nv)
pn = (-1)**np.arange(nup)

N_orig = wavelet.N
Expand All @@ -168,17 +161,17 @@ def _process_args(x, scales, nv, fs, t):

if not rpadded:
# shorten to pre-padded size
Wx = Wx[:, n1:n1 + n]
Wx = Wx[:, n1:n1 + len(x)]
if derivative:
dWx = dWx[:, n1:n1 + n]
dWx = dWx[:, n1:n1 + len(x)]
if not l1_norm:
# normalize energy per L2 wavelet norm, else already L1-normalized
Wx *= np.sqrt(scales)
if derivative:
dWx *= np.sqrt(scales)

return ((Wx, scales, x_mean, dWx) if derivative else
(Wx, scales, x_mean))
return ((Wx, scales, dWx) if derivative else
(Wx, scales))


def icwt(Wx, wavelet, scales='log', nv=None, one_int=True, x_len=None, x_mean=0,
Expand Down Expand Up @@ -217,10 +210,7 @@ def icwt(Wx, wavelet, scales='log', nv=None, one_int=True, x_len=None, x_mean=0,
infinite scale component). Default 0.
padtype: str
Pad scheme to apply on input. One of:
('zero', 'symmetric', 'replicate').
'zero' is most naive, while 'reflect' (default) partly mitigates
boundary effects. See `padsignal`.
Pad scheme to apply on input. See `help(utils.padsignal)`.
!!! currently uses only 'zero'
rpadded: bool (default False)
Expand Down
Loading

0 comments on commit ac2fd4b

Please sign in to comment.