Source code for cr.nimble._src.metrics

# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from jax import jit, lax
import jax.numpy as jnp

from cr.nimble import promote_arg_dtypes, check_shapes_are_equal, dtype_ranges

[docs]@jit def mean_squared(array): """Returns the mean squared value of an array """ # We need to handle both real and complex cases sqr = jnp.conj(array) * array mean_sqr = jnp.mean(sqr) # Make sure that we are down to float data type return jnp.abs(mean_sqr)
[docs]@jit def mean_squared_error(array1, array2): """Returns the mean square error between two arrays """ # check shape compatibility check_shapes_are_equal(array1, array2) # promote to same inexact type (real or complex) array1, array2 = promote_arg_dtypes(array1, array2) diff = array1 - array2 # We need to handle both real and complex cases sqr_error = jnp.conj(diff) * diff mse = jnp.mean(sqr_error) # Make sure that we are down to float data type return jnp.abs(mse)
[docs]def root_mean_squared(array): """Returns the root mean squared value of an array """ return jnp.sqrt(mean_squared(array))
[docs]def root_mse(array1, array2): """Returns the root mean square error between two arrays """ return jnp.sqrt(mean_squared_error(array1, array2))
@partial(jit, static_argnums=(1,)) def normalization_factor(array, normalization : str): """Returns the normalization factor based on the contents of an array """ normalization = normalization.lower() if normalization == 'euclidean': return root_mean_squared(array) elif normalization == 'min-max': return jnp.abs(jnp.max(array) - jnp.min(array)) elif normalization == 'mean': return jnp.abs(jnp.mean(array)) elif normalization == 'median': return jnp.abs(jnp.median(array)) else: raise ValueError("Unsupported normalization type")
[docs]@partial(jit, static_argnames=("normalization",)) def normalized_root_mse(reference_arr, test_arr, normalization='euclidean'): """Returns the normalized root mean square error between two arrays """ rmse = root_mse(reference_arr, test_arr) denom = normalization_factor(reference_arr, normalization) # make sure that denominator is non-zero eps = jnp.finfo(float).eps denom = denom + eps return rmse / denom
[docs]def normalized_mse(reference_arr, test_arr): """Returns the normalized mean square error between two arrays """ # check shape compatibility check_shapes_are_equal(reference_arr, test_arr) # promote to same inexact type (real or complex) reference_arr, test_arr = promote_arg_dtypes(reference_arr, test_arr) diff = reference_arr - test_arr # We need to handle both real and complex cases numer = jnp.sum(jnp.conj(diff) * diff) denom = jnp.sum(jnp.conj(reference_arr) * reference_arr) # make sure that values are real numer = jnp.abs(numer) denom = jnp.abs(denom) # make sure that denominator is non-zero eps = jnp.finfo(float).eps denom = denom + eps return numer / denom
[docs]@jit def peak_signal_noise_ratio(reference_arr, test_arr): """Returns the Peak Signal to Noise Ratio between two arrays """ min_val, max_val = dtype_ranges[reference_arr.dtype] data_min = jnp.min(reference_arr) zero = jnp.zeros_like(min_val) # min_val below 0 is considered only if the data really has negative values min_val = jnp.where(data_min >= 0, zero, min_val) drange = max_val - min_val mse = mean_squared_error(reference_arr, test_arr) eps = jnp.finfo(float).eps mse = lax.cond(mse, lambda _ : mse, lambda _ : eps, None) return 10 * jnp.log10((drange ** 2) / mse)
[docs]@jit def signal_noise_ratio(reference_arr, test_arr): """Returns the signal to noise ratio between a reference array and a test array Args: reference_arr (jax.numpy.ndarray): Reference signal (can be ND array) test_arr (jax.numpy.ndarray): Test signal (can be ND array) Returns: Signal to Noise Ratio between reference and test signals Example: :: >>> x = jnp.ones(10) >>> y = 0.9 * x >>> signal_noise_ratio(x, y) DeviceArray(19.999998, dtype=float32) """ reference_arr, test_arr = promote_arg_dtypes(reference_arr, test_arr) ref_energy = jnp.abs(jnp.vdot(reference_arr, reference_arr)) error = reference_arr - test_arr err_energy = jnp.abs(jnp.vdot(error, error)) eps = jnp.finfo(float).eps # make sure that error energy is non-zero err_energy = lax.cond(err_energy, lambda _ : err_energy, lambda _ : eps, None) # make sure that ref energy is non-zero ref_energy = lax.cond(ref_energy, lambda _ : ref_energy, lambda _ : eps, None) return 10 * jnp.log10(ref_energy/ err_energy)
def prd(reference_arr, test_arr): """Returns the percentage root mean square difference Args: reference_arr (jax.numpy.ndarray): Reference signal (can be ND array) test_arr (jax.numpy.ndarray): Test signal (can be ND array) Returns: (float): PRD measure """ reference_arr, test_arr = promote_arg_dtypes(reference_arr, test_arr) error = reference_arr - test_arr err_energy = jnp.abs(jnp.vdot(error, error)) ref_energy = jnp.abs(jnp.vdot(reference_arr, reference_arr)) PRD = jnp.sqrt(err_energy/ref_energy) return PRD * 100 percent_rms_diff = prd
[docs]def percent_space_saving(original_bits, compressed_bits): """Returns the space saving ratio as percentage """ return 100. * (original_bits - compressed_bits) / original_bits
[docs]def compression_ratio(original_bits, compressed_bits): """ Returns the compression ratio as n x factor """ return original_bits / compressed_bits
[docs]def prd_to_snr(prd): """ Converts the PRD value to SNR value """ return - 20 * jnp.log10(prd/100)
def cr_to_pss(cr): """Converts compression ratio to percentage space savings """ # invert it cr = 1. / cr # subtract original bits ss = 1 - cr # convert to percentage pss = 100 * ss return pss def pss_to_cr(pss): """Converts percentage space savings to compression ratio """ ss = pss / 100 cr = 1 - ss cr = 1. / cr return cr