Skip to content

Commit

Permalink
fix broadband import bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiim committed Jul 31, 2024
1 parent 2eb17f5 commit ff9fabc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
2 changes: 1 addition & 1 deletion classical_doa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "1.1.1"
__version__ = "1.1.2"
__author__ = "Qian Xu"
2 changes: 1 addition & 1 deletion classical_doa/algorithm/broadband.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from classical_doa.algorithm.music import music
from classical_doa.algorithm.music_based import music
from classical_doa.algorithm.utils import (
divide_into_fre_bins,
get_noise_space,
Expand Down
2 changes: 1 addition & 1 deletion classical_doa/algorithm/sparse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import cvxpy as cp
import numpy as np

C = 3e8

Expand Down
78 changes: 48 additions & 30 deletions classical_doa/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@

def plot_spatial_spectrum(
spectrum,
ground_truth,
angle_grids,
num_signal,
ground_truth=None,
x_label="Angle",
y_label="Spectrum",
):
"""Plot spatial spectrum
Args:
spectrum: Spatial spectrum estimated by the algorithm
ground_truth: True incident angles
angle_grids: Angle grids corresponding to the spatial spectrum
num_signal: Number of signals
ground_truth: True incident angles
x_label: x-axis label
y_label: y-axis label
"""
Expand Down Expand Up @@ -54,11 +54,15 @@ def plot_spatial_spectrum(
ax.annotate(angle, xy=(angle, heights[i]))

# ground truth
for angle in ground_truth:
ax.axvline(x=angle, color="green", linestyle="--")
if ground_truth is not None:
for angle in ground_truth:
ax.axvline(x=angle, color="green", linestyle="--")

# set labels
ax.legend(["Spectrum", "Estimated", "Ground Truth"])
if ground_truth is not None:
ax.legend(["Spectrum", "Estimated", "Ground Truth"])
else:
ax.legend(["Spectrum", "Estimated"])

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
Expand All @@ -68,21 +72,21 @@ def plot_spatial_spectrum(

def plot_estimated_value(
estimates,
ground_truth,
ticks_min=-90,
ticks_max=90,
ground_truth=None,
x_label="Angle",
y_label="Spectrum",
):
"""Display estimated angle values
Args:
estimates: Angle estimates
ground_truth: True incident angles
ticks_min (int, optional): Minimum value for x-axis ticks.
Defaults to -90.
ticks_max (int, optional): Maximum value for x-axis ticks.
Defaults to 90.
ground_truth: True incident angles
x_label (str, optional): x-axis label. Defaults to "Angle".
y_label (str, optional): y-axis label. Defaults to "Spetrum".
"""
Expand All @@ -97,8 +101,9 @@ def plot_estimated_value(
ax.xaxis.set_minor_locator(plt.MultipleLocator(minor_space))

# ground truth
for angle in ground_truth:
truth_line = ax.axvline(x=angle, color="c", linestyle="--")
if ground_truth is not None:
for angle in ground_truth:
truth_line = ax.axvline(x=angle, color="c", linestyle="--")

# plot estimates
for angle in estimates:
Expand All @@ -109,16 +114,20 @@ def plot_estimated_value(
ax.set_ylabel(y_label)

# set legend
ax.legend([truth_line, estimate_line], ["Ground Truth", "Estimated"])
if ground_truth is not None:
ax.legend([truth_line, estimate_line], ["Ground Truth", "Estimated"])
else:
ax.legend([estimate_line], ["Estimated"])

plt.show()


def plot_spatial_spectrum_2d(
spectrum,
ground_truth,
azimuth_grids,
elevation_grids,
num_signal,
ground_truth=None,
x_label="Elevation",
y_label="Azimuth",
z_label="Spectrum",
Expand All @@ -127,17 +136,18 @@ def plot_spatial_spectrum_2d(
Args:
spectrum: Spatial spectrum estimated by the algorithm
ground_truth: True incident angles
azimuth_grids : Azimuth grids corresponding to the spatial spectrum
elevation_grids : Elevation grids corresponding to the spatial spectrum
num_signal: Number of signals
ground_truth: True incident angles
x_label: x-axis label
y_label: y-axis label
z_label : x-axis label. Defaults to "Spectrum".
"""
x, y = np.meshgrid(elevation_grids, azimuth_grids)
spectrum = spectrum / spectrum.max()
# Find the peaks in the surface
peaks = peak_local_max(spectrum, num_peaks=ground_truth.shape[1])
peaks = peak_local_max(spectrum, num_peaks=num_signal)
spectrum = np.log(spectrum + 1e-10)

fig = plt.figure()
Expand All @@ -161,19 +171,24 @@ def plot_spatial_spectrum_2d(
"({}, {})".format(x[peak[0], peak[1]], y[peak[0], peak[1]]),
)
# plot ground truth
truth_lines = ax.stem(
ground_truth[1],
ground_truth[0],
np.ones_like(ground_truth[0]),
bottom=spectrum.min(),
linefmt="g--",
markerfmt=" ",
basefmt=" ",
)

ax.legend(
[surf, truth_lines, peak_dot], ["Spectrum", "Estimated", "Ground Truth"]
)
if ground_truth is not None:
truth_lines = ax.stem(
ground_truth[1],
ground_truth[0],
np.ones_like(ground_truth[0]),
bottom=spectrum.min(),
linefmt="g--",
markerfmt=" ",
basefmt=" ",
)

if ground_truth is not None:
ax.legend(
[surf, truth_lines, peak_dot],
["Spectrum", "Estimated", "Ground Truth"],
)
else:
ax.legend([surf, peak_dot], ["Spectrum", "Estimated"])

ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
Expand All @@ -185,7 +200,8 @@ def plot_spatial_spectrum_2d(
def plot_estimated_value_2d(
estimated_azimuth,
estimated_elevation,
ground_truth,
num_signal,
ground_truth=None,
unit="deg",
x_label="Angle",
y_label="Spectrum",
Expand All @@ -194,6 +210,7 @@ def plot_estimated_value_2d(
Args:
estimates: Angle estimates
num_signal: Number of signals
ground_truth: True incident angles
ticks_min (int, optional): Minimum value for x-axis ticks.
Defaults to -90.
Expand All @@ -204,13 +221,14 @@ def plot_estimated_value_2d(
"""
if unit == "deg":
estimated_azimuth = estimated_azimuth / 180 * np.pi
ground_truth = ground_truth.astype(float)
ground_truth[0] = ground_truth[0] / 180 * np.pi

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection="polar")

ax.scatter(ground_truth[0], ground_truth[1], marker="o", color="g")
if ground_truth is not None:
ground_truth = ground_truth.astype(float)
ground_truth[0] = ground_truth[0] / 180 * np.pi
ax.scatter(ground_truth[0], ground_truth[1], marker="o", color="g")
ax.scatter(estimated_azimuth, estimated_elevation, marker="x", color="r")

ax.set_rlabel_position(90)
Expand Down

0 comments on commit ff9fabc

Please sign in to comment.