Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate rtx #2

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@
*.ogg
*.flac

/blind_rt60.egg-info/
/build/
/dist/
/.idea/misc.xml
/.idea/modules.xml
/.idea/inspectionProfiles/profiles_settings.xml
/.idea/inspectionProfiles/Project_Default.xml
/.idea/pyBlindRT.iml
/.idea/vcs.xml
/blind_rt60.egg-info/*
/build/*
/dist/*
/.idea/*
/local_history.patch
/notebooks/.ipynb_checkpoints/
/notebooks/.ipynb_checkpoints/*
16 changes: 16 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v17.0.6
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.12.1
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args:
- --profile=black
1 change: 1 addition & 0 deletions blind_rt60/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
"""

from .estimation import BlindRT60
from .utils import calculate_decay_time
from .version import __version__
147 changes: 104 additions & 43 deletions blind_rt60/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,34 @@
import scipy.signal as sig
from matplotlib.figure import Figure

from .utils import calculate_decay_time

# Constants
FRAME_LENGTH = 200e-3 # Frame length in seconds
EPS = np.finfo('float').eps
EPS = np.finfo("float").eps


class UpdateMethod:
NEWTON = 'newton'
BISECTED = 'bisected'
NEWTON = "newton"
BISECTED = "bisected"


class BlindRT60:
def __init__(self, fs: int = 8000, framelen: Optional[float] = None, hop: Optional[float] = None,
percentile: int = 50., a_init: int = 0.99, sigma2_init: int = 0.5, max_itr: int = 1000,
max_err: float = 1e-1, a_range: tuple = (0.99, 0.999999999), bisected_itr: int = 8,
sigma2_range: tuple = (0., np.inf), verbose: bool = False):
def __init__(
self,
fs: int = 8000,
framelen: Optional[float] = None,
hop: Optional[float] = None,
percentile: int = 50.0,
a_init: int = 0.99,
sigma2_init: int = 0.5,
max_itr: int = 1000,
max_err: float = 1e-1,
a_range: tuple = (0.99, 0.999999999),
bisected_itr: int = 8,
sigma2_range: tuple = (0.0, np.inf),
verbose: bool = False,
):
"""
Estimate the reverberation time (RT60) from the input signal.

Expand All @@ -38,7 +51,9 @@ def __init__(self, fs: int = 8000, framelen: Optional[float] = None, hop: Option
- sigma2_range: Range of valid values for 'sigma2'
"""
self.fs = fs
self.framelen = int(self.fs * FRAME_LENGTH) if framelen is None else int(self.fs * framelen)
self.framelen = (
int(self.fs * FRAME_LENGTH) if framelen is None else int(self.fs * framelen)
)
self.hop = int(self.framelen) // 4 if hop is None else int(self.fs * hop)
self.percentile = percentile
self.a_init = a_init
Expand Down Expand Up @@ -83,14 +98,17 @@ def sanity_check(self):
"""
Check the validity of input parameters.
"""
assert 0. <= self.percentile <= 100., 'gamma should be between 0 to 100'
assert self.framelen > 0, f'sigma2 should be larger than 0'
assert 0. < self.hop <= self.framelen, 'hop must be between 0 to framelen'
assert self.a_range[0] <= self.a_init < self.a_range[
1], f'a should be between {self.a_range[0]} to {self.a_range[1]}'
assert self.sigma2_init > 0., f'sigma2 should be larger than 0'

def likelihood_derivative(self, a: np.ndarray, x_frames: np.ndarray) -> (np.ndarray, np.ndarray, np.ndarray):
assert 0.0 <= self.percentile <= 100.0, "gamma should be between 0 to 100"
assert self.framelen > 0, f"sigma2 should be larger than 0"
assert 0.0 < self.hop <= self.framelen, "hop must be between 0 to framelen"
assert (
self.a_range[0] <= self.a_init < self.a_range[1]
), f"a should be between {self.a_range[0]} to {self.a_range[1]}"
assert self.sigma2_init > 0.0, f"sigma2 should be larger than 0"

def likelihood_derivative(
self, a: np.ndarray, x_frames: np.ndarray
) -> (np.ndarray, np.ndarray, np.ndarray):
"""
Calculate the first and second derivatives of the log-likelihood function with respect to 'a'.

Expand All @@ -103,13 +121,23 @@ def likelihood_derivative(self, a: np.ndarray, x_frames: np.ndarray) -> (np.ndar
- d2l_da2: Second derivative of the log-likelihood with respect to 'a'
- sigma2: Estimated variance of the signal
"""
a_x_prod = a ** (-2 * self.n) * x_frames ** 2
sigma2 = np.clip(np.mean(a_x_prod, axis=1, keepdims=True), a_min=self.sigma2_range[0],
a_max=self.sigma2_range[1])
dl_da = 1 / (a + EPS) * (
1 / (sigma2 + EPS) * np.sum(self.n * a_x_prod, axis=1, keepdims=True) - self.framelen_fac)
d2l_da2 = self.framelen_fac / (a ** 2 + EPS) + 1 / (sigma2 + EPS) * np.sum(
(1 - 2 * self.n) * self.n * a_x_prod, axis=1, keepdims=True)
a_x_prod = a ** (-2 * self.n) * x_frames**2
sigma2 = np.clip(
np.mean(a_x_prod, axis=1, keepdims=True),
a_min=self.sigma2_range[0],
a_max=self.sigma2_range[1],
)
dl_da = (
1
/ (a + EPS)
* (
1 / (sigma2 + EPS) * np.sum(self.n * a_x_prod, axis=1, keepdims=True)
- self.framelen_fac
)
)
d2l_da2 = self.framelen_fac / (a**2 + EPS) + 1 / (sigma2 + EPS) * np.sum(
(1 - 2 * self.n) * self.n * a_x_prod, axis=1, keepdims=True
)
return dl_da, d2l_da2, sigma2

def step(self, x_frames: np.ndarray, method: str) -> np.ndarray:
Expand Down Expand Up @@ -146,7 +174,9 @@ def step(self, x_frames: np.ndarray, method: str) -> np.ndarray:
self.a_lower[changed_sign] = middle_a[changed_sign]
self.a_upper[not_changed_sign] = middle_a[not_changed_sign]
else:
raise ValueError(f'method {method} should be {UpdateMethod.NEWTON} or {UpdateMethod.BISECTED}')
raise ValueError(
f"method {method} should be {UpdateMethod.NEWTON} or {UpdateMethod.BISECTED}"
)

self.a = np.clip(self.a, a_min=self.a_range[0], a_max=self.a_range[1])
dl_da, d2l_da2, self.sigma2 = self.likelihood_derivative(self.a, x_frames)
Expand Down Expand Up @@ -174,35 +204,49 @@ def visualize(self, x: np.ndarray, fs: int, ylim: Tuple = (0, 1)) -> Figure:
fig, axs = plt.subplots(nrows=1, ncols=2, width_ratios=(3, 1), sharey=True)

# Plot the input signal normalized to its maximum value
axs[0].plot(np.linspace(0.0, x_duration, len(x)), abs(x) / np.max(np.abs(x - np.mean(x))), label='Signal',
color='black')
axs[0].set_xlabel('Time [sec]')
axs[0].set_ylabel('Samples')
axs[0].tick_params(axis='y', labelcolor='black')
axs[0].plot(
np.linspace(0.0, x_duration, len(x)),
abs(x) / np.max(np.abs(x - np.mean(x))),
label="Signal",
color="black",
)
axs[0].set_xlabel("Time [sec]")
axs[0].set_ylabel("Samples")
axs[0].tick_params(axis="y", labelcolor="black")

# Plot the estimated reverberation times for each frame
axs0 = axs[0].twinx()
axs0.plot(np.linspace(0.0, x_duration, len(self.taus)), self.taus, label='Tau [sec]', color='c')
axs0.tick_params(axis='y', labelcolor='black')
axs0.set_ylabel('Time Constant [sec]')
axs0.plot(
np.linspace(0.0, x_duration, len(self.taus)),
self.taus,
label="Tau [sec]",
color="c",
)
axs0.tick_params(axis="y", labelcolor="black")
axs0.set_ylabel("Time Constant [sec]")

# Create a histogram of taus
bins = np.arange(ylim[0], ylim[1], 0.05)
axs[1].hist(self.taus, bins=bins, orientation='horizontal', color='c')
axs[1].axhline(self.tau, xmin=0.0, color='black')
axs[1].text(2, self.tau + 0.05, f'Tau {self.tau:.2f} sec', color='black')
axs[1].set_xlabel('Counts')
axs[1].hist(self.taus, bins=bins, orientation="horizontal", color="c")
axs[1].axhline(self.tau, xmin=0.0, color="black")
axs[1].text(2, self.tau + 0.05, f"Tau {self.tau:.2f} sec", color="black")
axs[1].set_xlabel("Counts")

# Set y-axis limits for all subplots
for ax in [axs[0], axs[1], axs0]:
ax.set_ylim(ylim)

# Add a title to the entire visualization
plt.suptitle(f'Blind RT60 Estimation | RT60 {self.rt60:.2f} sec')
plt.suptitle(f"Blind RT60 Estimation | RT60 {self.rt60:.2f} sec")

fig.tight_layout()
return fig

def calculate_rtx(self, decay_db):
if self.tau is None:
raise ValueError("tau has to be estimated first.")
return calculate_decay_time(decay_db, self.tau)

def estimate(self, x: np.ndarray, fs: int) -> float:
"""
Estimate the reverberation time (RT60) from the input signal.
Expand All @@ -217,24 +261,41 @@ def estimate(self, x: np.ndarray, fs: int) -> float:
self.sanity_check()
assert np.ndim(x) == 1

x = sig.decimate(x, int(fs // self.fs)) if fs > self.fs else sig.resample(x, int(len(x) * self.fs / fs))
x_frames = np.array([x[i:i + self.framelen] for i in range(0, len(x) - self.framelen + 1, self.hop)])
x = (
sig.decimate(x, int(fs // self.fs))
if fs > self.fs
else sig.resample(x, int(len(x) * self.fs / fs))
)
x_frames = np.array(
[
x[i : i + self.framelen]
for i in range(0, len(x) - self.framelen + 1, self.hop)
]
)
self.init_states(x_frames.shape[0])

itr = 0
while itr < self.bisected_itr or (itr < self.max_itr and np.any(np.bitwise_not(self.converged))):
method = UpdateMethod.BISECTED if itr < self.bisected_itr else UpdateMethod.NEWTON
while itr < self.bisected_itr or (
itr < self.max_itr and np.any(np.bitwise_not(self.converged))
):
method = (
UpdateMethod.BISECTED
if itr < self.bisected_itr
else UpdateMethod.NEWTON
)
dl_da = self.step(x_frames, method=method)
self.converged = np.abs(dl_da) <= self.max_err
itr += 1

self.taus = -1 / np.log(self.a) / self.fs
self.taus[np.bitwise_not(self.converged)] = np.nan
self.tau = np.percentile(self.taus[self.converged], q=self.percentile)
self.rt60 = -3 * self.tau / np.log10(np.e ** -1)
self.rt60 = -3 * self.tau / np.log10(np.e**-1)

if self.verbose:
print(f'Iteration {itr} / {self.max_itr}; rt60 {self.rt60:.2f} sec; tau {self.tau:.2f} sec')
print(
f"Iteration {itr} / {self.max_itr}; rt60 {self.rt60:.2f} sec; tau {self.tau:.2f} sec"
)

return self.rt60

Expand Down
28 changes: 28 additions & 0 deletions blind_rt60/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np


def calculate_decay_time(decay_db: float, tau: float) -> float:
"""
Calculates the decay time of a signal based on its decay in decibels (dB) and the time constant (tau).
decay_time = -decay_db / (20 * log10(e)) * tau

Parameters:
- decay_db (float): The decay of the signal in decibels (dB). Positive values indicate
attenuation, while negative values indicate amplification.
- tau (float): The time constant of the system, which represents the time it takes for
the signal to decay to 1/e (approximately 36.8%) of its initial value.
- e (float, optional): The mathematical constant e (approximately 2.71828). Defaults
to `np.e` for efficiency (already imported with `numpy`).

Returns:
float: The calculated decay time in the same units as `tau` (typically seconds, milliseconds, etc.).

Raises:
ValueError: If `decay_db` is not a finite number (i.e., NaN or Inf).
"""

if not np.isfinite(decay_db):
raise ValueError("decay_db must be a finite number (not NaN or Inf).")

decay_time = -decay_db / (20 * np.log10(np.e**-1)) * tau
return decay_time
2 changes: 1 addition & 1 deletion blind_rt60/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# blind_rt60 version
__version__ = '0.1.0-a0'
__version__ = "0.1.1"
2 changes: 1 addition & 1 deletion notebooks/blind_rt60.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "code",
"execution_count": 1,
Expand Down
20 changes: 9 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,27 @@

setuptools.setup(
name="blind_rt60",
version="0.1.0-a0",
version="0.1.1",
author="Asaf Zorea",
author_email="zoreasaf@gmail.com",
description="The BlindRT60 algorithm is used to estimate the reverberation time (RT60) "
"of a room based on the recorded audio signals from microphones",
"of a room based on the recorded audio signals from microphones",
long_description=long_description,
long_description_content_type="text/markdown",
license='MIT',
license="MIT",
url="https://github.com/nuniz/blind_rt60",
packages=setuptools.find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*", "tests.*"]),
packages=setuptools.find_packages(
exclude=["tests", "*.tests", "*.tests.*", "tests.*", "tests.*"]
),
include_package_data=True,
classifiers=[
"Programming Language :: Python :: 3",
'License :: OSI Approved :: MIT License',
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.6",
install_requires=[
"scipy",
"numpy",
"matplotlib"
],
install_requires=["scipy", "numpy", "matplotlib"],
extras_require={
"dev": ["pyroomacoustics", "parameterized"],
"dev": ["pyroomacoustics", "parameterized", "pre-commit"],
},
)
Loading
Loading