Source code for cr.nimble._src.householder

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Some basic linear transformations
"""

import jax
import jax.numpy as jnp

from cr.nimble import promote_arg_dtypes
from cr.nimble import to_row_vec, to_col_vec


[docs]def householder_vec(x): """Computes a Householder vector for :math:`x` GVL4: Algorithm 5.1.1 """ x = promote_arg_dtypes(x) m = len(x) if m == 1: return jnp.array(0), jnp.array(0) x_1 = x[0] x_rest = x[1:] sigma = x_rest.T @ x_rest v = jnp.hstack((1, x_rest)) def non_zero_sigma(v): mu = jnp.sqrt(x_1*x_1 + sigma) v_1 = jax.lax.cond(x_1 >= 0, lambda _: x_1 - mu, lambda _: -sigma/(x_1 + mu), operand=None) v = v.at[0].set(v_1) beta = 2. * v_1 * v_1 / (sigma + v_1 * v_1) v = v / v_1 return v, beta def zero_sigma(v): beta = jax.lax.cond(x_1 >= 0, lambda _: 0., lambda _: -2., operand=None) return v, beta v, beta = jax.lax.cond(sigma == 0, zero_sigma, non_zero_sigma, operand=v) return v , beta
[docs]def householder_matrix(x): """Computes a Householder refection matrix for :math:`x` """ v, beta = householder_vec(x) return jnp.eye(len(x)) - beta * jnp.outer(v, v)
[docs]def householder_premultiply(v, beta, A): """Pre-multiplies a Householder reflection defined by :math:`v, beta` to a matrix A, PA """ assert v.ndim == 1 assert A.ndim == 2 vt = to_row_vec(v) v = to_col_vec(v) return A - (beta * v) @(vt @ A)
[docs]def householder_postmultiply(v, beta, A): """Post-multiplies a Householder reflection defined by :math:`v, beta` to a matrix A, AP """ assert v.ndim == 1 assert A.ndim == 2 vt = to_row_vec(v) v = to_col_vec(v) return A - (A @ v) @ (beta * vt)
def householder_vec_(x): """Computes a Householder vector for :math:`x` """ x = promote_arg_dtypes(x) m = len(x) if m == 1: return jnp.array(0), jnp.array(0) x_1 = x[0] x_rest = x[1:] sigma = x_rest.T @ x_rest v = jnp.hstack((1, x_rest)) if sigma == 0: if x_1 >= 0: beta = 0 else: beta = -2 else: mu = jnp.sqrt(x_1*x_1 + sigma) if x_1 <= 0: v = v.at[0].set(x_1 - mu) else: v = v.at[0].set(-sigma/(x_1 + mu)) v_1 = v[0] beta = 2 * v_1 * v_1 / (sigma + v_1 * v_1) v = v / v_1 return v , beta
[docs]def householder_ffm_jth_v_beta(A,j): """GVL4 EQ 5.1.4 v, beta calculation """ v = A[j+1:, j] ms = v.T @ v beta = 2/(1 + ms) v = jnp.hstack((1, v)) return v, beta
[docs]def householder_ffm_premultiply(A, C): """ Computes Q^T C where Q is stored in its factored form in A. Each column j, of A contains the essential part of the j-th Householder vector. GVL4 EQ 5.1.4 """ m, n = A.shape for j in range(n): v, beta = householder_ffm_jth_v_beta(A, j) C2 = householder_premultiply(v, beta, C[j:,:]) C = C.at[j:, :].set(C2) return C
[docs]def householder_ffm_backward_accum(A, k): """ Computes k columns of Q from the factored form representation of Q stored in A. GVL4 EQ 5.1.5 """ m, n = A.shape Q = jnp.eye(m,k) for j in range(n-1, -1, -1): v, beta = householder_ffm_jth_v_beta(A, j) QQ = householder_premultiply(v, beta, Q[j:,j:]) Q = Q.at[j:, j:].set(QQ) return Q
[docs]def householder_ffm_to_wy(A): """ Computes the WY representation of Q such that Q = I_m - W Y^T from the factored form representation GVL4 algorithm 5.1.2 """ m, r = A.shape v, beta = householder_ffm_jth_v_beta(A, 0) v = to_col_vec(v) Y = v W = beta * v for j in range (1, r-1): v, beta = householder_ffm_jth_v_beta(A, j) v = to_col_vec(v) v2 = jnp.vstack((jnp.zeros(j), v)) z = beta * (v2 - (W @ Y[j:,:].T)@v) W = jnp.hstack((W, z)) Y = jnp.hstack((Y, v2)) return W, Y
[docs]def householder_qr_packed(A): """Computes the QR = A factorization of A using Householder reflections. Returns packed factorization. Algorithm 5.2.1 """ A = promote_arg_dtypes(A) m, n = A.shape assert m >= n for j in range(n-1): x = A[j:, j] v, beta = householder_vec(x) A2 = householder_premultiply(v, beta, A[j:, j:]) A = A.at[j:, j:].set(A2) # place the essential part of the Householder vector A = A.at[j+1:,j].set(v[1:]) return A
[docs]def householder_split_qf_r(A): """Splits a packed QR factorization into QF and R """ # The upper triangular part is R R = jnp.triu(A) # The remaining lower triangular part of A is the factored form representation of Q QF = jnp.tril(A[:,:-1], -1) return QF, R
[docs]def householder_qr(A): """Computes the QR = A factorization of A using Householder reflections Algorithm 5.2.1 """ m, n = A.shape A = householder_qr_packed(A) QF , R = householder_split_qf_r(A) Q = householder_ffm_backward_accum(QF, n) return Q, R