# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from collections.abc import Callable, Sequence from types import ModuleType from typing import Any, Protocol, runtime_checkable, Union import numpy as np from jax._src.partition_spec import PartitionSpec as P from jax._src.named_sharding import NamedSharding from jax._src.sharding import Sharding # TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. # We redefine these here to prevent circular imports. @runtime_checkable class SupportsDType(Protocol): @property def dtype(self) -> np.dtype: ... DTypeLike = Union[str, type[Any], np.dtype, SupportsDType] Axis = Union[int, Sequence[int], None] Shard = Any # TODO: alias this to xla_client.Traceback Device = Any Traceback = Any # TODO(jakevdp): fix import cycles and import this from jax._src.lax. PrecisionLike = Any # TODO(slebedev): Remove the metaclass once ``jax_extension_version >= 325``. class Array(metaclass=abc.ABCMeta): aval: Any @property def dtype(self) -> np.dtype: ... @property def ndim(self) -> int: ... @property def size(self) -> int: ... @property def itemsize(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, order=None): raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." " Use jax.numpy.array, or jax.numpy.zeros instead.") def __array_namespace__(self, *, api_version: None | str = ...) -> ModuleType: ... def __getitem__(self, key) -> Array: ... def __setitem__(self, key, value) -> None: ... def __len__(self) -> int: ... def __iter__(self) -> Any: ... def __reversed__(self) -> Any: ... def __round__(self, ndigits=None) -> Array: ... # Comparisons # these return bool for object, so ignore override errors. def __lt__(self, other) -> Array: ... def __le__(self, other) -> Array: ... def __eq__(self, other) -> Array: ... # type: ignore[override] def __ne__(self, other) -> Array: ... # type: ignore[override] def __gt__(self, other) -> Array: ... def __ge__(self, other) -> Array: ... # Unary arithmetic def __neg__(self) -> Array: ... def __pos__(self) -> Array: ... def __abs__(self) -> Array: ... def __invert__(self) -> Array: ... # Binary arithmetic def __add__(self, other) -> Array: ... def __sub__(self, other) -> Array: ... def __mul__(self, other) -> Array: ... def __matmul__(self, other) -> Array: ... def __truediv__(self, other) -> Array: ... def __floordiv__(self, other) -> Array: ... def __mod__(self, other) -> Array: ... def __divmod__(self, other) -> tuple[Array, Array]: ... def __pow__(self, other) -> Array: ... def __lshift__(self, other) -> Array: ... def __rshift__(self, other) -> Array: ... def __and__(self, other) -> Array: ... def __xor__(self, other) -> Array: ... def __or__(self, other) -> Array: ... def __radd__(self, other) -> Array: ... def __rsub__(self, other) -> Array: ... def __rmul__(self, other) -> Array: ... def __rmatmul__(self, other) -> Array: ... def __rtruediv__(self, other) -> Array: ... def __rfloordiv__(self, other) -> Array: ... def __rmod__(self, other) -> Array: ... def __rdivmod__(self, other) -> Array: ... def __rpow__(self, other) -> Array: ... def __rlshift__(self, other) -> Array: ... def __rrshift__(self, other) -> Array: ... def __rand__(self, other) -> Array: ... def __rxor__(self, other) -> Array: ... def __ror__(self, other) -> Array: ... def __bool__(self) -> bool: ... def __complex__(self) -> complex: ... def __int__(self) -> int: ... def __float__(self) -> float: ... def __index__(self) -> int: ... def __buffer__(self, flags: int) -> memoryview: ... def __release_buffer__(self, view: memoryview) -> None: ... # np.ndarray methods: def all(self, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... def any(self, axis: Axis = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... def argmax(self, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: ... def argmin(self, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: ... def argpartition(self, kth: int, axis: int = -1) -> Array: ... def argsort(self, axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) -> Array: ... def astype(self, dtype: DTypeLike | None = None, copy: bool = False, device: Device | Sharding | None = None) -> Array: ... def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ... def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ... def compress(self, condition: ArrayLike, axis: int | None = None, *, out: None = None, size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... def cumprod(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: ... def cumsum(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: ... def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ... def dot(self, b: ArrayLike, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: ... def flatten(self, order: str = "C") -> Array: ... @property def imag(self) -> Array: ... def item(self, *args: int) -> Any: ... def max(self, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: ... def mean(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... def min(self, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: ... @property def nbytes(self) -> int: ... def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, size: int | None = None) -> tuple[Array, ...]: ... def prod(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ... def ptp(self, axis: Axis = None, out: None = None, keepdims: bool = False) -> Array: ... def ravel(self, order: str = 'C', *, out_sharding: NamedSharding | P | None = ...) -> Array: ... @property def real(self) -> Array: ... def repeat(self, repeats: ArrayLike, axis: int | None = None, *, total_repeat_length: int | None = None, out_sharding: NamedSharding | P | None = None) -> Array: ... def reshape(self, *args: Any, order: str = "C", out_sharding: NamedSharding | P | None = ...) -> Array: ... def round(self, decimals: int = 0, out: None = None) -> Array: ... def searchsorted(self, v: ArrayLike, side: str = 'left', sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... def sort(self, axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False) -> Array: ... def squeeze(self, axis: Axis = None) -> Array: ... def std(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... def sum(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... def take(self, indices: ArrayLike, axis: int | None = None, out: None = None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value: StaticScalar | None = None) -> Array: ... def tobytes(self, order: str = 'C') -> bytes: ... def tolist(self) -> list[Any]: ... def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: ... def transpose(self, *args: Any) -> Array: ... @property def T(self) -> Array: ... @property def mT(self) -> Array: ... def var(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, ddof: int = 0, keepdims: bool = False, *, where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... def view(self, dtype: DTypeLike | None = None, type: None = None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we # implement the NumPy ArrayLike protocol. def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ... def __dlpack__(self) -> Any: ... # JAX extensions @property def at(self) -> _IndexUpdateHelper: ... @property def weak_type(self) -> bool: ... # Methods defined on ArrayImpl, but not on Tracers def addressable_data(self, index: int) -> Array: ... def block_until_ready(self) -> Array: ... def copy_to_host_async(self) -> None: ... def delete(self) -> None: ... def devices(self) -> set[Device]: ... @property def sharding(self) -> Sharding: ... @property def committed(self) -> bool: ... @property def device(self) -> Device | Sharding: ... @property def addressable_shards(self) -> Sequence[Shard]: ... @property def global_shards(self) -> Sequence[Shard]: ... def is_deleted(self) -> bool: ... @property def is_fully_addressable(self) -> bool: ... @property def is_fully_replicated(self) -> bool: ... def on_device_size_in_bytes(self) -> int: ... @property def traceback(self) -> Traceback: ... def unsafe_buffer_pointer(self) -> int: ... def to_device(self, device: Device | Sharding, *, stream: int | Any | None = ...) -> Array: ... StaticScalar = Union[ np.bool_, np.number, # NumPy scalar types bool, int, float, complex, # Python scalar types ] ArrayLike = Union[ Array, # JAX array type np.ndarray, # NumPy array type StaticScalar, # valid scalars ] # TODO: restructure to avoid re-defining this here? # from jax._src.numpy.lax_numpy import _IndexUpdateHelper class _IndexUpdateHelper: def __getitem__(self, index: Any) -> _IndexUpdateRef: ... class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, out_sharding: Sharding | P | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... def add(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def subtract(self, values: Any, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def mul(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def multiply(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def divide(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def power(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def min(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def max(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ... def apply(self, func: Callable[[ArrayLike], ArrayLike], indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None) -> Array: ...