Source code for cr.nimble._src.subspaces

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax.numpy as jnp
from jax import jit, lax

from jax.numpy.linalg import norm

from .svd_utils import singular_values
from .array import hermitian
from .matrix import AH_v

def orth_complement(A, B):
    """Returns the orthogonal complement of A in B
    """
    rank_a = A.shape[1]
    C = jnp.hstack([A, B])
    Q, R = jnp.linalg.qr(C)
    return Q[:, rank_a:]

[docs]def principal_angles_cos(A, B): """Returns the cosines of principal angles between two subspaces Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (jax.numpy.ndarray): The list of principal angles between two subspaces from smallest to the largest. """ AH = jnp.conjugate(A.T) M = AH @ B s = singular_values(M) # ensure that the singular values are below 1 return jnp.minimum(1, s)
principal_angles_cos_jit = jit(principal_angles_cos)
[docs]def principal_angles_rad(A, B): """Returns the principal angles between two subspaces in radians Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (jax.numpy.ndarray): The list of principal angles between two subspaces from smallest to the largest. """ angles = principal_angles_cos(A, B) return jnp.arccos(angles)
principal_angles_rad_jit = jit(principal_angles_rad)
[docs]def principal_angles_deg(A, B): """Returns the principal angles between two subspaces in degrees Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (jax.numpy.ndarray): The list of principal angles between two subspaces from smallest to the largest. """ angles = principal_angles_rad(A, B) return jnp.rad2deg(angles)
principal_angles_deg_jit = jit(principal_angles_deg)
[docs]def smallest_principal_angle_cos(A, B): """Returns the cosine of smallest principal angle between two subspaces Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (float): Cosine of the smallest principal angle between the two subspaces """ angles = principal_angles_cos(A, B) return angles[0]
smallest_principal_angle_cos_jit = jit(smallest_principal_angle_cos)
[docs]def smallest_principal_angle_rad(A, B): """Returns the smallest principal angle between two subspaces in radians Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (float): The smallest principal angle between the two subspaces in radians """ angle = smallest_principal_angle_cos(A, B) return jnp.arccos(angle)
smallest_principal_angle_rad_jit = jit(smallest_principal_angle_rad)
[docs]def smallest_principal_angle_deg(A, B): """Returns the smallest principal angle between two subspaces in degrees Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (float): The smallest principal angle between the two subspaces in degrees """ angle = smallest_principal_angle_rad(A, B) return jnp.rad2deg(angle)
smallest_principal_angle_deg_jit = jit(smallest_principal_angle_deg)
[docs]def smallest_principal_angles_cos(subspaces): """Returns the smallest principal angles between each pair of subspaces Args: A (jax.numpy.ndarray): An array of ONBs for the subspaces Returns: (jax.numpy.ndarray): A symmetric matrix containing the cosine of the smallest principal angles between each pair of subspaces Further reading on implementation: * `Vectorizing computations on pairs of elements in an nd-array <https://towardsdatascience.com/vectorizing-computations-on-pairs-of-elements-in-an-nd-array-326b5a648ad6>`_ * `SO: How to vectorize a 2 level loop in NumPy <https://stackoverflow.com/questions/69391894/how-to-vectorize-a-2-level-loop-in-numpy>`_ """ subspaces = jnp.asarray(subspaces) # Number of subspaces k = subspaces.shape[0] # Indices for upper triangular matrix i, j = jnp.triu_indices(k, k=1) # prepare all the possible pairs of A and B A = subspaces[i] B = subspaces[j] AH = jnp.conjugate(jnp.transpose(A, axes=(0,2,1))) M = jnp.matmul(AH, B) s = jnp.linalg.svd(M, compute_uv=False) # keep only the first index s = s[:, 0] # prepare the returning matrix r = jnp.eye(k) r = r.at[i, j].set(s) r = r + r.T - jnp.eye(k) # make sure that there is no overflow r = jnp.minimum(r, 1.) return r
smallest_principal_angles_cos_jit = jit(smallest_principal_angles_cos)
[docs]def smallest_principal_angles_rad(subspaces): """Returns the smallest principal angles between each pair of subspaces in radians Args: A (jax.numpy.ndarray): An array of ONBs for the subspaces Returns: (jax.numpy.ndarray): A symmetric matrix containing the smallest principal angles between each pair of subspaces in radians """ result = smallest_principal_angles_cos(subspaces) return jnp.arccos(result)
smallest_principal_angles_rad_jit = jit(smallest_principal_angles_rad)
[docs]def smallest_principal_angles_deg(subspaces): """Returns the smallest principal angles between each pair of subspaces in degrees Args: A (jax.numpy.ndarray): An array of ONBs for the subspaces Returns: (jax.numpy.ndarray): A symmetric matrix containing the smallest principal angles between each pair of subspaces in degrees """ result = smallest_principal_angles_rad(subspaces) return jnp.rad2deg(result)
smallest_principal_angles_deg_jit = jit(smallest_principal_angles_deg)
[docs]def subspace_distance(A, B): r"""Returns the Grassmannian distance between two subspaces Args: A (jax.numpy.ndarray): ONB for the first subspace B (jax.numpy.ndarray): ONB for the second subspace Returns: (float): Distance between the two subspaces the `Grassmannian <https://en.wikipedia.org/wiki/Grassmannian>`_ is a space that parameterizes all k dimensional linear subspaces of a vector space V. A `metric <https://math.stackexchange.com/questions/198111/distance-between-real-finite-dimensional-linear-subspaces>`_ can be defined over this space. We can use this metric to compute the distance between two subspaces. """ # Compute the projection operators for the two subspaces PA = A @ jnp.conjugate(A.T) PB = B @ jnp.conjugate(B.T) # Difference between the projection operators D = PA - PB # Return the operator norm of D return jnp.linalg.norm(D)
subspace_distance_jit = jit(subspace_distance)
[docs]def project_to_subspace(U, v): """Projects a vector to a subspace Args: U (jax.numpy.ndarray): ONB for the subspace v (jax.numpy.ndarray): A vector in the ambient space Returns: (jax.numpy.ndarray): Projection of v onto the subspace spanned by U Example: >>> A = jnp.eye(6)[:, :3] >>> v = jnp.arange(6) + 0. >>> u = project_to_subspace(A, v) >>> print(A) [[1. 0. 0.] [0. 1. 0.] [0. 0. 1.] [0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] >>> print(v) [0. 1. 2. 3. 4. 5.] >>> print(u) [0. 1. 2. 0. 0. 0.] """ UHv = AH_v(U, v) return U @ UHv
[docs]def is_in_subspace(U, v): """Checks whether a vector v is in the subspace spanned by an ONB U or not Args: U (jax.numpy.ndarray): ONB for the subspace v (jax.numpy.ndarray): A vector in the ambient space Returns: (bool): True if v lies in the subspace spanned by U, False otherwise Example: >>> A = jnp.eye(6)[:, :3] >>> v = jnp.arange(6) + 0. >>> print(is_in_subspace(A, v)) False >>> u = project_to_subspace(A, v) >>> print(is_in_subspace(A, u)) True """ # Compute the projection p = project_to_subspace(U, v) # Compute the error e = p - v nv = norm(v) ne = norm(e) return ne <= 1e-6 * nv