2025-08-11 12:24:21 +08:00

384 lines
14 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import numpy as np
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.sharding import Sharding
export = util.set_module('jax.numpy')
for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
try:
cuda_plugin_extension = importlib.import_module(
f'{pkg_name}.cuda_plugin_extension'
)
except ImportError:
cuda_plugin_extension = None # type: ignore
else:
break
def _supports_buffer_protocol(obj):
try:
view = memoryview(obj)
except TypeError:
return False
else:
return True
def _make_string_array(
object: np.ndarray,
dtype: DTypeLike | None = None,
ndmin: int = 0,
device: xc.Device | Sharding | None = None,
) -> Array:
if not isinstance(object, np.ndarray):
raise TypeError(
"Currently, string arrays can only be made from NumPy"
f" arrays. Got: {type(object)}."
)
if dtype is not None and (
dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
):
raise TypeError(
f"Cannot make an array with dtype {dtype} from an object with dtype"
f" {object.dtype}."
)
if ndmin > object.ndim:
raise TypeError(
f"ndmin {ndmin} cannot be greater than object's ndims"
f" {object.ndim} for string arrays."
)
# Just do a device_put since XLA does not support string as a data type.
return api.device_put(x=object, device=device)
@export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
"""Convert an object to a JAX array.
JAX implementation of :func:`numpy.array`.
Args:
object: an object that is convertible to an array. This includes JAX
arrays, NumPy arrays, Python scalars, Python collections like lists
and tuples, objects with an ``__array__`` method, and objects
supporting the Python buffer protocol.
dtype: optionally specify the dtype of the output array. If not
specified it will be inferred from the input.
copy: specify whether to force a copy of the input. Default: True.
order: not implemented in JAX
ndmin: integer specifying the minimum number of dimensions in the
output array.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
A JAX array constructed from the input.
See also:
- :func:`jax.numpy.asarray`: like `array`, but by default only copies
when necessary.
- :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
that implements the dlpack interface.
- :func:`jax.numpy.frombuffer`: construct a JAX array from an object
that implements the buffer interface.
Examples:
Constructing JAX arrays from Python scalars:
>>> jnp.array(True)
Array(True, dtype=bool)
>>> jnp.array(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.array(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.array(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)
Constructing JAX arrays from Python collections:
>>> jnp.array([1, 2, 3]) # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
>>> jnp.array(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)
Constructing JAX arrays from NumPy arrays:
>>> jnp.array(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
Constructing a JAX array via the Python buffer interface, using Python's
built-in :mod:`array` module.
>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.array(pybuffer)
Array([2, 3, 5, 7], dtype=int32)
"""
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
# check if the given dtype is compatible with JAX
dtypes.check_user_dtype_supported(dtype, "array")
# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
if device is None and isinstance(object, core.Tracer):
sharding = object.aval.sharding
sharding = None if sharding.mesh.empty else sharding
else:
sharding = util.canonicalize_device_to_sharding(device)
# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
device is None):
# Keep the output uncommitted.
return api.device_put(object)
# String arrays need separate handling because XLA does not support string
# as a data type.
if dtypes.is_string_dtype(dtype) or (
hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
):
return _make_string_array(
object=object, dtype=dtype, ndmin=ndmin, device=device
)
# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
# not dtype is None. We don't assign the result because we want the raw object
# to be used for type inference below.
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
elif not isinstance(object, Array):
# Check if object supports any of the data exchange protocols
# (except dlpack, see data-apis/array-api#301). If it does,
# consume the object as jax array and continue (but not return) so
# that other array() arguments get processed against the input
# object.
#
# Notice that data exchange protocols define dtype in the
# corresponding data structures and it may not be available as
# object.dtype. So, we'll resolve the protocols here before
# evaluating object.dtype.
if hasattr(object, '__jax_array__'):
object = object.__jax_array__()
elif hasattr(object, '__cuda_array_interface__'):
cai = object.__cuda_array_interface__
backend = xla_bridge.get_backend("cuda")
if cuda_plugin_extension is None:
device_id = None
else:
device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
object = xc._xla.cuda_array_interface_to_buffer(
cai=cai, gpu_backend=backend, device_id=device_id)
leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None)
if any(leaf is None for leaf in leaves):
raise ValueError("None is not a valid value for jnp.array")
leaves = [
leaf
if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None
else leaf_jax_array()
for leaf in leaves
]
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
try:
dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
except TypeError:
# This happens if, e.g. one of the entries is a memoryview object.
# This is rare, so we only handle it if the normal path fails.
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
dtype = dtypes._lattice_result_type(*leaves)[0]
if not weak_type:
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
object = treedef.unflatten(leaves)
out: ArrayLike
if all(not isinstance(leaf, Array) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
# https://github.com/jax-ml/jax/pull/6047. More correct would be to call
# coerce_to_array on each leaf, but this may have performance implications.
out = np.asarray(object, dtype=dtype)
elif isinstance(object, Array):
assert object.aval is not None
out = lax._array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:
arrs = (array(elt, dtype=dtype, copy=False) for elt in object)
arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/jax-ml/jax/issues/653).
k = 16
while len(arrays_out) > k:
arrays_out = [lax.concatenate(arrays_out[i:i+k], 0)
for i in range(0, len(arrays_out), k)]
out = lax.concatenate(arrays_out, 0)
else:
out = np.array([], dtype=dtype)
elif _supports_buffer_protocol(object):
object = memoryview(object)
# TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg.
out = np.array(object) if copy else np.asarray(object)
else:
raise TypeError(f"Unexpected input type for array: {type(object)}")
out_array: Array = lax._convert_element_type(
out, dtype, weak_type=weak_type, sharding=sharding)
if ndmin > np.ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
return out_array
def _get_platform(
device_or_sharding: xc.Device | Sharding | None | str) -> str:
"""Get device_or_sharding platform or look up config.default_device.value."""
if isinstance(device_or_sharding, xc.Device):
return device_or_sharding.platform
elif isinstance(device_or_sharding, Sharding):
return list(device_or_sharding.device_set)[0].platform
elif isinstance(device_or_sharding, str):
return device_or_sharding
elif device_or_sharding is None:
if config.default_device.value is None:
return xla_bridge.default_backend()
else:
return _get_platform(config.default_device.value)
else:
raise ValueError(f"`{device_or_sharding = }` was passed to"
"`canonicalize_or_get_default_platform`, only xc.Device,"
" Sharding, None or str values are supported.")
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
try:
dtypes.dtype(x)
except TypeError:
return np.asarray(x)
else:
return x
@export
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
*, copy: bool | None = None,
device: xc.Device | Sharding | None = None) -> Array:
"""Convert an object to a JAX array.
JAX implementation of :func:`numpy.asarray`.
Args:
a: an object that is convertible to an array. This includes JAX
arrays, NumPy arrays, Python scalars, Python collections like lists
and tuples, objects with an ``__array__`` method, and objects
supporting the Python buffer protocol.
dtype: optionally specify the dtype of the output array. If not
specified it will be inferred from the input.
order: not implemented in JAX
copy: optional boolean specifying the copy mode. If True, then always
return a copy. If False, then error if a copy is necessary. Default is
None, which will only copy when necessary.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.
Returns:
A JAX array constructed from the input.
See also:
- :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`.
- :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
that implements the dlpack interface.
- :func:`jax.numpy.frombuffer`: construct a JAX array from an object
that implements the buffer interface.
Examples:
Constructing JAX arrays from Python scalars:
>>> jnp.asarray(True)
Array(True, dtype=bool)
>>> jnp.asarray(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.asarray(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.asarray(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)
Constructing JAX arrays from Python collections:
>>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array
Array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
>>> jnp.asarray(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)
Constructing JAX arrays from NumPy arrays:
>>> jnp.asarray(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
Constructing a JAX array via the Python buffer interface, using Python's
built-in :mod:`array` module.
>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.asarray(pybuffer)
Array([2, 3, 5, 7], dtype=int32)
"""
# For copy=False, the array API specifies that we raise a ValueError if the input supports
# the buffer protocol but a copy is required. Since array() supports the buffer protocol
# via numpy, this is only the case when the default device is not 'cpu'
if (copy is False and not isinstance(a, Array)
and _get_platform(device) != "cpu"
and _supports_buffer_protocol(a)):
raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array "
f"on platform={_get_platform(device)} with "
"copy=False. Consider using copy=None or copy=True instead.")
dtypes.check_user_dtype_supported(dtype, "asarray")
if dtype is not None:
dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True) # type: ignore[assignment]
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)