cr.nimble.mat_column_blocks¶
- cr.nimble.mat_column_blocks(A, n_blocks)[source]¶
Splits the columns of a matrix into blocks and returns a 3D array
- Parameters
A (jax.numpy.ndarray) – A 2D matrix
n_blocks (int) – The number of blocks
- Returns
An array of matrices where each matrix is a block of columns
Note
n_blocks is a static argument. The number of columns in A must be a multiple of n_blocks