Source code for cr.nimble._src.svdpack.lanbpro

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax import lax, jit, vmap, random
import jax.numpy as jnp
from jax.numpy.linalg import norm


import cr.nimble as cnb
from cr.nimble import AH_v

from .reorth import reorth_mgs
from .lanbpro_utils import (
    LanBDOptions,
    LanBProState,
    lanbpro_options_init,
    do_elr,
    update_mu,
    update_nu,
    compute_ind,
    bpro_norm_estimate
)

FUDGE = 1.01
M2 = 3/2
N2 = 3/2

KEY = random.PRNGKey(0)
KEYS = random.split(KEY, 20)


[docs]def lanbpro_init(A, k, p0, options: LanBDOptions): """Initialize the state with a starting vector """ # we follow steps from page 30 of Larsen paper m, n = A.shape U = jnp.zeros((m, k)) V = jnp.zeros((n, k)) alpha = jnp.zeros(k) beta = jnp.zeros(k+1) mu = jnp.zeros(k) nu = jnp.zeros(k) mumax = jnp.zeros(k) numax = jnp.zeros(k) indices = jnp.zeros(k, dtype=bool) anorms = jnp.zeros(k) # options delta = options.delta eps = options.eps gamma = options.gamma # whether full reorthogonalization will be done b_fro = delta == 0 # initial value of force reorthogonalization b_force_reorth = False # step 1 p_norm = norm(p0) # beta_0 beta = beta.at[0].set(p_norm) # U_0 u = cnb.vec_safe_divide_by_scalar(p0, p_norm) U = U.at[:, 0].set(u) # step 2 r update r = AH_v(A, u) # step 2.1b alpha update r_norm = norm(r) alpha = alpha.at[0].set(r_norm) anorm = FUDGE * r_norm # step 2.1b v update v = cnb.vec_safe_divide_by_scalar(r, r_norm) V = V.at[:, 0].set(v) # step 2.1b p update p = A @ v - alpha[0] * u # step 2.2b beta update p_norm = norm(p) p, p_norm, p_proj = do_elr(u, p, p_norm, options.gamma) # Check for convergence or failure to maintain semiorthogonality semiorth_cond = p_norm < max(m, n) * anorm * eps p, p_norm, b_force_reorth, indices = lax.cond(semiorth_cond, # compute a new random p vector orthogonal to previous U lambda _ : (new_p_vec(A, U, 1, gamma)[0], 0., True, indices.at[0].set(True)), lambda _ : (p, p_norm, b_force_reorth, indices), None) beta = beta.at[1].set(p_norm) # update anorm estimate anorm = jnp.maximum(anorm,FUDGE*jnp.hypot(alpha[0],beta[1])) # update mu for the first iteration before computation of U_1 eps = jnp.finfo(float).eps eps1 = 100*eps/2 T = eps1*(anorm + jnp.hypot(alpha[0],beta[1]) + jnp.hypot(alpha[0],beta[0]) ) # TODO this is problematic if p_norm is 0 mu0 = lax.cond(p_norm, lambda _ : T / p_norm, lambda _ : 0., None) mu = mu.at[0].set(mu0) # mumax update mumax = mumax.at[0].set(jnp.abs(mu[0])) # TODO add condition with elr > 0 mu = mu.at[0].set(M2 * eps) # prepare state after completion of one iteration anorms = anorms.at[0].set(anorm) return LanBProState(p=p, U=U, V=V, alpha=alpha, beta=beta, mu=mu, nu=nu, mumax=mumax, numax=numax, anorm=anorm, anorms=anorms, indices=indices, b_force_reorth=b_force_reorth, b_fro=b_fro, iterations=1, )
[docs]def lanbpro_iteration(A, state: LanBProState, options: LanBDOptions): """One single (j-th) iteration of Lanczos bidiagonalization with partial reorthogonalization algorithm """ m, n = A.shape max_m_n = max(m, n) # copy variables from the state p = state.p U = state.U V = state.V alpha = state.alpha beta = state.beta mu = state.mu nu = state.nu mumax = state.mumax numax = state.numax indices = state.indices anorm = state.anorm b_fro = state.b_fro b_force_reorth = state.b_force_reorth b_est_anorm = state.b_est_anorm # the total number of iterations k = len(alpha) # index for k idx = jnp.arange(k) # iteration number j = state.iterations # first j indices mask j_mask = idx < j jp1_mask = idx <= j # options gamma = options.gamma elr = options.elr eps = options.eps delta = options.delta eta = options.eta # carry out the work for one iteration of lanbpro beta_j = beta[j] # compute next left singular vector u = cnb.vec_safe_divide_by_scalar(p, beta_j) U = U.at[:, j].set(u) # once sufficient iterations are completed, we can # compute a better estimate of a norm anorm, b_est_anorm = lax.cond(j == 5, # Replace norm estimate with largest Ritz value. lambda _ : (FUDGE*bpro_norm_estimate(alpha, beta), False), # continue with current value lambda _ : (anorm, b_est_anorm), None ) # step 2 r update v_jm1 = V[:, j-1] r = AH_v(A, u) - beta_j * v_jm1 r_norm = norm(r) # elr condition b_no_fro = jnp.logical_not(b_fro) elr_cond = jnp.logical_and(jnp.logical_and(r_norm < gamma * beta_j, elr), b_no_fro) # extended local reorthogonalization of r w.r.t. previous v_j r, r_norm, proj = lax.cond(elr_cond, lambda _ : do_elr(v_jm1, r, r_norm, gamma), lambda _ : (r, r_norm, 0.), None ) # save updated r_norm in alpha_j alpha = alpha.at[j].set(r_norm) # make changes to beta_j if required. beta = beta.at[j].add(proj) # norm estimate anorm_up_1 = lambda anorm: jnp.maximum(anorm, FUDGE*jnp.sqrt(alpha[0]**2+beta[1]**2+alpha[1]*beta[1])) anorm_up_j = lambda anorm: jnp.maximum(anorm, FUDGE*jnp.sqrt(alpha[j-1]**2+beta[j]**2 + alpha[j-1]*beta[j-1] + alpha[j]*beta[j])) anorm = lax.cond(b_est_anorm, # We need to update norm estimate lambda anorm : lax.cond(j == 1, anorm_up_1, anorm_up_j, anorm), # no more norm estimation needed lambda anorm : anorm, anorm) # nu update condition nu_update_cond = jnp.logical_and(b_no_fro, r_norm != 0) nu, numax = lax.cond(nu_update_cond, lambda nu: update_nu(nu, numax, mu, j, alpha, beta, anorm), lambda nu : (nu, numax), nu ) # if elr is on, then current vector is orthogonalized against previous one nu = lax.cond(elr > 0, lambda nu: nu.at[j-1].set(N2 * eps), lambda nu: nu, nu) # condition for partial or full reorthogonalization reorth_cond = jnp.logical_or(b_fro, numax[j] > delta) reorth_cond = jnp.logical_or(reorth_cond, b_force_reorth) reorth_cond = jnp.logical_and(reorth_cond, alpha[j] != 0) # function to reorth r w.r.t. previous V def reorth_r(_): # identify the indices at which partial or full reorthogonalization will be done reorth_indices = lax.cond(jnp.logical_or(b_fro, eta == 0), # full reorthogonalization case lambda _ : j_mask, # partial reorthogonalization case lambda _ : lax.cond(b_force_reorth, lambda _ : indices, lambda _ : compute_ind(nu, delta, eta), None ), None ) # carry out reorthogonalization r2, r2_norm, iters = reorth_mgs(V, r, r_norm, reorth_indices, gamma) # reset nu in the entries which have been reorthogonalized nu2 = jnp.where(reorth_indices, N2*eps, nu) # if a reorthogonalization was forced. it won't be for next iteration b_force_reorth2 = jnp.logical_not(b_force_reorth) return r2, r2_norm, reorth_indices, nu2, b_force_reorth2 # reorthogonalize r if required r, r_norm, indices, nu, b_force_reorth = lax.cond(reorth_cond, reorth_r, lambda _ : (r, r_norm, indices, nu, b_force_reorth), None ) # Check for convergence or failure to maintain semiorthogonality # this is the case where r is in the column space of previous V vectors semiorth_cond = r_norm < max(m, n) * anorm * eps r, r_norm, b_force_reorth, indices = lax.cond(semiorth_cond, # compute a new random r vector orthogonal to previous V lambda _ : (new_r_vec(A, V, j, gamma)[0], 0., True, j_mask), lambda _ : (r, r_norm, b_force_reorth, indices), None) # update alpha_j again if required alpha = alpha.at[j].set(r_norm) # step 2.1b v update v = cnb.vec_safe_divide_by_scalar(r, r_norm) V = V.at[:, j].set(v) # Lanczos step to generate u_{j+1} # step 2.1b p update p = A @ v - alpha[j] * u # step 2.2b beta update p_norm = norm(p) # elr condition for p elr_cond = jnp.logical_and(jnp.logical_and(p_norm < gamma * r_norm, elr), b_no_fro) # extended local reorthogonalization of p w.r.t. previous u_j p, p_norm, proj = lax.cond(elr_cond, lambda _ : do_elr(u, p, p_norm, gamma), lambda _ :(p, p_norm, 0.), None ) # save updated p_norm in beta_{j+1} (if there are any changes) beta = beta.at[j+1].set(p_norm) # make changes to alpha_j if required. alpha = alpha.at[j].add(proj) anorm = lax.cond(b_est_anorm, # we need to update anorm estimate lambda anorm: jnp.maximum(anorm,FUDGE*jnp.sqrt(alpha[j]**2+beta[j+1]**2+alpha[j]*beta[j])), # no further need to update norm estimate lambda anorm: anorm, anorm) # mu update condition mu_update_cond = jnp.logical_and(b_no_fro, p_norm != 0) mu, mumax = lax.cond(mu_update_cond, lambda mu: update_mu(mu, mumax, nu, j, alpha, beta, anorm), lambda mu : (mu, mumax), mu ) # TODO add condition with elr > 0 mu = mu.at[j].set(M2 * eps) # condition for partial or full reorthogonalization reorth_cond = jnp.logical_or(b_fro, mumax[j] > delta) reorth_cond = jnp.logical_or(reorth_cond, b_force_reorth) reorth_cond = jnp.logical_and(reorth_cond, p_norm != 0) # function to reorth p w.r.t. previous U def reorth_p(_): # identify the indices at which partial or full reorthogonalization will be done # from U_0 to U_j [j+1 vectors] have already been computed reorth_indices = lax.cond(jnp.logical_or(b_fro, eta == 0), # full reorthogonalization case lambda _ : jp1_mask, # partial reorthogonalization case lambda _ : lax.cond(b_force_reorth, # for forced reorth, we need to add one more vec lambda _ : indices.at[k - jnp.argmax(indices[::-1])].set(True), lambda _ : compute_ind(mu, delta, eta), None ), None ) # carry out reorthogonalization p2, p2_norm, iters = reorth_mgs(U, p, p_norm, indices, gamma) # reset mu in the entries which have been reorthogonalized mu2 = jnp.where(reorth_indices, M2*eps, mu) # if a reorthogonalization was forced. it won't be for next iteration b_force_reorth2 = jnp.logical_not(b_force_reorth) return p2, p2_norm, reorth_indices, mu2, b_force_reorth2 # reorthogonalize p if required p, p_norm, indices, nu, b_force_reorth = lax.cond(reorth_cond, reorth_p, lambda _ : (p, p_norm, indices, nu, b_force_reorth), None ) # Check for convergence or failure to maintain semiorthogonality semiorth_cond = p_norm < max(m, n) * anorm * eps p, p_norm, b_force_reorth, indices = lax.cond(semiorth_cond, # compute a new random p vector orthogonal to previous U lambda _ : (new_p_vec(A, U, j+1, gamma)[0], 0., True, jp1_mask), lambda _ : (p, p_norm, b_force_reorth, indices), None) # save updated p_norm in beta_{j+1} (if there are any changes) beta = beta.at[j+1].set(p_norm) # track anorm for the current iteration anorms = state.anorms.at[j].set(anorm) # prepare the state for next iteration return LanBProState(p=p, U=U, V=V, alpha=alpha, beta=beta, mu=mu, nu=nu, mumax=mumax, numax=numax, anorm=anorm, anorms=anorms, indices=indices, b_fro=b_fro, b_force_reorth=b_force_reorth, b_est_anorm=b_est_anorm, iterations=j+1 )
lanbpro_iteration_jit = jit(lanbpro_iteration)
[docs]def lanbpro(A, k, p0): """K steps of the Lanczos bidiagonalization with partial reorthogonalization """ options = lanbpro_options_init(k) state = lanbpro_init(A, k, p0, options) def cond(state): return state.iterations < k def body(state): state = lanbpro_iteration(A, state, options, state.iterations) return state # state = lax.while_loop(cond, body, state) # while cond(state): # state = body(state) state = lax.fori_loop(1, k, lambda i, state: lanbpro_iteration(A, state, options), state) return state
lanbpro_jit = jit(lanbpro, static_argnums=(1,)) def new_p_vec(A, U, j, gamma): """Generates a new p vector which is orthogonal to all previous U vectors """ m, n = A.shape idx = jnp.arange(U.shape[1]) j_mask = idx < j def p_vec(i): p = random.uniform(KEYS[i], (n,)) p = A @ p p_norm = norm(p) p, p_norm, iters = reorth_mgs(U, p, p_norm, j_mask, gamma) return p, p_norm def init(): p, p_norm = p_vec(0) return p, p_norm, 1 def cond(state): p, p_norm, iterations = state cond = jnp.logical_and(iterations < 10, p_norm <1e-10) return cond def body(state): p, p_norm, i = state p, p_norm = p_vec(i) return p, p_norm, i+1 state = lax.while_loop(cond, body, init()) p, p_norm, i = state return p / p_norm, i new_p_vec_jit = jit(new_p_vec) def new_r_vec(A, V, j, gamma): """Generates a new r vector which is orthogonal to all previous V vectors """ m, n = A.shape idx = jnp.arange(V.shape[1]) j_mask = idx < j max_iters = 20 def r_vec(i): # This approach fails for large matrices with low rank # r = random.uniform(keys[i], (m,)) # r = AH_v(A, r) r = random.uniform(KEYS[i], (n,)) r_norm = norm(r) r, r_norm, iters = reorth_mgs(V, r, r_norm, j_mask, gamma) return r, r_norm def init(): r, r_norm = r_vec(0) return r, r_norm, 1 def cond(state): r, r_norm, iterations = state cond = jnp.logical_and(iterations < max_iters, r_norm <1e-10) return cond def body(state): r, r_norm, i = state r, r_norm = r_vec(i) return r, r_norm, i+1 state = lax.while_loop(cond, body, init()) r, r_norm, i = state # print(f'r_norm {r_norm}, iters: {i}') return r / r_norm, i new_r_vec_jit = jit(new_r_vec)