Contents:
Sets the diagonal elements to a specific value
A (jax.numpy.ndarray) – A 2D matrix
value (float) – A value to be added to the diagonal elements
Matrix with updated diagonal
(jax.numpy.ndarray)