Source code for cr.nimble._src.util
# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import jax.numpy as jnp
import jax
from jax import lax, random
from jax._src import dtypes
from jax.lib import xla_bridge
platform = xla_bridge.get_backend().platform
[docs]def is_cpu():
"""Returns True if the code is running on a CPU platform
"""
return platform == 'cpu'
[docs]def is_gpu():
"""Returns True if the code is running on a GPU platform
"""
return platform == 'gpu'
[docs]def is_tpu():
"""Returns True if the code is running on a TPU platform
"""
return platform == 'tpu'
KEY0 = random.PRNGKey(0)
KEYS = random.split(KEY0, 64)
[docs]def promote_arg_dtypes(*args):
"""Promotes `args` to a common inexact type.
Args:
*args: list of JAX ndarrays to be promoted to common inexact type
Returns:
The same list of arrays with their dtype promoted to a common inexact type
Example:
Promoting a single argument::
>>> cr.nimble.promote_arg_dtypes(jnp.arange(2))
DeviceArray([0., 1.], dtype=float32)
>>> from jax.config import config
>>> config.update("jax_enable_x64", True)
>>> cr.nimble.promote_arg_dtypes(jnp.arange(2))
DeviceArray([0., 1.], dtype=float64)
Promoting two arguments to common floating point type::
>>> a = jnp.arange(2)
>>> b = jnp.arange(1.5, 3.5)
>>> a, b = cr.nimble.promote_arg_dtypes(a, b)
>>> print(a)
>>> print(b)
[0. 1.]
[1.5 2.5]
A mix of real and complex types::
>>> a = jnp.arange(2) + 0.j
>>> b = jnp.arange(1.5, 3.5)
>>> a, b = cr.nimble.promote_arg_dtypes(a, b)
>>> print(a)
>>> print(b)
[0.+0.j 1.+0.j]
[1.5+0.j 2.5+0.j]
"""
def _to_inexact_type(type):
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
inexact_types = [_to_inexact_type(arg.dtype) for arg in args]
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
args = [lax.convert_element_type(arg, dtype) for arg in args]
if len(args) == 1:
return args[0]
else:
return args
[docs]def canonicalize_dtype(dtype):
"""Wrapper function on dtypes.canonicalize_dtype with None handling
"""
if dtype is None:
return dtype
return dtypes.canonicalize_dtype(dtype)
[docs]def promote_to_complex(arg):
"""Promotes an argument to complex type"""
dtype = dtypes.result_type(arg, np.complex64)
return lax.convert_element_type(arg, dtype)
[docs]def promote_to_real(arg):
"""Promotes an argument to real type"""
dtype = dtypes.result_type(arg, np.float32)
return lax.convert_element_type(arg, dtype)
# Integer types
integer_types = (
jnp.uint8.dtype,
jnp.uint16.dtype,
jnp.uint32.dtype,
jnp.uint64.dtype,
jnp.int8.dtype,
jnp.int16.dtype,
jnp.int32.dtype,
jnp.int64.dtype,
)
# Ranges of values for integer types
integer_ranges = {t: (jnp.iinfo(t).min, jnp.iinfo(t).max)
for t in integer_types}
# Ranges of values for floating point types
dtype_ranges = {
bool: (False, True),
float: (-1, 1),
jnp.bool_.dtype: (False, True),
jnp.float_.dtype: (-1, 1),
jnp.float16.dtype: (-1, 1),
jnp.float32.dtype: (-1, 1),
jnp.complex64.dtype: (-1, 1),
jnp.complex128.dtype: (-1, 1),
}
dtype_ranges.update(integer_ranges)
[docs]def nbytes_live_buffers():
"""Returns the number of bytes consumed by the live buffers
"""
backend = jax.lib.xla_bridge.get_backend()
nbytes = [buf.nbytes for buf in backend.live_buffers()]
return np.sum(nbytes)