Source code for cr.nimble._src.norm
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
References
- https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
- https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.norm.html
"""
import jax.numpy as jnp
EPS = jnp.finfo(jnp.float32).eps
norm = jnp.linalg.norm
from .util import promote_arg_dtypes
[docs]def norm_l1(x):
"""
Computes the l_1 norm of a vector
"""
return jnp.sum(jnp.abs(x))
[docs]def sqr_norm_l2(x):
"""
Computes the squared l_2 norm of a vector
"""
return x.T @ x
[docs]def norm_l2(x):
"""
Computes the l_2 norm of a vector
"""
return jnp.sqrt(x.T @ x)
[docs]def norm_linf(x):
"""
Computes the l_inf norm of a vector
"""
return jnp.max(jnp.abs(x))
[docs]def norms_l1_cw(X):
"""
Computes the l_1 norm of each column of a matrix
"""
return norm(X, ord=1, axis=0)
[docs]def norms_l1_rw(X):
"""
Computes the l_1 norm of each row of a matrix
"""
return norm(X, ord=1, axis=1)
[docs]def norms_l2_cw(X):
"""
Computes the l_2 norm of each column of a matrix
"""
return norm(X, ord=2, axis=0, keepdims=False)
[docs]def norms_l2_rw(X):
"""
Computes the l_2 norm of each row of a matrix
"""
return norm(X, ord=2, axis=1, keepdims=False)
[docs]def norms_linf_cw(X):
"""
Computes the l_inf norm of each column of a matrix
"""
return norm(X, ord=jnp.inf, axis=0)
[docs]def norms_linf_rw(X):
"""
Computes the l_inf norm of each row of a matrix
"""
return norm(X, ord=jnp.inf, axis=1)
[docs]def sqr_norms_l2_cw(X):
"""
Computes the squared l_2 norm of each column of a matrix
"""
return jnp.sum(X * X, axis=0)
[docs]def sqr_norms_l2_rw(X):
"""
Computes the l_2 norm of each row of a matrix
"""
return jnp.sum(X * X, axis=1)
######################################
# Normalization of vectors
######################################
[docs]def normalize_l1(x):
"""Normalizes a vector by its l_1-norm
"""
x = promote_arg_dtypes(x)
x2 = jnp.abs(x)
s = jnp.sum(x) + EPS
return jnp.divide(x, s)
[docs]def normalize_l2(x):
"""Normalizes a vector by its l_2-norm
"""
x = promote_arg_dtypes(x)
s = jnp.sqrt(jnp.sum(x ** 2)) + EPS
return jnp.divide(x, s)
[docs]def normalize_linf(x):
"""Normalizes a vector by its l_inf-norm
"""
x = promote_arg_dtypes(x)
s = jnp.max(jnp.abs(x)) + EPS
return jnp.divide(x, s)
######################################
# Normalization of rows and columns
######################################
[docs]def normalize_l1_cw(X):
"""
Normalize each column of X per l_1-norm
"""
X = promote_arg_dtypes(X)
X2 = jnp.abs(X)
sums = jnp.sum(X2, axis=0) + EPS
return jnp.divide(X, sums)
[docs]def normalize_l1_rw(X):
"""
Normalize each row of X per l_1-norm
"""
X = promote_arg_dtypes(X)
X2 = jnp.abs(X)
sums = jnp.sum(X2, axis=1) + EPS
# row wise sum should be a column vector
sums = jnp.expand_dims(sums, axis=-1)
# now broadcasting works well
return jnp.divide(X, sums)
[docs]def normalize_l2_cw(X):
"""
Normalize each column of X per l_2-norm
"""
X = promote_arg_dtypes(X)
X2 = jnp.square(X)
sums = jnp.sum(X2, axis=0)
sums = jnp.sqrt(sums)
return jnp.divide(X, sums)
[docs]def normalize_l2_rw(X):
"""
Normalize each row of X per l_2-norm
"""
X = promote_arg_dtypes(X)
X2 = jnp.square(X)
sums = jnp.sum(X2, axis=1)
sums = jnp.sqrt(sums)
# row wise sum should be a column vector
sums = jnp.expand_dims(sums, axis=-1)
# now broadcasting works well
return jnp.divide(X, sums)