Skip to content

Commit

Permalink
BlackmanWindow ivy-llc#19480 (ivy-llc#19882)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-branch <ivy.branch@lets-unify.ai>
Co-authored-by: Samsam Lee <106169847+jieunboy0516@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 15, 2023
1 parent 5102d09 commit 74c13af
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 0 deletions.
46 changes: 46 additions & 0 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,49 @@ def unsorted_segment_sum(
`segment_ids` equals to segment ID.
"""
return ivy.unsorted_segment_sum(self._data, segment_ids, num_segments)

def blackman_window(
self: ivy.Array,
/,
*,
periodic: bool = True,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
device: Optional[Union[ivy.Device, ivy.NativeDevice]] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.blackman_window. This method simply wraps the
function, and so the docstring for ivy.blackman_window also applies to this method with
minimal changes.
Parameters
----------
self
int.
periodic
If True, returns a window to be used as periodic function.
If False, return a symmetric window.
Default: ``True``.
dtype
output array data type. If ``dtype`` is ``None``, the output array data type
must be inferred from ``self``. Default: ``None``.
device
device on which to place the created array. If ``device`` is ``None``, the
output array device must be inferred from ``self``. Default: ``None``.
out
optional output array, for writing the result to. It must have a shape that
the inputs broadcast to.
Returns
-------
ret
The array containing the window.
Examples
--------
>>> ivy.blackman_window(4, periodic = True)
ivy.array([-1.38777878e-17, 3.40000000e-01, 1.00000000e+00, 3.40000000e-01])
>>> ivy.blackman_window(7, periodic = False)
ivy.array([-1.38777878e-17, 1.30000000e-01, 6.30000000e-01, 1.00000000e+00,
6.30000000e-01, 1.30000000e-01, -1.38777878e-17])
"""
return ivy.blackman_window(
self._data, periodic=periodic, dtype=dtype, device=device, out=out
)
93 changes: 93 additions & 0 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,3 +954,96 @@ def unsorted_segment_sum(
segment_ids,
num_segments,
)

@staticmethod
def static_blackman_window(
window_length: Union[int, ivy.Container],
periodic: bool = True,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.blackman_window. This method simply
wraps the function, and so the docstring for ivy.blackman_window also applies to
this method with minimal changes.
Parameters
----------
window_length
container including multiple window sizes.
periodic
If True, returns a window to be used as periodic function.
If False, return a symmetric window.
dtype
The data type to produce. Must be a floating point type.
out
optional output container, for writing the result to.
Returns
-------
ret
The container that contains the Blackman windows.
Examples
--------
With one :class:`ivy.Container` input:
>>> x = ivy.Container(a=3, b=5)
>>> ivy.Container.static_blackman_window(x)
{
a: ivy.array([-1.38777878e-17, 6.30000000e-01, 6.30000000e-01])
b: ivy.array([-1.38777878e-17, 2.00770143e-01, 8.49229857e-01,
8.49229857e-01, 2.00770143e-01])
}
"""
return ContainerBase.cont_multi_map_in_function(
"blackman_window",
window_length,
periodic,
dtype,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def blackman_window(
self: ivy.Container,
periodic: bool = True,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.blackman_window. This method simply
wraps the function, and so the docstring for ivy.blackman_window also applies to
this method with minimal changes.
Parameters
----------
self
input container with window sizes.
periodic
If True, returns a window to be used as periodic function.
If False, return a symmetric window.
dtype
The data type to produce. Must be a floating point type.
out
optional output container, for writing the result to.
Returns
-------
ret
The container containing the Blackman windows.
Examples
--------
With one :class:`ivy.Container` input:
>>> x = ivy.Container(a=3, b=5)
>>> ivy.blackman_window(x)
{
a: ivy.array([-1.38777878e-17, 6.30000000e-01, 6.30000000e-01])
b: ivy.array([-1.38777878e-17, 2.00770143e-01, 8.49229857e-01,
8.49229857e-01, 2.00770143e-01])
}
"""
return self.static_blackman_window(self, periodic, dtype, out=out)
20 changes: 20 additions & 0 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,23 @@ def unsorted_segment_sum(
data, segment_ids, num_segments
)
return jax.ops.segment_sum(data, segment_ids, num_segments)


def blackman_window(
size: int,
/,
*,
periodic: bool = True,
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if size < 2:
return jnp.ones([size], dtype=dtype)
if periodic:
count = jnp.arange(size) / size
else:
count = jnp.linspace(start=0, stop=size, num=size)
return (0.42 - 0.5 * jnp.cos(2 * jnp.pi * count)) + (
0.08 * jnp.cos(2 * jnp.pi * 2 * count)
)

11 changes: 11 additions & 0 deletions ivy/functional/backends/mxnet/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,14 @@ def tril_indices(
n_rows: int, n_cols: Optional[int] = None, k: int = 0, /, *, device: str
) -> Tuple[(Union[(None, mx.ndarray.NDArray)], ...)]:
raise IvyNotImplementedException()


def blackman_window(
size: int,
/,
*,
periodic: bool = True,
dtype: Optional[None] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()
24 changes: 24 additions & 0 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ def unsorted_segment_min(
return res


def blackman_window(
size: int,
/,
*,
periodic: bool = True,
dtype: Optional[np.dtype] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if size < 2:
return np.ones([size], dtype=dtype)
if periodic:
count = np.arange(size) / size
else:
count = np.linspace(start=0, stop=size, num=size)

return (
(0.42 - 0.5 * np.cos(2 * np.pi * count))
+ (0.08 * np.cos(2 * np.pi * 2 * count))
).astype(dtype)


blackman_window.support_native_out = False

def unsorted_segment_sum(
data: np.ndarray,
segment_ids: np.ndarray,
Expand All @@ -134,3 +157,4 @@ def unsorted_segment_sum(
res[i] = np.sum(data[mask_index], axis=0)

return res

22 changes: 22 additions & 0 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ def unsorted_segment_min(
return res




def blackman_window(
size: int,
/,
*,
periodic: Optional[bool] = True,
dtype: Optional[paddle.dtype] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if size < 2:
return paddle.ones([size], dtype=dtype)
if periodic:
count = paddle.arange(size) / size
else:
count = paddle.linspace(start=0, stop=size, num=size)
return (
(0.42 - 0.5 * paddle.cos(2 * math.pi * count))
+ (0.08 * paddle.cos(2 * math.pi * 2 * count))
).cast(dtype)

def unsorted_segment_sum(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
Expand Down Expand Up @@ -165,3 +186,4 @@ def unsorted_segment_sum(
res = paddle.cast(res, "int32")

return res

20 changes: 20 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,29 @@ def unsorted_segment_min(
return tf.math.unsorted_segment_min(data, segment_ids, num_segments)


def blackman_window(
size: int,
/,
*,
periodic: bool = True,
dtype: Optional[tf.DType] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if size < 2:
return tnp.ones([size], dtype=tnp.result_type(size, 0.0))
if periodic:
count = tnp.arange(size) / size
else:
count = tnp.linspace(start=0, stop=size, num=size)

return (0.42 - 0.5 * tnp.cos(2 * tnp.pi * count)) + (
0.08 * tnp.cos(2 * tnp.pi * 2 * count)
)

def unsorted_segment_sum(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)

20 changes: 20 additions & 0 deletions ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ def unsorted_segment_min(
return res


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def blackman_window(
size: int,
/,
*,
periodic: bool = True,
dtype: Optional[torch.dtype] = None,
out: Optional[torch.tensor] = None,
) -> torch.tensor:
return torch.blackman_window(
size,
periodic=periodic,
dtype=dtype,
)


blackman_window.support_native_out = False

def unsorted_segment_sum(
data: torch.Tensor,
segment_ids: torch.Tensor,
Expand All @@ -175,3 +193,5 @@ def unsorted_segment_sum(
res[i] = torch.sum(data[mask_index], dim=0)

return res


49 changes: 49 additions & 0 deletions ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def unsorted_segment_min(
return ivy.current_backend().unsorted_segment_min(data, segment_ids, num_segments)



@handle_exceptions
@handle_nestable
@to_native_arrays_and_back
Expand Down Expand Up @@ -687,6 +688,53 @@ def unsorted_segment_sum(
return ivy.current_backend().unsorted_segment_sum(data, segment_ids, num_segments)



@handle_exceptions
@handle_nestable
@handle_out_argument
@to_native_arrays_and_back
@infer_dtype
@handle_device_shifting
def blackman_window(
size: int,
*,
periodic: bool = True,
dtype: Optional[Union[ivy.Dtype, ivy.NativeDtype]] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Generate a Blackman window. The Blackman window is a taper formed by using the first
three terms of a summation of cosines. It was designed to have close to the minimal
leakage possible. It is close to optimal, only slightly worse than a Kaiser window.
Parameters
----------
window_length
the window_length of the returned window.
periodic
If True, returns a window to be used as periodic function.
If False, return a symmetric window.
dtype
The data type to produce. Must be a floating point type.
out
optional output array, for writing the result to.
Returns
-------
ret
The array containing the window.
Functional Examples
-------------------
>>> ivy.blackman_window(4, periodic = True)
ivy.array([-1.38777878e-17, 3.40000000e-01, 1.00000000e+00, 3.40000000e-01])
>>> ivy.blackman_window(7, periodic = False)
ivy.array([-1.38777878e-17, 1.30000000e-01, 6.30000000e-01, 1.00000000e+00,
6.30000000e-01, 1.30000000e-01, -1.38777878e-17])
"""
return ivy.current_backend().blackman_window(
size, periodic=periodic, dtype=dtype, out=out
)



@handle_exceptions
@handle_nestable
@infer_dtype
Expand Down Expand Up @@ -755,3 +803,4 @@ def random_tucker(
return ivy.TuckerTensor.tucker_to_tensor((core, factors))
else:
return ivy.TuckerTensor((core, factors))

Loading

0 comments on commit 74c13af

Please sign in to comment.