384 lines
14 KiB
Python
384 lines
14 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import api
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import tree_util
|
|
from jax._src import xla_bridge
|
|
from jax._src.lax import lax
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.numpy import util
|
|
from jax._src.typing import Array, ArrayLike, DTypeLike
|
|
from jax._src.sharding import Sharding
|
|
|
|
|
|
export = util.set_module('jax.numpy')
|
|
|
|
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
|
|
try:
|
|
cuda_plugin_extension = importlib.import_module(
|
|
f'{pkg_name}.cuda_plugin_extension'
|
|
)
|
|
except ImportError:
|
|
cuda_plugin_extension = None # type: ignore
|
|
else:
|
|
break
|
|
|
|
|
|
def _supports_buffer_protocol(obj):
|
|
try:
|
|
view = memoryview(obj)
|
|
except TypeError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def _make_string_array(
|
|
object: np.ndarray,
|
|
dtype: DTypeLike | None = None,
|
|
ndmin: int = 0,
|
|
device: xc.Device | Sharding | None = None,
|
|
) -> Array:
|
|
if not isinstance(object, np.ndarray):
|
|
raise TypeError(
|
|
"Currently, string arrays can only be made from NumPy"
|
|
f" arrays. Got: {type(object)}."
|
|
)
|
|
if dtype is not None and (
|
|
dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
|
|
):
|
|
raise TypeError(
|
|
f"Cannot make an array with dtype {dtype} from an object with dtype"
|
|
f" {object.dtype}."
|
|
)
|
|
if ndmin > object.ndim:
|
|
raise TypeError(
|
|
f"ndmin {ndmin} cannot be greater than object's ndims"
|
|
f" {object.ndim} for string arrays."
|
|
)
|
|
|
|
# Just do a device_put since XLA does not support string as a data type.
|
|
return api.device_put(x=object, device=device)
|
|
|
|
|
|
@export
|
|
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
|
order: str | None = "K", ndmin: int = 0,
|
|
*, device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Convert an object to a JAX array.
|
|
|
|
JAX implementation of :func:`numpy.array`.
|
|
|
|
Args:
|
|
object: an object that is convertible to an array. This includes JAX
|
|
arrays, NumPy arrays, Python scalars, Python collections like lists
|
|
and tuples, objects with an ``__array__`` method, and objects
|
|
supporting the Python buffer protocol.
|
|
dtype: optionally specify the dtype of the output array. If not
|
|
specified it will be inferred from the input.
|
|
copy: specify whether to force a copy of the input. Default: True.
|
|
order: not implemented in JAX
|
|
ndmin: integer specifying the minimum number of dimensions in the
|
|
output array.
|
|
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
A JAX array constructed from the input.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.asarray`: like `array`, but by default only copies
|
|
when necessary.
|
|
- :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
|
|
that implements the dlpack interface.
|
|
- :func:`jax.numpy.frombuffer`: construct a JAX array from an object
|
|
that implements the buffer interface.
|
|
|
|
Examples:
|
|
Constructing JAX arrays from Python scalars:
|
|
|
|
>>> jnp.array(True)
|
|
Array(True, dtype=bool)
|
|
>>> jnp.array(42)
|
|
Array(42, dtype=int32, weak_type=True)
|
|
>>> jnp.array(3.5)
|
|
Array(3.5, dtype=float32, weak_type=True)
|
|
>>> jnp.array(1 + 1j)
|
|
Array(1.+1.j, dtype=complex64, weak_type=True)
|
|
|
|
Constructing JAX arrays from Python collections:
|
|
|
|
>>> jnp.array([1, 2, 3]) # list of ints -> 1D array
|
|
Array([1, 2, 3], dtype=int32)
|
|
>>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array
|
|
Array([[1, 2, 3],
|
|
[4, 5, 6]], dtype=int32)
|
|
>>> jnp.array(range(5))
|
|
Array([0, 1, 2, 3, 4], dtype=int32)
|
|
|
|
Constructing JAX arrays from NumPy arrays:
|
|
|
|
>>> jnp.array(np.linspace(0, 2, 5))
|
|
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
|
|
|
|
Constructing a JAX array via the Python buffer interface, using Python's
|
|
built-in :mod:`array` module.
|
|
|
|
>>> from array import array
|
|
>>> pybuffer = array('i', [2, 3, 5, 7])
|
|
>>> jnp.array(pybuffer)
|
|
Array([2, 3, 5, 7], dtype=int32)
|
|
"""
|
|
if order is not None and order != "K":
|
|
raise NotImplementedError("Only implemented for order='K'")
|
|
|
|
# check if the given dtype is compatible with JAX
|
|
dtypes.check_user_dtype_supported(dtype, "array")
|
|
|
|
# Here we make a judgment call: we only return a weakly-typed array when the
|
|
# input object itself is weakly typed. That ensures asarray(x) is a no-op
|
|
# whenever x is weak, but avoids introducing weak types with something like
|
|
# array([1, 2, 3])
|
|
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
|
if device is None and isinstance(object, core.Tracer):
|
|
sharding = object.aval.sharding
|
|
sharding = None if sharding.mesh.empty else sharding
|
|
else:
|
|
sharding = util.canonicalize_device_to_sharding(device)
|
|
|
|
# Use device_put to avoid a copy for ndarray inputs.
|
|
if (not copy and isinstance(object, np.ndarray) and
|
|
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
|
|
device is None):
|
|
# Keep the output uncommitted.
|
|
return api.device_put(object)
|
|
|
|
# String arrays need separate handling because XLA does not support string
|
|
# as a data type.
|
|
if dtypes.is_string_dtype(dtype) or (
|
|
hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
|
|
):
|
|
return _make_string_array(
|
|
object=object, dtype=dtype, ndmin=ndmin, device=device
|
|
)
|
|
|
|
# For Python scalar literals, call coerce_to_array to catch any overflow
|
|
# errors. We don't use dtypes.is_python_scalar because we don't want this
|
|
# triggering for traced values. We do this here because it matters whether or
|
|
# not dtype is None. We don't assign the result because we want the raw object
|
|
# to be used for type inference below.
|
|
if isinstance(object, (bool, int, float, complex)):
|
|
_ = dtypes.coerce_to_array(object, dtype)
|
|
elif not isinstance(object, Array):
|
|
# Check if object supports any of the data exchange protocols
|
|
# (except dlpack, see data-apis/array-api#301). If it does,
|
|
# consume the object as jax array and continue (but not return) so
|
|
# that other array() arguments get processed against the input
|
|
# object.
|
|
#
|
|
# Notice that data exchange protocols define dtype in the
|
|
# corresponding data structures and it may not be available as
|
|
# object.dtype. So, we'll resolve the protocols here before
|
|
# evaluating object.dtype.
|
|
if hasattr(object, '__jax_array__'):
|
|
object = object.__jax_array__()
|
|
elif hasattr(object, '__cuda_array_interface__'):
|
|
cai = object.__cuda_array_interface__
|
|
backend = xla_bridge.get_backend("cuda")
|
|
if cuda_plugin_extension is None:
|
|
device_id = None
|
|
else:
|
|
device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
|
|
object = xc._xla.cuda_array_interface_to_buffer(
|
|
cai=cai, gpu_backend=backend, device_id=device_id)
|
|
|
|
leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None)
|
|
if any(leaf is None for leaf in leaves):
|
|
raise ValueError("None is not a valid value for jnp.array")
|
|
leaves = [
|
|
leaf
|
|
if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None
|
|
else leaf_jax_array()
|
|
for leaf in leaves
|
|
]
|
|
if dtype is None:
|
|
# Use lattice_result_type rather than result_type to avoid canonicalization.
|
|
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
|
|
try:
|
|
dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
|
|
except TypeError:
|
|
# This happens if, e.g. one of the entries is a memoryview object.
|
|
# This is rare, so we only handle it if the normal path fails.
|
|
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
|
|
dtype = dtypes._lattice_result_type(*leaves)[0]
|
|
|
|
if not weak_type:
|
|
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
|
|
|
object = treedef.unflatten(leaves)
|
|
out: ArrayLike
|
|
if all(not isinstance(leaf, Array) for leaf in leaves):
|
|
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
|
|
# containing large integers; see discussion in
|
|
# https://github.com/jax-ml/jax/pull/6047. More correct would be to call
|
|
# coerce_to_array on each leaf, but this may have performance implications.
|
|
out = np.asarray(object, dtype=dtype)
|
|
elif isinstance(object, Array):
|
|
assert object.aval is not None
|
|
out = lax._array_copy(object) if copy else object
|
|
elif isinstance(object, (list, tuple)):
|
|
if object:
|
|
arrs = (array(elt, dtype=dtype, copy=False) for elt in object)
|
|
arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs]
|
|
# lax.concatenate can be slow to compile for wide concatenations, so form a
|
|
# tree of concatenations as a workaround especially for op-by-op mode.
|
|
# (https://github.com/jax-ml/jax/issues/653).
|
|
k = 16
|
|
while len(arrays_out) > k:
|
|
arrays_out = [lax.concatenate(arrays_out[i:i+k], 0)
|
|
for i in range(0, len(arrays_out), k)]
|
|
out = lax.concatenate(arrays_out, 0)
|
|
else:
|
|
out = np.array([], dtype=dtype)
|
|
elif _supports_buffer_protocol(object):
|
|
object = memoryview(object)
|
|
# TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg.
|
|
out = np.array(object) if copy else np.asarray(object)
|
|
else:
|
|
raise TypeError(f"Unexpected input type for array: {type(object)}")
|
|
out_array: Array = lax._convert_element_type(
|
|
out, dtype, weak_type=weak_type, sharding=sharding)
|
|
if ndmin > np.ndim(out_array):
|
|
out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
|
|
return out_array
|
|
|
|
|
|
def _get_platform(
|
|
device_or_sharding: xc.Device | Sharding | None | str) -> str:
|
|
"""Get device_or_sharding platform or look up config.default_device.value."""
|
|
if isinstance(device_or_sharding, xc.Device):
|
|
return device_or_sharding.platform
|
|
elif isinstance(device_or_sharding, Sharding):
|
|
return list(device_or_sharding.device_set)[0].platform
|
|
elif isinstance(device_or_sharding, str):
|
|
return device_or_sharding
|
|
elif device_or_sharding is None:
|
|
if config.default_device.value is None:
|
|
return xla_bridge.default_backend()
|
|
else:
|
|
return _get_platform(config.default_device.value)
|
|
else:
|
|
raise ValueError(f"`{device_or_sharding = }` was passed to"
|
|
"`canonicalize_or_get_default_platform`, only xc.Device,"
|
|
" Sharding, None or str values are supported.")
|
|
|
|
|
|
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
|
|
try:
|
|
dtypes.dtype(x)
|
|
except TypeError:
|
|
return np.asarray(x)
|
|
else:
|
|
return x
|
|
|
|
|
|
@export
|
|
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
|
|
*, copy: bool | None = None,
|
|
device: xc.Device | Sharding | None = None) -> Array:
|
|
"""Convert an object to a JAX array.
|
|
|
|
JAX implementation of :func:`numpy.asarray`.
|
|
|
|
Args:
|
|
a: an object that is convertible to an array. This includes JAX
|
|
arrays, NumPy arrays, Python scalars, Python collections like lists
|
|
and tuples, objects with an ``__array__`` method, and objects
|
|
supporting the Python buffer protocol.
|
|
dtype: optionally specify the dtype of the output array. If not
|
|
specified it will be inferred from the input.
|
|
order: not implemented in JAX
|
|
copy: optional boolean specifying the copy mode. If True, then always
|
|
return a copy. If False, then error if a copy is necessary. Default is
|
|
None, which will only copy when necessary.
|
|
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
|
|
to which the created array will be committed.
|
|
|
|
Returns:
|
|
A JAX array constructed from the input.
|
|
|
|
See also:
|
|
- :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`.
|
|
- :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
|
|
that implements the dlpack interface.
|
|
- :func:`jax.numpy.frombuffer`: construct a JAX array from an object
|
|
that implements the buffer interface.
|
|
|
|
Examples:
|
|
Constructing JAX arrays from Python scalars:
|
|
|
|
>>> jnp.asarray(True)
|
|
Array(True, dtype=bool)
|
|
>>> jnp.asarray(42)
|
|
Array(42, dtype=int32, weak_type=True)
|
|
>>> jnp.asarray(3.5)
|
|
Array(3.5, dtype=float32, weak_type=True)
|
|
>>> jnp.asarray(1 + 1j)
|
|
Array(1.+1.j, dtype=complex64, weak_type=True)
|
|
|
|
Constructing JAX arrays from Python collections:
|
|
|
|
>>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array
|
|
Array([1, 2, 3], dtype=int32)
|
|
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array
|
|
Array([[1, 2, 3],
|
|
[4, 5, 6]], dtype=int32)
|
|
>>> jnp.asarray(range(5))
|
|
Array([0, 1, 2, 3, 4], dtype=int32)
|
|
|
|
Constructing JAX arrays from NumPy arrays:
|
|
|
|
>>> jnp.asarray(np.linspace(0, 2, 5))
|
|
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
|
|
|
|
Constructing a JAX array via the Python buffer interface, using Python's
|
|
built-in :mod:`array` module.
|
|
|
|
>>> from array import array
|
|
>>> pybuffer = array('i', [2, 3, 5, 7])
|
|
>>> jnp.asarray(pybuffer)
|
|
Array([2, 3, 5, 7], dtype=int32)
|
|
"""
|
|
# For copy=False, the array API specifies that we raise a ValueError if the input supports
|
|
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
|
|
# via numpy, this is only the case when the default device is not 'cpu'
|
|
if (copy is False and not isinstance(a, Array)
|
|
and _get_platform(device) != "cpu"
|
|
and _supports_buffer_protocol(a)):
|
|
raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array "
|
|
f"on platform={_get_platform(device)} with "
|
|
"copy=False. Consider using copy=None or copy=True instead.")
|
|
dtypes.check_user_dtype_supported(dtype, "asarray")
|
|
if dtype is not None:
|
|
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
|
|
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
|