Source code for cr.nimble._src.toeplitz

# Copyright 2022 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax import jit
import jax.numpy as jnp
import jax.numpy.fft as jfft

[docs]def toeplitz_mat(c, r): """Constructs a Toeplitz matrix """ c = jnp.asarray(c) r = jnp.asarray(r) m = len(c) n = len(r) # assert c[0] == r[0] w = jnp.concatenate((c[::-1], r[1:])) # backwards indices a = -jnp.arange(m, dtype=int) # print(a) # forwards indices b = jnp.arange(m-1,m+n-1, dtype=int) # print(b) # combine indices for the toeplitz matrix indices = a[:, None] + b[None, :] # print(indices) # form the toeplitz matrix mat = w[indices] return mat
[docs]def toeplitz_mult(w, x): """Multiplies a Toeplitz matrix with a vector Note: Only real matrices and vectors are supported """ c, r = w m = len(c) n = len(r) p = m + n - 1 if x.ndim == 1: x = x.reshape(-1, 1) ww = jnp.concatenate((c, r[-1:0:-1])) wf = jfft.rfft(ww).reshape(-1, 1) xf = jfft.rfft(x, n=p, axis=0) yf = wf * xf y = jfft.irfft(yf, n=p, axis=0) # drop extra values y = y[:m, :] # drop extra dimension if required return jnp.squeeze(y)
[docs]def circulant_mat(c): """Constructs a circulant matrix """ # make sure that the array is flattened c = jnp.asarray(c).ravel() m = len(c) # extend c for the toeplitz structure cc = jnp.concatenate((c[::-1], c[:0:-1])) # backwards indices a = -jnp.arange(m, dtype=int) # forwards indices b = jnp.arange(m-1,m+m-1, dtype=int) # combine indices for the toeplitz matrix indices = a[:, None] + b[None, :] # form the circulant matrix mat = cc[indices] return mat
[docs]def circulant_mult(c, x): """Multiplies a circulant matrix with a vector Note: Only real matrices and vectors are supported """ if x.ndim == 1: x = x.reshape(-1, 1) # make sure that the array is flattened c = jnp.asarray(c).ravel() m = len(c) cf = jfft.rfft(c).reshape(-1, 1) xf = jfft.rfft(x, n=m, axis=0) yf = xf * cf y = jfft.irfft(yf, n=m, axis=0) # drop extra dimension if required return jnp.squeeze(y)