384 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			384 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2025 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import importlib
 | |
| from typing import Any
 | |
| 
 | |
| import numpy as np
 | |
| 
 | |
| from jax._src import api
 | |
| from jax._src import config
 | |
| from jax._src import core
 | |
| from jax._src import dtypes
 | |
| from jax._src import tree_util
 | |
| from jax._src import xla_bridge
 | |
| from jax._src.lax import lax
 | |
| from jax._src.lib import xla_client as xc
 | |
| from jax._src.numpy import util
 | |
| from jax._src.typing import Array, ArrayLike, DTypeLike
 | |
| from jax._src.sharding import Sharding
 | |
| 
 | |
| 
 | |
| export = util.set_module('jax.numpy')
 | |
| 
 | |
| for pkg_name in ['jax_cuda12_plugin', 'jax.jaxlib.cuda']:
 | |
|   try:
 | |
|     cuda_plugin_extension = importlib.import_module(
 | |
|         f'{pkg_name}.cuda_plugin_extension'
 | |
|     )
 | |
|   except ImportError:
 | |
|     cuda_plugin_extension = None  # type: ignore
 | |
|   else:
 | |
|     break
 | |
| 
 | |
| 
 | |
| def _supports_buffer_protocol(obj):
 | |
|   try:
 | |
|     view = memoryview(obj)
 | |
|   except TypeError:
 | |
|     return False
 | |
|   else:
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def _make_string_array(
 | |
|     object: np.ndarray,
 | |
|     dtype: DTypeLike | None = None,
 | |
|     ndmin: int = 0,
 | |
|     device: xc.Device | Sharding | None = None,
 | |
| ) -> Array:
 | |
|   if not isinstance(object, np.ndarray):
 | |
|     raise TypeError(
 | |
|         "Currently, string arrays can only be made from NumPy"
 | |
|         f" arrays. Got:  {type(object)}."
 | |
|     )
 | |
|   if dtype is not None and (
 | |
|       dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
 | |
|   ):
 | |
|     raise TypeError(
 | |
|         f"Cannot make an array with dtype {dtype} from an object with dtype"
 | |
|         f" {object.dtype}."
 | |
|     )
 | |
|   if ndmin > object.ndim:
 | |
|     raise TypeError(
 | |
|         f"ndmin {ndmin} cannot be greater than object's ndims"
 | |
|         f" {object.ndim} for string arrays."
 | |
|     )
 | |
| 
 | |
|   # Just do a device_put since XLA does not support string as a data type.
 | |
|   return api.device_put(x=object, device=device)
 | |
| 
 | |
| 
 | |
| @export
 | |
| def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
 | |
|           order: str | None = "K", ndmin: int = 0,
 | |
|           *, device: xc.Device | Sharding | None = None) -> Array:
 | |
|   """Convert an object to a JAX array.
 | |
| 
 | |
|   JAX implementation of :func:`numpy.array`.
 | |
| 
 | |
|   Args:
 | |
|     object: an object that is convertible to an array. This includes JAX
 | |
|       arrays, NumPy arrays, Python scalars, Python collections like lists
 | |
|       and tuples, objects with an ``__array__`` method, and objects
 | |
|       supporting the Python buffer protocol.
 | |
|     dtype: optionally specify the dtype of the output array. If not
 | |
|       specified it will be inferred from the input.
 | |
|     copy: specify whether to force a copy of the input. Default: True.
 | |
|     order: not implemented in JAX
 | |
|     ndmin: integer specifying the minimum number of dimensions in the
 | |
|       output array.
 | |
|     device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
 | |
|       to which the created array will be committed.
 | |
| 
 | |
|   Returns:
 | |
|     A JAX array constructed from the input.
 | |
| 
 | |
|   See also:
 | |
|     - :func:`jax.numpy.asarray`: like `array`, but by default only copies
 | |
|       when necessary.
 | |
|     - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
 | |
|       that implements the dlpack interface.
 | |
|     - :func:`jax.numpy.frombuffer`: construct a JAX array from an object
 | |
|       that implements the buffer interface.
 | |
| 
 | |
|   Examples:
 | |
|     Constructing JAX arrays from Python scalars:
 | |
| 
 | |
|     >>> jnp.array(True)
 | |
|     Array(True, dtype=bool)
 | |
|     >>> jnp.array(42)
 | |
|     Array(42, dtype=int32, weak_type=True)
 | |
|     >>> jnp.array(3.5)
 | |
|     Array(3.5, dtype=float32, weak_type=True)
 | |
|     >>> jnp.array(1 + 1j)
 | |
|     Array(1.+1.j, dtype=complex64, weak_type=True)
 | |
| 
 | |
|     Constructing JAX arrays from Python collections:
 | |
| 
 | |
|     >>> jnp.array([1, 2, 3])  # list of ints -> 1D array
 | |
|     Array([1, 2, 3], dtype=int32)
 | |
|     >>> jnp.array([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
 | |
|     Array([[1, 2, 3],
 | |
|            [4, 5, 6]], dtype=int32)
 | |
|     >>> jnp.array(range(5))
 | |
|     Array([0, 1, 2, 3, 4], dtype=int32)
 | |
| 
 | |
|     Constructing JAX arrays from NumPy arrays:
 | |
| 
 | |
|     >>> jnp.array(np.linspace(0, 2, 5))
 | |
|     Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
 | |
| 
 | |
|     Constructing a JAX array via the Python buffer interface, using Python's
 | |
|     built-in :mod:`array` module.
 | |
| 
 | |
|     >>> from array import array
 | |
|     >>> pybuffer = array('i', [2, 3, 5, 7])
 | |
|     >>> jnp.array(pybuffer)
 | |
|     Array([2, 3, 5, 7], dtype=int32)
 | |
|   """
 | |
|   if order is not None and order != "K":
 | |
|     raise NotImplementedError("Only implemented for order='K'")
 | |
| 
 | |
|   # check if the given dtype is compatible with JAX
 | |
|   dtypes.check_user_dtype_supported(dtype, "array")
 | |
| 
 | |
|   # Here we make a judgment call: we only return a weakly-typed array when the
 | |
|   # input object itself is weakly typed. That ensures asarray(x) is a no-op
 | |
|   # whenever x is weak, but avoids introducing weak types with something like
 | |
|   # array([1, 2, 3])
 | |
|   weak_type = dtype is None and dtypes.is_weakly_typed(object)
 | |
|   if device is None and isinstance(object, core.Tracer):
 | |
|     sharding = object.aval.sharding
 | |
|     sharding = None if sharding.mesh.empty else sharding
 | |
|   else:
 | |
|     sharding = util.canonicalize_device_to_sharding(device)
 | |
| 
 | |
|   # Use device_put to avoid a copy for ndarray inputs.
 | |
|   if (not copy and isinstance(object, np.ndarray) and
 | |
|       (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
 | |
|       device is None):
 | |
|     # Keep the output uncommitted.
 | |
|     return api.device_put(object)
 | |
| 
 | |
|   # String arrays need separate handling because XLA does not support string
 | |
|   # as a data type.
 | |
|   if dtypes.is_string_dtype(dtype) or (
 | |
|       hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
 | |
|   ):
 | |
|     return _make_string_array(
 | |
|         object=object, dtype=dtype, ndmin=ndmin, device=device
 | |
|     )
 | |
| 
 | |
|   # For Python scalar literals, call coerce_to_array to catch any overflow
 | |
|   # errors. We don't use dtypes.is_python_scalar because we don't want this
 | |
|   # triggering for traced values. We do this here because it matters whether or
 | |
|   # not dtype is None. We don't assign the result because we want the raw object
 | |
|   # to be used for type inference below.
 | |
|   if isinstance(object, (bool, int, float, complex)):
 | |
|     _ = dtypes.coerce_to_array(object, dtype)
 | |
|   elif not isinstance(object, Array):
 | |
|     # Check if object supports any of the data exchange protocols
 | |
|     # (except dlpack, see data-apis/array-api#301). If it does,
 | |
|     # consume the object as jax array and continue (but not return) so
 | |
|     # that other array() arguments get processed against the input
 | |
|     # object.
 | |
|     #
 | |
|     # Notice that data exchange protocols define dtype in the
 | |
|     # corresponding data structures and it may not be available as
 | |
|     # object.dtype. So, we'll resolve the protocols here before
 | |
|     # evaluating object.dtype.
 | |
|     if hasattr(object, '__jax_array__'):
 | |
|       object = object.__jax_array__()
 | |
|     elif hasattr(object, '__cuda_array_interface__'):
 | |
|       cai = object.__cuda_array_interface__
 | |
|       backend = xla_bridge.get_backend("cuda")
 | |
|       if cuda_plugin_extension is None:
 | |
|         device_id = None
 | |
|       else:
 | |
|         device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
 | |
|       object = xc._xla.cuda_array_interface_to_buffer(
 | |
|           cai=cai, gpu_backend=backend, device_id=device_id)
 | |
| 
 | |
|   leaves, treedef = tree_util.tree_flatten(object, is_leaf=lambda x: x is None)
 | |
|   if any(leaf is None for leaf in leaves):
 | |
|     raise ValueError("None is not a valid value for jnp.array")
 | |
|   leaves = [
 | |
|       leaf
 | |
|       if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None
 | |
|       else leaf_jax_array()
 | |
|       for leaf in leaves
 | |
|   ]
 | |
|   if dtype is None:
 | |
|     # Use lattice_result_type rather than result_type to avoid canonicalization.
 | |
|     # Otherwise, weakly-typed inputs would have their dtypes canonicalized.
 | |
|     try:
 | |
|       dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
 | |
|     except TypeError:
 | |
|       # This happens if, e.g. one of the entries is a memoryview object.
 | |
|       # This is rare, so we only handle it if the normal path fails.
 | |
|       leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
 | |
|       dtype = dtypes._lattice_result_type(*leaves)[0]
 | |
| 
 | |
|   if not weak_type:
 | |
|     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
 | |
| 
 | |
|   object = treedef.unflatten(leaves)
 | |
|   out: ArrayLike
 | |
|   if all(not isinstance(leaf, Array) for leaf in leaves):
 | |
|     # TODO(jakevdp): falling back to numpy here fails to overflow for lists
 | |
|     # containing large integers; see discussion in
 | |
|     # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
 | |
|     # coerce_to_array on each leaf, but this may have performance implications.
 | |
|     out = np.asarray(object, dtype=dtype)
 | |
|   elif isinstance(object, Array):
 | |
|     assert object.aval is not None
 | |
|     out = lax._array_copy(object) if copy else object
 | |
|   elif isinstance(object, (list, tuple)):
 | |
|     if object:
 | |
|       arrs = (array(elt, dtype=dtype, copy=False) for elt in object)
 | |
|       arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs]
 | |
|       # lax.concatenate can be slow to compile for wide concatenations, so form a
 | |
|       # tree of concatenations as a workaround especially for op-by-op mode.
 | |
|       # (https://github.com/jax-ml/jax/issues/653).
 | |
|       k = 16
 | |
|       while len(arrays_out) > k:
 | |
|         arrays_out = [lax.concatenate(arrays_out[i:i+k], 0)
 | |
|                       for i in range(0, len(arrays_out), k)]
 | |
|       out = lax.concatenate(arrays_out, 0)
 | |
|     else:
 | |
|       out = np.array([], dtype=dtype)
 | |
|   elif _supports_buffer_protocol(object):
 | |
|     object = memoryview(object)
 | |
|     # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg.
 | |
|     out = np.array(object) if copy else np.asarray(object)
 | |
|   else:
 | |
|     raise TypeError(f"Unexpected input type for array: {type(object)}")
 | |
|   out_array: Array = lax._convert_element_type(
 | |
|       out, dtype, weak_type=weak_type, sharding=sharding)
 | |
|   if ndmin > np.ndim(out_array):
 | |
|     out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
 | |
|   return out_array
 | |
| 
 | |
| 
 | |
| def _get_platform(
 | |
|     device_or_sharding: xc.Device | Sharding | None | str) -> str:
 | |
|   """Get device_or_sharding platform or look up config.default_device.value."""
 | |
|   if isinstance(device_or_sharding, xc.Device):
 | |
|     return device_or_sharding.platform
 | |
|   elif isinstance(device_or_sharding, Sharding):
 | |
|     return list(device_or_sharding.device_set)[0].platform
 | |
|   elif isinstance(device_or_sharding, str):
 | |
|     return device_or_sharding
 | |
|   elif device_or_sharding is None:
 | |
|     if config.default_device.value is None:
 | |
|       return xla_bridge.default_backend()
 | |
|     else:
 | |
|       return _get_platform(config.default_device.value)
 | |
|   else:
 | |
|     raise ValueError(f"`{device_or_sharding = }` was passed to"
 | |
|                      "`canonicalize_or_get_default_platform`, only xc.Device,"
 | |
|                      " Sharding, None or str values are supported.")
 | |
| 
 | |
| 
 | |
| def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
 | |
|   try:
 | |
|     dtypes.dtype(x)
 | |
|   except TypeError:
 | |
|     return np.asarray(x)
 | |
|   else:
 | |
|     return x
 | |
| 
 | |
| 
 | |
| @export
 | |
| def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
 | |
|             *, copy: bool | None = None,
 | |
|             device: xc.Device | Sharding | None = None) -> Array:
 | |
|   """Convert an object to a JAX array.
 | |
| 
 | |
|   JAX implementation of :func:`numpy.asarray`.
 | |
| 
 | |
|   Args:
 | |
|     a: an object that is convertible to an array. This includes JAX
 | |
|       arrays, NumPy arrays, Python scalars, Python collections like lists
 | |
|       and tuples, objects with an ``__array__`` method, and objects
 | |
|       supporting the Python buffer protocol.
 | |
|     dtype: optionally specify the dtype of the output array. If not
 | |
|       specified it will be inferred from the input.
 | |
|     order: not implemented in JAX
 | |
|     copy: optional boolean specifying the copy mode. If True, then always
 | |
|       return a copy. If False, then error if a copy is necessary. Default is
 | |
|       None, which will only copy when necessary.
 | |
|     device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
 | |
|       to which the created array will be committed.
 | |
| 
 | |
|   Returns:
 | |
|     A JAX array constructed from the input.
 | |
| 
 | |
|   See also:
 | |
|     - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`.
 | |
|     - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
 | |
|       that implements the dlpack interface.
 | |
|     - :func:`jax.numpy.frombuffer`: construct a JAX array from an object
 | |
|       that implements the buffer interface.
 | |
| 
 | |
|   Examples:
 | |
|     Constructing JAX arrays from Python scalars:
 | |
| 
 | |
|     >>> jnp.asarray(True)
 | |
|     Array(True, dtype=bool)
 | |
|     >>> jnp.asarray(42)
 | |
|     Array(42, dtype=int32, weak_type=True)
 | |
|     >>> jnp.asarray(3.5)
 | |
|     Array(3.5, dtype=float32, weak_type=True)
 | |
|     >>> jnp.asarray(1 + 1j)
 | |
|     Array(1.+1.j, dtype=complex64, weak_type=True)
 | |
| 
 | |
|     Constructing JAX arrays from Python collections:
 | |
| 
 | |
|     >>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
 | |
|     Array([1, 2, 3], dtype=int32)
 | |
|     >>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
 | |
|     Array([[1, 2, 3],
 | |
|            [4, 5, 6]], dtype=int32)
 | |
|     >>> jnp.asarray(range(5))
 | |
|     Array([0, 1, 2, 3, 4], dtype=int32)
 | |
| 
 | |
|     Constructing JAX arrays from NumPy arrays:
 | |
| 
 | |
|     >>> jnp.asarray(np.linspace(0, 2, 5))
 | |
|     Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
 | |
| 
 | |
|     Constructing a JAX array via the Python buffer interface, using Python's
 | |
|     built-in :mod:`array` module.
 | |
| 
 | |
|     >>> from array import array
 | |
|     >>> pybuffer = array('i', [2, 3, 5, 7])
 | |
|     >>> jnp.asarray(pybuffer)
 | |
|     Array([2, 3, 5, 7], dtype=int32)
 | |
|   """
 | |
|   # For copy=False, the array API specifies that we raise a ValueError if the input supports
 | |
|   # the buffer protocol but a copy is required. Since array() supports the buffer protocol
 | |
|   # via numpy, this is only the case when the default device is not 'cpu'
 | |
|   if (copy is False and not isinstance(a, Array)
 | |
|       and _get_platform(device) != "cpu"
 | |
|       and _supports_buffer_protocol(a)):
 | |
|     raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array "
 | |
|                      f"on platform={_get_platform(device)} with "
 | |
|                      "copy=False. Consider using copy=None or copy=True instead.")
 | |
|   dtypes.check_user_dtype_supported(dtype, "asarray")
 | |
|   if dtype is not None:
 | |
|     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
 | |
|   return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
 |