cr.nimble.block_diag

cr.nimble.block_diag(A, b)[source]

Extracts the block diagonal from the given matrix

Parameters
  • A (jax.numpy.ndarray) – A 2D matrix

  • b (int) – The size of each block

Returns

3D array of shape m x b x b where m is the number of blocks

Return type

An array of diagonal blocks

Note

b is a static argument