Source code for cr.nimble._src.norm

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""

References

- https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
- https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.norm.html

"""

import jax.numpy as jnp

EPS = jnp.finfo(jnp.float32).eps

norm = jnp.linalg.norm

from .util import promote_arg_dtypes

[docs]def norm_l1(x): """ Computes the l_1 norm of a vector """ return jnp.sum(jnp.abs(x))
[docs]def sqr_norm_l2(x): """ Computes the squared l_2 norm of a vector """ return x.T @ x
[docs]def norm_l2(x): """ Computes the l_2 norm of a vector """ return jnp.sqrt(x.T @ x)
[docs]def norm_linf(x): """ Computes the l_inf norm of a vector """ return jnp.max(jnp.abs(x))
[docs]def norms_l1_cw(X): """ Computes the l_1 norm of each column of a matrix """ return norm(X, ord=1, axis=0)
[docs]def norms_l1_rw(X): """ Computes the l_1 norm of each row of a matrix """ return norm(X, ord=1, axis=1)
[docs]def norms_l2_cw(X): """ Computes the l_2 norm of each column of a matrix """ return norm(X, ord=2, axis=0, keepdims=False)
[docs]def norms_l2_rw(X): """ Computes the l_2 norm of each row of a matrix """ return norm(X, ord=2, axis=1, keepdims=False)
[docs]def norms_linf_cw(X): """ Computes the l_inf norm of each column of a matrix """ return norm(X, ord=jnp.inf, axis=0)
[docs]def norms_linf_rw(X): """ Computes the l_inf norm of each row of a matrix """ return norm(X, ord=jnp.inf, axis=1)
[docs]def sqr_norms_l2_cw(X): """ Computes the squared l_2 norm of each column of a matrix """ return jnp.sum(X * X, axis=0)
[docs]def sqr_norms_l2_rw(X): """ Computes the l_2 norm of each row of a matrix """ return jnp.sum(X * X, axis=1)
###################################### # Normalization of vectors ######################################
[docs]def normalize_l1(x): """Normalizes a vector by its l_1-norm """ x = promote_arg_dtypes(x) x2 = jnp.abs(x) s = jnp.sum(x) + EPS return jnp.divide(x, s)
[docs]def normalize_l2(x): """Normalizes a vector by its l_2-norm """ x = promote_arg_dtypes(x) s = jnp.sqrt(jnp.sum(x ** 2)) + EPS return jnp.divide(x, s)
[docs]def normalize_linf(x): """Normalizes a vector by its l_inf-norm """ x = promote_arg_dtypes(x) s = jnp.max(jnp.abs(x)) + EPS return jnp.divide(x, s)
###################################### # Normalization of rows and columns ######################################
[docs]def normalize_l1_cw(X): """ Normalize each column of X per l_1-norm """ X = promote_arg_dtypes(X) X2 = jnp.abs(X) sums = jnp.sum(X2, axis=0) + EPS return jnp.divide(X, sums)
[docs]def normalize_l1_rw(X): """ Normalize each row of X per l_1-norm """ X = promote_arg_dtypes(X) X2 = jnp.abs(X) sums = jnp.sum(X2, axis=1) + EPS # row wise sum should be a column vector sums = jnp.expand_dims(sums, axis=-1) # now broadcasting works well return jnp.divide(X, sums)
[docs]def normalize_l2_cw(X): """ Normalize each column of X per l_2-norm """ X = promote_arg_dtypes(X) X2 = jnp.square(X) sums = jnp.sum(X2, axis=0) sums = jnp.sqrt(sums) return jnp.divide(X, sums)
[docs]def normalize_l2_rw(X): """ Normalize each row of X per l_2-norm """ X = promote_arg_dtypes(X) X2 = jnp.square(X) sums = jnp.sum(X2, axis=1) sums = jnp.sqrt(sums) # row wise sum should be a column vector sums = jnp.expand_dims(sums, axis=-1) # now broadcasting works well return jnp.divide(X, sums)