Source code for cr.nimble._src.dsp.dct
# Copyright 2022-Present CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discrete Cosine Transforms
Adapted from:
* http://www-personal.umich.edu/~mejn/computational-physics/dcst.py
* https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft
"""
from jax import jit
import jax.numpy as jnp
import jax.numpy.fft as jfft
[docs]def dct(y):
"""Computes the 1D Type-II DCT transform
Args:
y (jax.numpy.ndarray): The 1D real signal
Returns:
jax.numpy.ndarray: The Type-II Discrete Cosine Transform coefficients of y
"""
n = y.shape[0]
y2 = jnp.concatenate( (y[:], y[::-1]))
c = jfft.rfft(y2, axis=0)[:n]
ks = jnp.arange(n)
phi = jnp.exp(-1j*jnp.pi*ks/(2*n))
prod = (phi*c.T).T
return jnp.real(prod)
[docs]def idct(a):
"""Computes the 1D Type-II Inverse DCT transform
Args:
a (jax.numpy.ndarray): The Type-II DCT transform coefficients of a 1D real signal
Returns:
jax.numpy.ndarray: The 1D real signal y s.t. a = dct(y)
"""
n = a.shape[0]
shape = (1,)+a.shape[1:]
ks = jnp.arange(n)
phi = jnp.exp(1j*jnp.pi*ks/(2*n))
upper = (phi*a.T).T
lower = jnp.zeros(shape)
c = jnp.concatenate((upper, lower))
return jfft.irfft(c, axis=0)[:n]
[docs]def orthonormal_dct(y):
"""Computes the 1D Type-II DCT transform such that the transform is orthonormal
Args:
y (jax.numpy.ndarray): The 1D real signal
Returns:
jax.numpy.ndarray: The orthonormal Type-II Discrete Cosine Transform coefficients of y
Orthonormality ensures that
.. math::
\\langle a, a \\rangle = \\langle y, y \\rangle
"""
n = y.shape[0]
factor = jnp.sqrt(1/(2*n))
ks = jnp.arange(n)
phi = jnp.exp(-1j*jnp.pi*ks/(2*n))
# scaling to make the transform orthonormal
phi = phi.at[0].set(phi[0]*1/jnp.sqrt(2))
phi = phi * factor
y2 = jnp.concatenate( (y[:], y[::-1]))
c = jfft.rfft(y2, axis=0)[:n]
prod = jnp.real(phi*c.T).T
# phi = phi*jnp.sqrt(2)/n
return prod
[docs]def orthonormal_idct(a):
"""Computes the 1D Type-II IDCT transform such that the transform is orthonormal
Args:
a (jax.numpy.ndarray): The orthonormal Type-II DCT transform coefficients of a 1D real signal
Returns:
jax.numpy.ndarray: The 1D real signal y s.t. a = orthonormal_dct(y)
"""
n = a.shape[0]
factor = jnp.sqrt(2*n)
ks = jnp.arange(n)
phi = jnp.exp(1j*jnp.pi*ks/(2*n))
# scaling to make the transform orthonormal
phi = phi*factor
phi = phi.at[0].set(phi[0]*jnp.sqrt(2))
upper = (phi*a.T).T
lower = jnp.zeros((1,)+a.shape[1:])
c = jnp.concatenate((upper, lower))
return jfft.irfft(c, axis=0)[:n]