Source code for cr.nimble._src.dsp.wht

# Copyright 2022-Present CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Fast Walsh Hadamard Transforms
"""

from jax import lax, jit
import jax.numpy as jnp



[docs]@jit def fwht(X): """Computes the Fast Walsh Hadamard Transform over columns Args: X (jax.numpy.ndarray): The 1D real signal or 2D matrix where each column is a signal whose transform is to be computed Returns: jax.numpy.ndarray: The Fast Walsh Hadamard Transform coefficients of (columns of) X """ n = X.shape[0] # number of stages s = (n-1).bit_length() def init1(): Y = jnp.empty(X.shape, dtype=X.dtype) A = X[0::2] B = X[1::2] Y = Y.at[0::2].set(A + B) Y = Y.at[1::2].set(A - B) return (Y, 1, 2, 4) def body1(state): # gap between x entries # number of x entries X, count, gap, step = state Y = jnp.empty(X.shape, dtype=X.dtype) J = 0 k = 0 def body2(state): Y, J, k = state def body3(state): Y, j, k = state # compute the four parts a = X[j] b = X[j+gap] c = X[j+1] d = X[j+1+gap] Y = Y.at[k].set(a+b) Y = Y.at[k+1].set(a-b) Y = Y.at[k+2].set(c-d) Y = Y.at[k+3].set(c+d) return (Y, j+2, k+4) def cond3(state): j = state[1] return j < J+gap-1 # the loop init3 = (Y, J, k) Y, j, k = lax.while_loop(cond3, body3, init3) return (Y, J + step, k) def cond2(state): k = state[2] return k < n - 1 init2 = Y, J, 0 Y, J, k = lax.while_loop(cond2, body2, init2) return (Y, count+1, 2*gap, 2*step) def cond1(state): count = state[1] return count < s state = lax.while_loop(cond1, body1, init1()) return state[0]