Source code for cr.nimble._src.dsp.thresholding
# Copyright 2021 CR.Sparse Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
from jax import jit
[docs]def hard_threshold(x, K):
"""Returns the indices and corresponding values of largest K non-zero entries in a vector x
Args:
x (jax.numpy.ndarray): A sparse/compressible signal
K (int): The number of largest entries to be kept in x
Returns:
(jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of:
* The indices of K largest entries in x
* Corresponding entries in x
See Also:
:func:`hard_threshold_sorted`
:func:`hard_threshold_by`
"""
indices = jnp.argsort(jnp.abs(x))
I = indices[:-K-1:-1]
x_I = x[I]
return I, x_I
[docs]def hard_threshold_sorted(x, K):
"""Returns the sorted indices and corresponding values of largest K non-zero entries in a vector x
Args:
x (jax.numpy.ndarray): A sparse/compressible signal
K (int): The number of largest entries to be kept in x
Returns:
(jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of:
* The indices of K largest entries in x sorted in ascending order
* Corresponding entries in x
See Also:
:func:`hard_threshold`
"""
# Sort entries in x by their magnitude
indices = jnp.argsort(jnp.abs(x))
# Pick the indices of K-largest (magnitude) entries in x (from behind)
I = indices[:-K-1:-1]
# Make sure that indices are sorted in ascending order
I = jnp.sort(I)
# Pick corresponding values
x_I = x[I]
return I, x_I
[docs]def hard_threshold_by(x, t):
"""
Sets all entries in x to be zero which are less than t in magnitude
Args:
x (jax.numpy.ndarray): A sparse/compressible signal
t (float): The threshold value
Returns:
(jax.numpy.ndarray): x modified such that all values below t are set to 0
Note:
This function doesn't change the length of x and can be JIT compiled
See Also:
:func:`hard_threshold`
"""
valid = jnp.abs(x) >= t
return x * valid
[docs]def largest_indices_by(x, t):
"""
Returns the locations of all entries in x which are larger than t in magnitude
Args:
x (jax.numpy.ndarray): A sparse/compressible signal
t (float): The threshold value
Returns:
(jax.numpy.ndarray): An index vector of all entries in x which are above the threshold
Note:
This function cannot be JIT compiled as the length of output is data dependent
See Also:
:func:`hard_threshold_by`
"""
return jnp.where(jnp.abs(x) >= t)[0]
[docs]def energy_threshold(signal, fraction):
"""
Keeps only as much coefficients in signal so as to capture a fraction of signal energy
Args:
x (jax.numpy.ndarray): A signal
fraction (float): The fraction of energy to be preserved
Returns:
(jax.numpy.ndarray, jax.numpy.ndarray): A tuple comprising of:
* Signal after thresholding
* A binary mask of the indices to be kept
Note:
This function doesn't change the length of signal and can be JIT compiled
See Also:
:func:`hard_threshold`
"""
# signal length
n = signal.size
# compute energies
energies = signal ** 2
# sort in descending order
idx = jnp.argsort(energies)[::-1]
energies = energies[idx]
# total energy
s = jnp.sum(energies) * 1.
# normalize
energies = energies / s
# convert to a cmf
cmf = jnp.cumsum(energies)
# find the index
index = jnp.argmax(cmf >= fraction)
# build the mask
idx2 = jnp.arange(n)
mask = jnp.where(idx2 <= index, 1, 0)
# reshuffle the mask
mask = mask.at[idx].set(mask)
signal = signal * mask
return signal, mask