Source code for cr.nimble._src.ndarray

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utility functions for ND arrays
"""

from jax import jit
import jax.numpy as jnp

from cr.nimble import promote_arg_dtypes


[docs]def arr_largest_index(x): """Returns the unraveled index of the largest entry (by magnitude) in an n-d array Args: x (jax.numpy.ndarray): An nd-array Returns: tuple: n-dim index of the largest entry in x """ x = jnp.asarray(x) return jnp.unravel_index(jnp.argmax(jnp.abs(x)), x.shape)
[docs]def arr_l1norm(x): """Returns the l1-norm of an array by flattening it Args: x (jax.numpy.ndarray): An nd-array Returns: (float): l1 norm of x """ x = jnp.asarray(x) x = promote_arg_dtypes(x) return jnp.sum(jnp.abs(x))
[docs]def arr_l2norm(x): """Returns the l2-norm of an array by flattening it """ x = jnp.asarray(x) x = promote_arg_dtypes(x) return jnp.sqrt(jnp.abs(jnp.vdot(x, x)))
[docs]def arr_l2norm_sqr(x): """Returns the squared l2-norm of an array by flattening it """ x = jnp.asarray(x) x = promote_arg_dtypes(x) return jnp.vdot(x, x)
[docs]def arr_vdot(x, y): """Returns the inner product of two arrays by flattening it """ x = jnp.asarray(x) y = jnp.asarray(y) x, y = promote_arg_dtypes(x, y) return jnp.vdot(x, y)
[docs]@jit def arr_rdot(x, y): """Returns the inner product Re(x^H, y) on two arrays by flattening them """ x = jnp.asarray(x) y = jnp.asarray(y) x = jnp.ravel(x) y = jnp.ravel(y) if jnp.isrealobj(x) and jnp.isrealobj(y): # we can fall back to real inner product return jnp.sum(x * y) if jnp.isrealobj(x) or jnp.isrealobj(y): # x = jnp.real(x) y = jnp.real(y) return jnp.sum(x * y) # both x and y are complex # compute x^H x = jnp.conjugate(x) # compute x^H y prod = jnp.sum(x * y) # take the real part return jnp.real(prod)
@jit def arr_rnorm_sqr(x): """Returns the squared norm of x using the real inner product Re(x^H, x) """ return arr_rdot(x, x) @jit def arr_rnorm(x): """Returns the norm of x using the real inner product Re(x^H, x) """ return jnp.sqrt(arr_rdot(x, x))
[docs]@jit def arr2vec(x): """Converts an nd array to a vector """ x = jnp.asarray(x) return jnp.ravel(x)
@jit def log_pos(x): """Computes log with the assumption that x values are positive. """ return jnp.log(jnp.maximum(x, jnp.finfo(float).eps))