cr.nimble.block_diag¶
- cr.nimble.block_diag(A, b)[source]¶
Extracts the block diagonal from the given matrix
- Parameters
A (jax.numpy.ndarray) – A 2D matrix
b (int) – The size of each block
- Returns
3D array of shape m x b x b where m is the number of blocks
- Return type
An array of diagonal blocks
Note
b is a static argument