cr.nimble.mat_column_blocks

cr.nimble.mat_column_blocks(A, n_blocks)[source]

Splits the columns of a matrix into blocks and returns a 3D array

Parameters
  • A (jax.numpy.ndarray) – A 2D matrix

  • n_blocks (int) – The number of blocks

Returns

An array of matrices where each matrix is a block of columns

Note

n_blocks is a static argument. The number of columns in A must be a multiple of n_blocks