Source code for cr.nimble._src.vector

"""
Utility functions for working with vectors
"""

from jax import jit, lax, vmap
import jax.numpy as jnp

from .util import promote_arg_dtypes

[docs]def is_scalar(x): """Returns if x is a scalar Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a scalar quantity (i.e. ndim==0). """ return x.ndim == 0
[docs]def is_vec(x): """Returns if x is a line vector or row vector or column vector Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a line vector or a row vector or a column vector. """ return x.ndim == 1 or (x.ndim == 2 and (x.shape[0] == 1 or x.shape[1] == 1))
[docs]def is_line_vec(x): """Returns if x is a line vector Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a line vector. """ return x.ndim == 1
[docs]def is_row_vec(x): """Returns if x is a row vector Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a row vector. """ return x.ndim == 2 and x.shape[0] == 1
[docs]def is_col_vec(x): """Returns if x is a column vector Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a column vector. """ return x.ndim == 1 or (x.ndim == 2 and x.shape[1] == 1)
[docs]def is_increasing_vec(x): """Returns if x is a vector with (strictly) increasing values Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is an increasing vector. """ return jnp.all(jnp.diff(x) > 0)
[docs]def is_decreasing_vec(x): """Returns if x is a vector with (strictly) decreasing values Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a decreasing vector. """ return jnp.all(jnp.diff(x) < 0)
[docs]def is_nonincreasing_vec(x): """Returns if x is a vector with non-increasing values Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a non-increasing vector. """ return jnp.all(jnp.diff(x) <= 0)
[docs]def is_nondecreasing_vec(x): """Returns if x is a vector with non-decreasing values Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a non-decreasing vector. """ return jnp.all(jnp.diff(x) >= 0)
[docs]def has_equal_values_vec(x): """Returns if x is a vector with equal values Args: x (jax.numpy.ndarray): A JAX array. Returns: True if x is a non-decreasing vector. """ return jnp.all(x == x[0])
[docs]def to_row_vec(x): """Converts a line vector to a row vector Args: x (jax.numpy.ndarray): A line vector (ndim == 1). Returns: jax.numpy.ndarray: A row vector. """ assert x.ndim == 1 return jnp.expand_dims(x, 0)
[docs]def to_col_vec(x): """Converts a line vector to a column vector Args: x (jax.numpy.ndarray): A line vector (ndim == 1). Returns: jax.numpy.ndarray: A column vector. """ assert x.ndim == 1 return jnp.expand_dims(x, 1)
[docs]def vec_unit(n, i): """Returns a unit vector in i-th dimension for the standard coordinate system Args: n (int): Length of the vector. i (int): Index/dimension of the unit vector. Returns: jax.numpy.ndarray: A line vector of length n with all zeros except a one at position i. """ return jnp.zeros(n).at[i].set(1)
vec_unit_jit = jit(vec_unit, static_argnums=(0, 1))
[docs]def vec_shift_right(x): """Right shift the contents of the vector Args: x (jax.numpy.ndarray): A line vector. Returns: jax.numpy.ndarray: Right shifted x. """ return jnp.zeros_like(x).at[1:].set(x[:-1])
[docs]def vec_rotate_right(x): """Circular right shift the contents of the vector Args: x (jax.numpy.ndarray): A line vector. Returns: jax.numpy.ndarray: Right rotated x. """ return jnp.roll(x, 1)
[docs]def vec_shift_left(x): """Left shift the contents of the vector Args: x (jax.numpy.ndarray): A line vector. Returns: jax.numpy.ndarray: Left shifted x. """ return jnp.zeros_like(x).at[0:-1].set(x[1:])
[docs]def vec_rotate_left(x): """Circular left shift the contents of the vector Args: x (jax.numpy.ndarray): A line vector. Returns: jax.numpy.ndarray: Left rotated x. """ return jnp.roll(x, -1)
[docs]def vec_shift_right_n(x, n): """Right shift the contents of the vector by n places Args: x (jax.numpy.ndarray): A line vector. n (int): Number of positions to shift. Returns: jax.numpy.ndarray: Right shifted x by n places. """ return jnp.zeros_like(x).at[n:].set(x[:-n])
[docs]def vec_rotate_right_n(x, n): """Circular right shift the contents of the vector by n places Args: x (jax.numpy.ndarray): A line vector. n (int): Number of positions to shift. Returns: jax.numpy.ndarray: Right roted x by n places. """ return jnp.roll(x, n)
[docs]def vec_shift_left_n(x, n): """Left shift the contents of the vector by n places Args: x (jax.numpy.ndarray): A line vector. n (int): Number of positions to shift. Returns: jax.numpy.ndarray: Left shifted x by n places. """ return jnp.zeros_like(x).at[0:-n].set(x[n:])
[docs]def vec_rotate_left_n(x, n): """Circular left shift the contents of the vector by n places Args: x (jax.numpy.ndarray): A line vector. n (int): Number of positions to shift. Returns: jax.numpy.ndarray: Left rotated x by n places. """ return jnp.roll(x, -n)
def vec_safe_divide_by_scalar(x, alpha): return lax.cond(alpha == 0, lambda x : x, lambda x: x / alpha, x) vec_safe_divide_by_scalar_jit = jit(vec_safe_divide_by_scalar)
[docs]def vec_repeat_at_end(x, p): """Extends a vector by repeating it at the end (periodic extension) Args: x (jax.numpy.ndarray): A line vector. p (int): Number of samples by which x will be extended. Returns: jax.numpy.ndarray: x extended periodically at the end. """ n = x.shape[0] indices = jnp.arange(p) % n padding = x[indices] return jnp.concatenate((x, padding))
vec_repeat_at_end_jit = jit(vec_repeat_at_end, static_argnums=(1,))
[docs]def vec_repeat_at_start(x, p): """Extends a vector by repeating it at the start (periodic extension) Args: x (jax.numpy.ndarray): A line vector. p (int): Number of samples by which x will be extended. Returns: jax.numpy.ndarray: x extended periodically at the start. """ n = x.shape[0] indices = (jnp.arange(p) + n - p) % n padding = x[indices] return jnp.concatenate((padding, x))
vec_repeat_at_start_jit = jit(vec_repeat_at_start, static_argnums=(1,))
[docs]def vec_centered(x, length): """Returns the central part of a vector of a specified length Args: x (jax.numpy.ndarray): A line vector. length (int): Length of the central part of x which will be retained. Returns: jax.numpy.ndarray: central part of x of the specified length. """ cur_len = len(x) length = min(cur_len, length) start = (len(x) - length) // 2 end = start + length return x[start:end]
vec_centered_jit = jit(vec_centered, static_argnums=(1,)) ######################################################## # Energy ########################################################
[docs]@jit def vec_mag_desc(a): """Returns the coefficients in the descending order of magnitude Args: a (jax.numpy.ndarray): A vector of coefficients """ return jnp.sort(jnp.abs(a))[::-1]
[docs]@jit def vec_to_pmf(a): """Computes a probability mass function from a given vector Args: a (jax.numpy.ndarray): A vector of coefficients """ s = jnp.sum(a) * 1. return a / s
[docs]@jit def vec_to_cmf(a): """Computes a cumulative mass function from a given vector Args: a (jax.numpy.ndarray): A vector of coefficients """ s = jnp.sum(a) * 1. # normalize a = a / s # generate the CMF return jnp.cumsum(a)
[docs]@jit def cmf_find_quantile_index(a, q): """Returns the index of a given quantile in a CMF Args: a (jax.numpy.ndarray): A vector of coefficients """ return jnp.argmax(a >= q)
[docs]def num_largest_coeffs_for_energy_percent(a, p): """Returns the number of largest components containing a given percentage of energy Args: a (jax.numpy.ndarray): A vector of coefficients p (float): percentage of energy """ # compute energies a = jnp.conj(a) * a # sort in descending order a = jnp.sort(a)[::-1] # total energy s = jnp.sum(a) * 1. # normalize a = a / s # convert to a cmf cmf = jnp.cumsum(a) # the quantile value q = (p - 1e-10) / 100 # find the index index = jnp.argmax(cmf >= q) return index + 1
[docs]def vec_swap_entries(x, i, j): """Swaps two entries in a vector """ xi = x[i] xj = x[j] x = x.at[i].set(xj) x = x.at[j].set(xi) return x
######################################################## # Sliding Windows ######################################################## def vec_to_windows(x, wlen): """Constructs windows of a given length from the vector Args: x (jax.numpy.ndarray): A line vector wlen: length of each window Returns: jax.numpy.ndarray: A matrix of shape (wlen, m) where m is the number of windows Notes: - Drops extra samples from the end if the last window is not complete """ n = len(x) # number of windows m = n // wlen # total samples to be kept s = m * wlen return jnp.reshape(x[:s], (m, wlen)).T vec_to_windows_jit = jit(vec_to_windows, static_argnums=(1,)) ######################################################## # Circular Buffer ########################################################
[docs]def cbuf_push_left(buf, val): """Left shift the contents of the vector Args: buf (jax.numpy.ndarray): A circular buffer val: A value to be pushed in the buffer from left Returns: jax.numpy.ndarray: modified buffer """ return buf.at[1:].set(buf[:-1]).at[0].set(val)
[docs]def cbuf_push_right(buf, val): """Left shift the contents of the vector Args: buf (jax.numpy.ndarray): A circular buffer val: A value to be pushed in the buffer from left Returns: jax.numpy.ndarray: modified buffer """ return buf.at[:-1].set(buf[1:]).at[-1].set(val)
######################################################## # Heap ########################################################
[docs]def is_min_heap(x): """ Checks if x is a min heap """ n = len(x) idx = jnp.arange(1, n, dtype=int) parents = (idx-1) // 2 return jnp.all(x[parents] <= x[1:])
[docs]def is_max_heap(x): """ Checks if x is a max heap """ n = len(x) idx = jnp.arange(1, n, dtype=int) parents = (idx-1) // 2 return jnp.all(x[parents] >= x[1:])
[docs]def left_child_idx(idx): """Returns the index of the left child """ return (idx << 1) + 1
[docs]def right_child_idx(idx): """Returns the index of the right child """ return (idx + 1) << 1
[docs]def parent_idx(idx): """Returns the parent index for an index """ return (idx - 1) >> 1
[docs]def build_max_heap(x): """Converts x into a max heap """ def cond_func(state): x,c,p = state return jnp.logical_and(x[c] > x[p], c > 0) def body_func(state): x,c,p = state xc = x[c] xp = x[p] x = x.at[c].set(xp) x = x.at[p].set(xc) c = p p = (p - 1) >> 1 return x, c, p def main_body(i, x): # parent index p = (i - 1) >> 1 # heapify x, _, _ = lax.while_loop(cond_func, body_func, (x, i, p)) return x return lax.fori_loop(1, len(x), main_body, x)
[docs]def largest_plr(x, idx): """Return the index of the largest value between a parent and its children """ l = (idx << 1) + 1 r = (idx + 1) << 1 largest = jnp.where(x[idx] < x[l], l, idx) largest = jnp.where(x[largest] < x[r], r, largest) return largest
[docs]def heapify_subtree(x, idx): """Heapifies a subtree starting from a given node """ n = len(x) n2 = n >> 1 def body_func(state): x, idx, _ = state largest = largest_plr(x, idx) change = largest != idx x = lax.cond(change, lambda x: vec_swap_entries(x, largest, idx), lambda x: x, x) return x, largest, change def cond_func(state): x, idx, change = state return jnp.logical_and(idx < n2, change) state = x, idx, True state = lax.while_loop(cond_func, body_func, state) x, idx, change = state return x
[docs]def delete_top_from_max_heap(x): """Removes the top element from a max heap retaining its heap structure """ last = x[-1] x = x.at[0].set(last)[:-1] return heapify_subtree(x, 0)
def build_max_heap2(x, w=20): x2 = jnp.reshape(x, (-1, w)) x2 = vmap(build_max_heap)(x2) return jnp.ravel(x2)