# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Sequence from functools import partial import itertools import math import numpy as np import operator from typing import Literal, NamedTuple, overload from jax import lax from jax._src.api import jit from jax._src import config from jax._src.custom_derivatives import custom_jvp from jax._src import deprecations from jax._src.lax import lax as lax_internal from jax._src.lax.lax import PrecisionLike from jax._src.lax import linalg as lax_linalg from jax._src.numpy import einsum from jax._src.numpy import indexing from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions, tensor_contractions, ufuncs from jax._src.numpy.util import promote_dtypes_inexact, ensure_arraylike from jax._src.util import canonicalize_axis, set_module from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg export = set_module('jax.numpy.linalg') class EighResult(NamedTuple): eigenvalues: Array eigenvectors: Array class QRResult(NamedTuple): Q: Array R: Array class SlogdetResult(NamedTuple): sign: Array logabsdet: Array class SVDResult(NamedTuple): U: Array S: Array Vh: Array def _H(x: ArrayLike) -> Array: return ufuncs.conjugate(jnp.matrix_transpose(x)) def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 @export @partial(jit, static_argnames=['upper', 'symmetrize_input']) def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array: """Compute the Cholesky decomposition of a matrix. JAX implementation of :func:`numpy.linalg.cholesky`. The Cholesky decomposition of a matrix `A` is: .. math:: A = U^HU or .. math:: A = LL^H where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix, and :math:`X^H` is the Hermitian transpose of `X`. Args: a: input array, representing a (batched) positive-definite hermitian matrix. Must have shape ``(..., N, N)``. upper: if True, compute the upper Cholesky decomposition `U`. if False (default), compute the lower Cholesky decomposition `L`. symmetrize_input: if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition. Returns: array of shape ``(..., N, N)`` representing the Cholesky decomposition of the input. If the input is not Hermitian positive-definite, The result will contain NaN entries. See also: - :func:`jax.scipy.linalg.cholesky`: SciPy-style Cholesky API - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API Examples: A small real Hermitian positive-definite matrix: >>> x = jnp.array([[2., 1.], ... [1., 2.]]) Lower Cholesky factorization: >>> jnp.linalg.cholesky(x) Array([[1.4142135 , 0. ], [0.70710677, 1.2247449 ]], dtype=float32) Upper Cholesky factorization: >>> jnp.linalg.cholesky(x, upper=True) Array([[1.4142135 , 0.70710677], [0. , 1.2247449 ]], dtype=float32) Reconstructing ``x`` from its factorization: >>> L = jnp.linalg.cholesky(x) >>> jnp.allclose(x, L @ L.T) Array(True, dtype=bool) """ a = ensure_arraylike("jnp.linalg.cholesky", a) a, = promote_dtypes_inexact(a) L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input) return L.mT.conj() if upper else L @overload def svd( a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> SVDResult: ... @overload def svd( a: ArrayLike, full_matrices: bool, compute_uv: Literal[True], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> SVDResult: ... @overload def svd( a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> Array: ... @overload def svd( a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> Array: ... @overload def svd( a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> Array | SVDResult: ... @export @partial( jit, static_argnames=( "full_matrices", "compute_uv", "hermitian", "subset_by_index", ), ) def svd( a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False, subset_by_index: tuple[int, int] | None = None, ) -> Array | SVDResult: r"""Compute the singular value decomposition. JAX implementation of :func:`numpy.linalg.svd`, implemented in terms of :func:`jax.lax.linalg.svd`. The SVD of a matrix `A` is given by .. math:: A = U\Sigma V^H - :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I` - :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I` - :math:`\Sigma` is a diagonal matrix of singular values. Args: a: input array, of shape ``(..., N, M)`` full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are ``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``. compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return only the singular values ``s``. hermitian: if True, assume the matrix is hermitian, which allows for a more efficient implementation (default=False) subset_by_index: (TPU-only) Optional 2-tuple [start, end] indicating the range of indices of singular values to compute. For example, if ``[n-2, n]`` then ``svd`` computes the two largest singular values and their singular vectors. Only compatible with ``full_matrices=False``. Returns: A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. - ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True or ``(..., N, K)`` otherwise. - ``s``: singular values of shape ``(..., K)`` - ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)`` if ``full_matrices`` is True or ``(..., K, M)`` otherwise. where ``K = min(N, M)``. See also: - :func:`jax.scipy.linalg.svd`: SciPy-style SVD API - :func:`jax.lax.linalg.svd`: XLA-style SVD API Examples: Consider the SVD of a small real-valued array: >>> x = jnp.array([[1., 2., 3.], ... [6., 5., 4.]]) >>> u, s, vt = jnp.linalg.svd(x, full_matrices=False) >>> s # doctest: +SKIP Array([9.361919 , 1.8315067], dtype=float32) The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are orthonormal, which can be demonstrated by comparing the matrix product with the identity matrix: >>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) >>> v = vt.T >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) Array(True, dtype=bool) Given the SVD, ``x`` can be reconstructed via matrix multiplication: >>> x_reconstructed = u @ jnp.diag(s) @ vt >>> jnp.allclose(x_reconstructed, x) Array(True, dtype=bool) """ a = ensure_arraylike("jnp.linalg.svd", a) a, = promote_dtypes_inexact(a) if hermitian: w, v = lax_linalg.eigh(a, subset_by_index=subset_by_index) s = lax.abs(v) if compute_uv: sign = lax.sign(v) idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) sign = lax.rev(sign, dimensions=[s.ndim - 1]) u = indexing.take_along_axis(w, idxs[..., None, :], axis=-1) vh = _H(u * sign[..., None, :].astype(u.dtype)) return SVDResult(u, s, vh) else: return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1]) if compute_uv: u, s, vh = lax_linalg.svd( a, full_matrices=full_matrices, compute_uv=True, subset_by_index=subset_by_index, ) return SVDResult(u, s, vh) else: return lax_linalg.svd( a, full_matrices=full_matrices, compute_uv=False, subset_by_index=subset_by_index, ) @export @partial(jit, static_argnames=('n',)) def matrix_power(a: ArrayLike, n: int) -> Array: """Raise a square matrix to an integer power. JAX implementation of :func:`numpy.linalg.matrix_power`, implemented via repeated squarings. Args: a: array of shape ``(..., M, M)`` to be raised to the power `n`. n: the integer exponent to which the matrix should be raised. Returns: Array of shape ``(..., M, M)`` containing the matrix power of a to the n. Examples: >>> a = jnp.array([[1., 2.], ... [3., 4.]]) >>> jnp.linalg.matrix_power(a, 3) Array([[ 37., 54.], [ 81., 118.]], dtype=float32) >>> a @ a @ a # equivalent evaluated directly Array([[ 37., 54.], [ 81., 118.]], dtype=float32) This also supports zero powers: >>> jnp.linalg.matrix_power(a, 0) Array([[1., 0.], [0., 1.]], dtype=float32) and also supports negative powers: >>> with jnp.printoptions(precision=3): ... jnp.linalg.matrix_power(a, -2) Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32) Negative powers are equivalent to matmul of the inverse: >>> inv_a = jnp.linalg.inv(a) >>> with jnp.printoptions(precision=3): ... inv_a @ inv_a Array([[ 5.5 , -2.5 ], [-3.75, 1.75]], dtype=float32) """ arr = ensure_arraylike("jnp.linalg.matrix_power", a) if arr.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " "two-dimensional".format(arr.ndim)) if arr.shape[-2] != arr.shape[-1]: raise TypeError("Last 2 dimensions of the array must be square") try: n = operator.index(n) except TypeError as err: raise TypeError(f"exponent must be an integer, got {n}") from err if n == 0: return jnp.broadcast_to(jnp.eye(arr.shape[-2], dtype=arr.dtype), arr.shape) elif n < 0: arr = inv(arr) n = abs(n) if n == 1: return arr elif n == 2: return arr @ arr elif n == 3: return (arr @ arr) @ arr z = result = None while n > 0: z = arr if z is None else (z @ z) # type: ignore[operator] n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) assert result is not None return result @export @jit def matrix_rank( M: ArrayLike, rtol: ArrayLike | None = None, *, tol: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: """Compute the rank of a matrix. JAX implementation of :func:`numpy.linalg.matrix_rank`. The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance. Args: M: array of shape ``(..., N, K)`` whose rank is to be computed. rtol: optional array of shape ``(...)`` specifying the tolerance. Singular values smaller than `rtol * largest_singular_value` are considered to be zero. If ``rtol`` is None (the default), a reasonable default is chosen based the floating point precision of the input. tol: deprecated alias of the ``rtol`` argument. Will result in a :class:`DeprecationWarning` if used. Returns: array of shape ``a.shape[-2]`` giving the matrix rank. Notes: The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the ``rtol`` parameter or using a more specialized rank computation method in such cases. Examples: >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.matrix_rank(a) Array(2, dtype=int32) >>> b = jnp.array([[1, 0], # Rank-deficient matrix ... [0, 0]]) >>> jnp.linalg.matrix_rank(b) Array(1, dtype=int32) """ M = ensure_arraylike("jnp.linalg.matrix_rank", M) # TODO(micky774): deprecated 2024-5-14, remove after deprecation expires. if not isinstance(tol, DeprecatedArg): rtol = tol del tol deprecations.warn( "jax-numpy-linalg-matrix_rank-tol", ("The tol argument for linalg.matrix_rank is deprecated. " "Please use rtol instead."), stacklevel=2 ) M, = promote_dtypes_inexact(M) if M.ndim < 2: return (M != 0).any().astype(np.int32) S = svd(M, full_matrices=False, compute_uv=False) if rtol is None: rtol = S.max(-1) * np.max(M.shape[-2:]).astype(S.dtype) * jnp.finfo(S.dtype).eps rtol = jnp.expand_dims(rtol, np.ndim(rtol)) return reductions.sum(S > rtol, axis=-1) @custom_jvp def _slogdet_lu(a: Array) -> tuple[Array, Array]: dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = reductions.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1], dtype=pivot.dtype), range(pivot.ndim - 1)) parity = reductions.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = reductions.prod(diag / ufuncs.abs(diag).astype(diag.dtype), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + reductions.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where( is_zero, jnp.array(-np.inf, dtype=dtype), reductions.sum(ufuncs.log(ufuncs.abs(diag)).astype(dtype), axis=-1)) return sign, ufuncs.real(logdet) @custom_jvp def _slogdet_qr(a: Array) -> tuple[Array, Array]: # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), np.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. a_diag = jnp.diagonal(a, axis1=-2, axis2=-1) log_abs_det = reductions.sum(ufuncs.log(ufuncs.abs(a_diag)), axis=-1) sign_diag = reductions.prod(ufuncs.sign(a_diag), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = reductions.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det @export @partial(jit, static_argnames=('method',)) def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ Compute the sign and (natural) logarithm of the determinant of an array. JAX implementation of :func:`numpy.linalg.slogdet`. Args: a: array of shape ``(..., M, M)`` for which to compute the sign and log determinant. method: the method to use for determinant computation. Options are - ``'lu'`` (default): use the LU decomposition. - ``'qr'``: use the QR decomposition. Returns: A tuple of arrays ``(sign, logabsdet)``, each of shape ``a.shape[:-2]`` - ``sign`` is the sign of the determinant. - ``logabsdet`` is the natural log of the determinant's absolute value. See also: :func:`jax.numpy.linalg.det`: direct computation of determinant Examples: >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> sign, logabsdet = jnp.linalg.slogdet(a) >>> sign # -1 indicates negative determinant Array(-1., dtype=float32) >>> jnp.exp(logabsdet) # Absolute value of determinant Array(2., dtype=float32) """ a = ensure_arraylike("jnp.linalg.slogdet", a) a, = promote_dtypes_inexact(a) a_shape = np.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: raise ValueError(f"Argument to slogdet() must have shape [..., n, n], got {a_shape}") if method is None or method == "lu": return SlogdetResult(*_slogdet_lu(a)) elif method == "qr": return SlogdetResult(*_slogdet_qr(a)) else: raise ValueError(f"Unknown slogdet method '{method}'. Supported methods " "are 'lu' (`None`), and 'qr'.") def _slogdet_jvp(primals, tangents): x, = primals g, = tangents sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), np.complexfloating): sign_dot = (ans_dot - ufuncs.real(ans_dot).astype(ans_dot.dtype)) * sign ans_dot = ufuncs.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot) _slogdet_lu.defjvp(_slogdet_jvp) _slogdet_qr.defjvp(_slogdet_jvp) def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> tuple[Array, Array]: """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a, b = ensure_arraylike("jnp.linalg._cofactor_solve", a, b) a, = promote_dtypes_inexact(a) b, = promote_dtypes_inexact(b) a_shape = np.shape(a) b_shape = np.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) iota = lax.expand_dims(jnp.arange(a_shape[-1], dtype=pivots.dtype), range(pivots.ndim - 1)) parity = reductions.count_nonzero(pivots != iota, axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = reductions.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1])) iotas = jnp.ix_(*(lax.iota(np.int32, b) for b in (*batch_dims, 1))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = reductions.any(ufuncs.logical_or(ufuncs.isnan(d), ufuncs.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim*(1,) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate((x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x def _det_2x2(a: Array) -> Array: return (a[..., 0, 0] * a[..., 1, 1] - a[..., 0, 1] * a[..., 1, 0]) def _det_3x3(a: Array) -> Array: return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] + a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] + a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] - a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] - a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] - a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2]) @custom_jvp def _det(a): sign, logdet = slogdet(a) return sign * ufuncs.exp(logdet).astype(sign.dtype) @_det.defjvp def _det_jvp(primals, tangents): x, = primals g, = tangents y, z = _cofactor_solve(x, g) return y, jnp.trace(z, axis1=-1, axis2=-2) @export @jit def det(a: ArrayLike) -> Array: """ Compute the determinant of an array. JAX implementation of :func:`numpy.linalg.det`. Args: a: array of shape ``(..., M, M)`` for which to compute the determinant. Returns: An array of determinants of shape ``a.shape[:-2]``. See also: :func:`jax.scipy.linalg.det`: Scipy-style API for determinant. Examples: >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.det(a) Array(-2., dtype=float32) """ a = ensure_arraylike("jnp.linalg.det", a) a, = promote_dtypes_inexact(a) a_shape = np.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: return _det_2x2(a) elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: return _det(a) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) @export def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. JAX implementation of :func:`numpy.linalg.eig`. Args: a: array of shape ``(..., M, M)`` for which to compute the eigenvalues and vectors. Returns: A tuple ``(eigenvalues, eigenvectors)`` with - ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues. - ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the eigenvector corresponding to the eigenvalue ``w[i]``. Notes: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - At present, non-symmetric eigendecomposition is only implemented on the CPU and GPU backends. For more details about the GPU implementation, see the documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. - :func:`jax.numpy.linalg.eigvals`: compute eigenvalues only. Examples: >>> a = jnp.array([[1., 2.], ... [2., 1.]]) >>> w, v = jnp.linalg.eig(a) >>> with jax.numpy.printoptions(precision=4): ... w Array([ 3.+0.j, -1.+0.j], dtype=complex64) >>> v Array([[ 0.70710677+0.j, -0.70710677+0.j], [ 0.70710677+0.j, 0.70710677+0.j]], dtype=complex64) """ a = ensure_arraylike("jnp.linalg.eig", a) a, = promote_dtypes_inexact(a) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False) return w, v @export @jit def eigvals(a: ArrayLike) -> Array: """ Compute the eigenvalues of a general matrix. JAX implementation of :func:`numpy.linalg.eigvals`. Args: a: array of shape ``(..., M, M)`` for which to compute the eigenvalues. Returns: An array of shape ``(..., M)`` containing the eigenvalues. See also: - :func:`jax.numpy.linalg.eig`: computes eigenvalues eigenvectors of a general matrix. - :func:`jax.numpy.linalg.eigh`: computes eigenvalues eigenvectors of a Hermitian matrix. Notes: - This differs from :func:`numpy.linalg.eigvals` in that the return type of :func:`jax.numpy.linalg.eigvals` is always complex64 for 32-bit input, and complex128 for 64-bit input. - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. Examples: >>> a = jnp.array([[1., 2.], ... [2., 1.]]) >>> w = jnp.linalg.eigvals(a) >>> with jnp.printoptions(precision=2): ... w Array([ 3.+0.j, -1.+0.j], dtype=complex64) """ a = ensure_arraylike("jnp.linalg.eigvals", a) a, = promote_dtypes_inexact(a) return lax_linalg.eig(a, compute_left_eigenvectors=False, compute_right_eigenvectors=False)[0] @export @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigh(a: ArrayLike, UPLO: str | None = None, symmetrize_input: bool = True) -> EighResult: """ Compute the eigenvalues and eigenvectors of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigh`. Args: a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition. Returns: A namedtuple ``(eigenvalues, eigenvectors)`` where - ``eigenvalues``: an array of shape ``(..., M)`` containing the eigenvalues, sorted in ascending order. - ``eigenvectors``: an array of shape ``(..., M, M)``, where column ``v[:, i]`` is the normalized eigenvector corresponding to the eigenvalue ``w[i]``. See also: - :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition. - :func:`jax.numpy.linalg.eigvalsh`: compute eigenvalues only. - :func:`jax.scipy.linalg.eigh`: SciPy API for Hermitian eigendecomposition. - :func:`jax.lax.linalg.eigh`: XLA API for Hermitian eigendecomposition. Examples: >>> a = jnp.array([[1, -2j], ... [2j, 1]]) >>> w, v = jnp.linalg.eigh(a) >>> w Array([-1., 3.], dtype=float32) >>> with jnp.printoptions(precision=3): ... v Array([[-0.707+0.j , -0.707+0.j ], [ 0. +0.707j, 0. -0.707j]], dtype=complex64) """ a = ensure_arraylike("jnp.linalg.eigh", a) if UPLO is None or UPLO == "L": lower = True elif UPLO == "U": lower = False else: msg = f"UPLO must be one of None, 'L', or 'U', got {UPLO}" raise ValueError(msg) a, = promote_dtypes_inexact(a) v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) return EighResult(w, v) @export @partial(jit, static_argnames=('UPLO', 'symmetrize_input')) def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *, symmetrize_input: bool = True) -> Array: """ Compute the eigenvalues of a Hermitian matrix. JAX implementation of :func:`numpy.linalg.eigvalsh`. Args: a: array of shape ``(..., M, M)``, containing the Hermitian (if complex) or symmetric (if real) matrix. UPLO: specifies whether the calculation is done with the lower triangular part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``). symmetrize_input: if True (default) then input is symmetrized, which leads to better behavior under automatic differentiation. Note that when this is set to True, both the upper and lower triangles of the input will be used in computing the decomposition. Returns: An array of shape ``(..., M)`` containing the eigenvalues, sorted in ascending order. See also: - :func:`jax.numpy.linalg.eig`: general eigenvalue decomposition. - :func:`jax.numpy.linalg.eigh`: computes eigenvalues and eigenvectors of a Hermitian matrix. Examples: >>> a = jnp.array([[1, -2j], ... [2j, 1]]) >>> w = jnp.linalg.eigvalsh(a) >>> w Array([-1., 3.], dtype=float32) """ a = ensure_arraylike("jnp.linalg.eigvalsh", a) a, = promote_dtypes_inexact(a) w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input) return w # TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires. @export def pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False, *, rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array: """Compute the (Moore-Penrose) pseudo-inverse of a matrix. JAX implementation of :func:`numpy.linalg.pinv`. Args: a: array of shape ``(..., M, N)`` containing matrices to pseudo-invert. rtol: float or array_like of shape ``a.shape[:-2]``. Specifies the cutoff for small singular values.of shape ``(...,)``. Cutoff for small singular values; singular values smaller ``rtol * largest_singular_value`` are treated as zero. The default is determined based on the floating point precision of the dtype. hermitian: if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False) rcond: deprecated alias of the ``rtol`` argument. Will result in a :class:`DeprecationWarning` if used. Returns: An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``. See also: - :func:`jax.numpy.linalg.inv`: multiplicative inverse of a square matrix. Notes: :func:`jax.numpy.linalg.pinv` differs from :func:`numpy.linalg.pinv` in the default value of `rcond``: in NumPy, the default is `1e-15`. In JAX, the default is ``10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps``. Examples: >>> a = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> a_pinv = jnp.linalg.pinv(a) >>> a_pinv # doctest: +SKIP Array([[-1.333332 , -0.33333257, 0.6666657 ], [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32) The pseudo-inverse operates as a multiplicative inverse so long as the output is not rank-deficient: >>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool) """ if not isinstance(rcond, DeprecatedArg): rtol = rcond del rcond deprecations.warn( "jax-numpy-linalg-pinv-rcond", ("The rcond argument for linalg.pinv is deprecated. " "Please use rtol instead."), stacklevel=2 ) return _pinv(a, rtol, hermitian) @partial(custom_jvp, nondiff_argnums=(1, 2)) @partial(jit, static_argnames=('hermitian')) def _pinv(a: ArrayLike, rtol: ArrayLike | None = None, hermitian: bool = False) -> Array: # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 a = ensure_arraylike("jnp.linalg.pinv", a) arr, = promote_dtypes_inexact(a) m, n = arr.shape[-2:] if m == 0 or n == 0: return jnp.empty(arr.shape[:-2] + (n, m), arr.dtype) arr = ufuncs.conj(arr) if rtol is None: max_rows_cols = max(arr.shape[-2:]) rtol = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps) rtol = jnp.asarray(rtol) u, s, vh = svd(arr, full_matrices=False, hermitian=hermitian) # Singular values less than or equal to ``rtol * largest_singular_value`` # are set to zero. rtol = lax.expand_dims(rtol[..., np.newaxis], range(s.ndim - rtol.ndim - 1)) cutoff = rtol * s[..., 0:1] s = jnp.where(s > cutoff, s, np.inf).astype(u.dtype) res = tensor_contractions.matmul(vh.mT, ufuncs.divide(u.mT, s[..., np.newaxis]), precision=lax.Precision.HIGHEST) return lax.convert_element_type(res, arr.dtype) @_pinv.defjvp @config.default_matmul_precision("float32") def _pinv_jvp(rtol, hermitian, primals, tangents): # The Differentiation of Pseudo-Inverses and Nonlinear Least Squares Problems # Whose Variables Separate. Author(s): G. H. Golub and V. Pereyra. SIAM # Journal on Numerical Analysis, Vol. 10, No. 2 (Apr., 1973), pp. 413-432. # (via https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Derivative) a, = primals # m x n a_dot, = tangents p = pinv(a, rtol=rtol, hermitian=hermitian) # n x m if hermitian: # svd(..., hermitian=True) symmetrizes its input, and the JVP must match. a = _symmetrize(a) a_dot = _symmetrize(a_dot) # TODO(phawkins): this could be simplified in the Hermitian case if we # supported triangular matrix multiplication. m, n = a.shape[-2:] if m >= n: s = (p @ _H(p)) @ _H(a_dot) # nxm t = (_H(a_dot) @ _H(p)) @ p # nxm p_dot = -(p @ a_dot) @ p + s - (s @ a) @ p + t - (p @ a) @ t else: # m < n s = p @ (_H(p) @ _H(a_dot)) t = _H(a_dot) @ (_H(p) @ p) p_dot = -p @ (a_dot @ p) + s - s @ (a @ p) + t - p @ (a @ t) return p, p_dot @export @jit def inv(a: ArrayLike) -> Array: """Return the inverse of a square matrix JAX implementation of :func:`numpy.linalg.inv`. Args: a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted. Returns: Array of shape ``(..., N, N)`` containing the inverse of the input. Notes: In most cases, explicitly computing the inverse of a matrix is ill-advised. For example, to compute ``x = inv(A) @ b``, it is more performant and numerically precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`. See Also: - :func:`jax.scipy.linalg.inv`: SciPy-style API for matrix inverse - :func:`jax.numpy.linalg.solve`: direct linear solver Examples: Compute the inverse of a 3x3 matrix >>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jnp.linalg.inv(a) >>> a_inv # doctest: +SKIP Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32) Check that multiplying with the inverse gives the identity: >>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool) Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b`` >>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32) Note, however, that explicitly computing the inverse in such a case can lead to poor performance and loss of precision as the size of the problem grows. Instead, you should use a direct solver like :func:`jax.numpy.linalg.solve`: >>> jnp.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32) """ arr = ensure_arraylike("jnp.linalg.inv", a) if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]: raise ValueError( f"Argument to inv must have shape [..., n, n], got {arr.shape}.") return solve( arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2])) @export @partial(jit, static_argnames=('ord', 'axis', 'keepdims')) def norm(x: ArrayLike, ord: int | str | None = None, axis: None | tuple[int, ...] | int = None, keepdims: bool = False) -> Array: """Compute the norm of a matrix or vector. JAX implementation of :func:`numpy.linalg.norm`. Args: x: N-dimensional array for which the norm will be computed. ord: specify the kind of norm to take. Default is Frobenius norm for matrices, and the 2-norm for vectors. For other options, see Notes below. axis: integer or sequence of integers specifying the axes over which the norm will be computed. For a single axis, compute a vector norm. For two axes, compute a matrix norm. Defaults to all axes of ``x``. keepdims: if True, the output array will have the same number of dimensions as the input, with the size of reduced axes replaced by ``1`` (default: False). Returns: array containing the specified norm of x. Notes: The flavor of norm computed depends on the value of ``ord`` and the number of axes being reduced. For **vector norms** (i.e. a single axis reduction): - ``ord=None`` (default) computes the 2-norm - ``ord=inf`` computes ``max(abs(x))`` - ``ord=-inf`` computes min(abs(x))`` - ``ord=0`` computes ``sum(x!=0)`` - for other numerical values, computes ``sum(abs(x) ** ord)**(1/ord)`` For **matrix norms** (i.e. two axes reductions): - ``ord='fro'`` or ``ord=None`` (default) computes the Frobenius norm - ``ord='nuc'`` computes the nuclear norm, or the sum of the singular values - ``ord=1`` computes ``max(abs(x).sum(0))`` - ``ord=-1`` computes ``min(abs(x).sum(0))`` - ``ord=2`` computes the 2-norm, i.e. the largest singular value - ``ord=-2`` computes the smallest singular value In the special case of ``ord=None`` and ``axis=None``, this function accepts an array of any dimension and computes the vector 2-norm of the flattened array. Examples: Vector norms: >>> x = jnp.array([3., 4., 12.]) >>> jnp.linalg.norm(x) Array(13., dtype=float32) >>> jnp.linalg.norm(x, ord=1) Array(19., dtype=float32) >>> jnp.linalg.norm(x, ord=0) Array(3., dtype=float32) Matrix norms: >>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.norm(x) # Frobenius norm Array(10.198039, dtype=float32) >>> jnp.linalg.norm(x, ord='nuc') # nuclear norm Array(10.762535, dtype=float32) >>> jnp.linalg.norm(x, ord=1) # 1-norm Array(10., dtype=float32) Batched vector norm: >>> jnp.linalg.norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32) """ x = ensure_arraylike("jnp.linalg.norm", x) x, = promote_dtypes_inexact(x) x_shape = np.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim),) num_axes = len(axis) if num_axes == 1: return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims) elif num_axes == 2: row_axis, col_axis = axis # pytype: disable=bad-unpacking if ord is None or ord in ('f', 'fro'): return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return reductions.amax(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims, initial=0) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == np.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amax(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims, initial=0) elif ord == -np.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return reductions.amin(reductions.sum(ufuncs.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) s = svd(x, compute_uv=False) if ord == 2: y = reductions.amax(s, axis=-1, initial=0) elif ord == -2: y = reductions.amin(s, axis=-1) else: y = reductions.sum(s, axis=-1) if keepdims: y = jnp.expand_dims(y, axis) return y else: raise ValueError(f"Invalid order '{ord}' for matrix norm.") else: raise ValueError(f"Improper number of axes for norm: {axis=}. Pass one axis to" " compute a vector-norm, or two axes to compute a matrix-norm.") @overload def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ... @overload def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ... @export @partial(jit, static_argnames=('mode',)) def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: """Compute the QR decomposition of an array JAX implementation of :func:`numpy.linalg.qr`. The QR decomposition of a matrix `A` is given by .. math:: A = QR Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular matrix. Args: a: array of shape (..., M, N) mode: Computational mode. Supported values are: - ``"reduced"`` (default): return `Q` of shape ``(..., M, K)`` and `R` of shape ``(..., K, N)``, where ``K = min(M, N)``. - ``"complete"``: return `Q` of shape ``(..., M, M)`` and `R` of shape ``(..., M, N)``. - ``"raw"``: return lapack-internal representations of shape ``(..., M, N)`` and ``(..., K)``. - ``"r"``: return `R` only. Returns: A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``, where: - ``Q`` is an orthogonal matrix of shape ``(..., M, K)`` (if ``mode`` is ``"reduced"``) or ``(..., M, M)`` (if ``mode`` is ``"complete"``). - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is ``"r"`` or ``"complete"``) or ``(..., K, N)`` (if ``mode`` is ``"reduced"``) with ``K = min(M, N)``. See also: - :func:`jax.scipy.linalg.qr`: SciPy-style QR decomposition API - :func:`jax.lax.linalg.qr`: XLA-style QR decomposition API Examples: Compute the QR decomposition of a matrix: >>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q # doctest: +SKIP Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R # doctest: +SKIP Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32) Check that ``Q`` is orthonormal: >>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool) Reconstruct the input: >>> jnp.allclose(Q @ R, a) Array(True, dtype=bool) """ a = ensure_arraylike("jnp.linalg.qr", a) a, = promote_dtypes_inexact(a) if mode == "raw": a, taus = lax_linalg.geqrf(a) return QRResult(a.mT, taus) if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") q, r = lax_linalg.qr(a, pivoting=False, full_matrices=full_matrices) if mode == "r": return r return QRResult(q, r) @export @jit def solve(a: ArrayLike, b: ArrayLike) -> Array: """Solve a linear system of equations. JAX implementation of :func:`numpy.linalg.solve`. This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` given ``a`` and ``b``. If ``a`` is singular, this will return ``nan`` or ``inf`` values. Args: a: array of shape ``(..., N, N)``. b: array of shape ``(N,)`` (for 1-dimensional right-hand-side) or ``(..., N, M)`` (for batched 2-dimensional right-hand-side). Returns: An array containing the result of the linear solve if ``a`` is non-singular. The result has shape ``(..., N)`` if ``b`` is of shape ``(N,)``, and has shape ``(..., N, M)`` otherwise. If ``a`` is singular, the result contains ``nan`` or ``inf`` values. See also: - :func:`jax.scipy.linalg.solve`: SciPy-style API for solving linear systems. - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. Examples: A simple 3x3 linear system: >>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32) Confirming that the result solves the system: >>> jnp.allclose(A @ x, b) Array(True, dtype=bool) """ a, b = ensure_arraylike("jnp.linalg.solve", a, b) a, b = promote_dtypes_inexact(a, b) if a.ndim < 2: raise ValueError( f"left hand array must be at least two dimensional; got {a.shape=}") # Check for invalid inputs that previously would have led to a batched 1D solve: if (b.ndim > 1 and a.ndim == b.ndim + 1 and a.shape[-1] == b.shape[-1] and a.shape[-1] != b.shape[-2]): raise ValueError( f"Invalid shapes for solve: {a.shape}, {b.shape}. Prior to JAX v0.5.0," " this would have been treated as a batched 1-dimensional solve." " To recover this behavior, use solve(a, b[..., None]).squeeze(-1).") signature = "(m,m),(m)->(m)" if b.ndim == 1 else "(m,m),(m,n)->(m,n)" return jnp.vectorize(lax_linalg._solve, signature=signature)(a, b) def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: # TODO: add lstsq to lax_linalg and implement this function via those wrappers. # TODO: add custom jvp rule for more robust lstsq differentiation a, b = promote_dtypes_inexact(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") b_orig_ndim = b.ndim if b_orig_ndim == 1: b = b[:, None] if a.ndim != 2: raise TypeError( f"{a.ndim}-dimensional array given. Array must be two-dimensional") if b.ndim != 2: raise TypeError( f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype if a.size == 0: s = jnp.empty(0, dtype=a.dtype) rank = jnp.array(0, dtype=int) x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype) else: if rcond is None: rcond = float(jnp.finfo(dtype).eps) * max(n, m) else: rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) u, s, vt = svd(a, full_matrices=False) mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() safe_s = jnp.where(mask, s, 1).astype(a.dtype) s_inv = jnp.where(mask, 1 / safe_s, 0)[:, np.newaxis] uTb = tensor_contractions.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) x = tensor_contractions.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: b_estimate = tensor_contractions.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s _jit_lstsq = jit(partial(_lstsq, numpy_resid=False)) @export def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *, numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]: """ Return the least-squares solution to a linear equation. JAX implementation of :func:`numpy.linalg.lstsq`. Args: a: array of shape ``(M, N)`` representing the coefficient matrix. b: array of shape ``(M,)`` or ``(M, K)`` representing the right-hand side. rcond: Cut-off ratio for small singular values. Singular values smaller than ``rcond * largest_singular_value`` are treated as zero. If None (default), the optimal value will be used to reduce floating point errors. numpy_resid: If True, compute and return residuals in the same way as NumPy's `linalg.lstsq`. This is necessary if you want to precisely replicate NumPy's behavior. If False (default), a more efficient method is used to compute residuals. Returns: Tuple of arrays ``(x, resid, rank, s)`` where - ``x`` is a shape ``(N,)`` or ``(N, K)`` array containing the least-squares solution. - ``resid`` is the sum of squared residual of shape ``()`` or ``(K,)``. - ``rank`` is the rank of the matrix ``a``. - ``s`` is the singular values of the matrix ``a``. Examples: >>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([5, 6]) >>> x, _, _, _ = jnp.linalg.lstsq(a, b) >>> with jnp.printoptions(precision=3): ... print(x) [-4. 4.5] """ a, b = ensure_arraylike("jnp.linalg.lstsq", a, b) if numpy_resid: return _lstsq(a, b, rcond, numpy_resid=True) return _jit_lstsq(a, b, rcond) @export def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): r"""Compute the cross-product of two 3D vectors JAX implementation of :func:`numpy.linalg.cross` Args: x1: N-dimensional array, with ``x1.shape[axis] == 3`` x2: N-dimensional array, with ``x2.shape[axis] == 3``, and other axes broadcast-compatible with ``x1``. axis: axis along which to take the cross product (default: -1). Returns: array containing the result of the cross-product See Also: :func:`jax.numpy.cross`: more flexible cross-product API. Examples: Showing that :math:`\hat{x} \times \hat{y} = \hat{z}`: >>> x = jnp.array([1., 0., 0.]) >>> y = jnp.array([0., 1., 0.]) >>> jnp.linalg.cross(x, y) Array([0., 0., 1.], dtype=float32) Cross product of :math:`\hat{x}` with all three standard unit vectors, via broadcasting: >>> xyz = jnp.eye(3) >>> jnp.linalg.cross(x, xyz, axis=-1) Array([[ 0., 0., 0.], [ 0., 0., 1.], [ 0., -1., 0.]], dtype=float32) """ x1, x2 = ensure_arraylike("jnp.linalg.outer", x1, x2) if x1.shape[axis] != 3 or x2.shape[axis] != 3: raise ValueError( "Both input arrays must be (arrays of) 3-dimensional vectors, " f"but they have {x1.shape[axis]=} and {x2.shape[axis]=}" ) return jnp.cross(x1, x2, axis=axis) @export def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Compute the outer product of two 1-dimensional arrays. JAX implementation of :func:`numpy.linalg.outer`. Args: x1: array x2: array Returns: array containing the outer product of ``x1`` and ``x2`` See also: :func:`jax.numpy.outer`: similar function in the main :mod:`jax.numpy` module. Examples: >>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.outer(x1, x2) Array([[ 4, 5, 6], [ 8, 10, 12], [12, 15, 18]], dtype=int32) """ x1, x2 = ensure_arraylike("jnp.linalg.outer", x1, x2) if x1.ndim != 1 or x2.ndim != 1: raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}") return x1[:, None] * x2[None, :] @export def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array: """Compute the norm of a matrix or stack of matrices. JAX implementation of :func:`numpy.linalg.matrix_norm` Args: x: array of shape ``(..., M, N)`` for which to take the norm. keepdims: if True, keep the reduced dimensions in the output. ord: A string or int specifying the type of norm; default is the Frobenius norm. See :func:`numpy.linalg.norm` for details on available options. Returns: array containing the norm of ``x``. Has shape ``x.shape[:-2]`` if ``keepdims`` is False, or shape ``(..., 1, 1)`` if ``keepdims`` is True. See also: - :func:`jax.numpy.linalg.vector_norm`: Norm of a vector or stack of vectors. - :func:`jax.numpy.linalg.norm`: More general matrix or vector norm. Examples: >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.linalg.matrix_norm(x) Array(16.881943, dtype=float32) """ x = ensure_arraylike('jnp.linalg.matrix_norm', x) return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1)) @export def matrix_transpose(x: ArrayLike, /) -> Array: """Transpose a matrix or stack of matrices. JAX implementation of :func:`numpy.linalg.matrix_transpose`. Args: x: array of shape ``(..., M, N)`` Returns: array of shape ``(..., N, M)`` containing the matrix transpose of ``x``. See also: :func:`jax.numpy.transpose`: more general transpose operation. Examples: Transpose of a single matrix: >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.matrix_transpose(x) Array([[1, 4], [2, 5], [3, 6]], dtype=int32) Transpose of a stack of matrices: >>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.linalg.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32) For convenience, the same computation can be done via the :attr:`~jax.Array.mT` property of JAX array objects: >>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32) """ x_arr = ensure_arraylike('jnp.linalg.matrix_transpose', x) ndim = x_arr.ndim if ndim < 2: raise ValueError(f"matrix_transpose requires at least 2 dimensions; got {ndim=}") return lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2)) @export def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ord: int | str = 2) -> Array: """Compute the vector norm of a vector or batch of vectors. JAX implementation of :func:`numpy.linalg.vector_norm`. Args: x: N-dimensional array for which to take the norm. axis: optional axis along which to compute the vector norm. If None (default) then ``x`` is flattened and the norm is taken over all values. keepdims: if True, keep the reduced dimensions in the output. ord: A string or int specifying the type of norm; default is the 2-norm. See :func:`numpy.linalg.norm` for details on available options. Returns: array containing the norm of ``x``. See also: - :func:`jax.numpy.linalg.matrix_norm`: Norm of a matrix or stack of matrices. - :func:`jax.numpy.linalg.norm`: More general matrix or vector norm. Examples: Norm of a single vector: >>> x = jnp.array([1., 2., 3.]) >>> jnp.linalg.vector_norm(x) Array(3.7416575, dtype=float32) Norm of a batch of vectors: >>> x = jnp.array([[1., 2., 3.], ... [4., 5., 7.]]) >>> jnp.linalg.vector_norm(x, axis=1) Array([3.7416575, 9.486833 ], dtype=float32) """ x = ensure_arraylike('jnp.linalg.vector_norm', x) if ord is None or ord == 2: return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis, keepdims=keepdims)) elif ord == np.inf: return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims, initial=0) elif ord == -np.inf: return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims) elif isinstance(ord, str): msg = f"Invalid order '{ord}' for vector norm." if ord == "inf": msg += "Use 'jax.numpy.inf' instead." if ord == "-inf": msg += "Use '-jax.numpy.inf' instead." raise ValueError(msg) else: abs_x = ufuncs.abs(x) ord_arr = lax_internal._const(abs_x, ord) ord_inv = lax_internal._const(abs_x, 1. / ord_arr) out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims) return ufuncs.power(out, ord_inv) @export def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Compute the (batched) vector conjugate dot product of two arrays. JAX implementation of :func:`numpy.linalg.vecdot`. Args: x1: left-hand side array. x2: right-hand side array. Size of ``x2[axis]`` must match size of ``x1[axis]``, and remaining dimensions must be broadcast-compatible. axis: axis along which to compute the dot product (default: -1) precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two such values indicating precision of ``x1`` and ``x2``. preferred_element_type: either ``None`` (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. Returns: array containing the conjugate dot product of ``x1`` and ``x2`` along ``axis``. The non-contracted dimensions are broadcast together. See also: - :func:`jax.numpy.vecdot`: similar API in the ``jax.numpy`` namespace. - :func:`jax.numpy.linalg.matmul`: matrix multiplication. - :func:`jax.numpy.linalg.tensordot`: general tensor dot product. Examples: Vector dot product of two 1D arrays: >>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.vecdot(x1, x2) Array(32, dtype=int32) Batched vector dot product of two 2D arrays: >>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([[2, 3, 4]]) >>> jnp.linalg.vecdot(x1, x2, axis=-1) Array([20, 47], dtype=int32) """ x1, x2 = ensure_arraylike('jnp.linalg.vecdot', x1, x2) return tensor_contractions.vecdot(x1, x2, axis=axis, precision=precision, preferred_element_type=preferred_element_type) @export def matmul(x1: ArrayLike, x2: ArrayLike, /, *, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Perform a matrix multiplication. JAX implementation of :func:`numpy.linalg.matmul`. Args: x1: first input array, of shape ``(..., N)``. x2: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of ``x1``. precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two such values indicating precision of ``x1`` and ``x2``. preferred_element_type: either ``None`` (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. Returns: array containing the matrix product of the inputs. Shape is ``x1.shape[:-1]`` if ``x2.ndim == 1``, otherwise the shape is ``(..., M)``. See Also: :func:`jax.numpy.matmul`: NumPy API for this function. :func:`jax.numpy.linalg.vecdot`: batched vector product. :func:`jax.numpy.linalg.tensordot`: batched tensor product. Examples: Vector dot products: >>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.matmul(x1, x2) Array(32, dtype=int32) Matrix dot product: >>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.linalg.matmul(x1, x2) Array([[22, 28], [49, 64]], dtype=int32) For convenience, in all cases you can do the same computation using the ``@`` operator: >>> x1 @ x2 Array([[22, 28], [49, 64]], dtype=int32) """ x1, x2 = ensure_arraylike('jnp.linalg.matmul', x1, x2) return tensor_contractions.matmul(x1, x2, precision=precision, preferred_element_type=preferred_element_type) @export def tensordot(x1: ArrayLike, x2: ArrayLike, /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, precision: PrecisionLike = None, preferred_element_type: DTypeLike | None = None) -> Array: """Compute the tensor dot product of two N-dimensional arrays. JAX implementation of :func:`numpy.linalg.tensordot`. Args: x1: N-dimensional array x2: M-dimensional array axes: integer or tuple of sequences of integers. If an integer `k`, then sum over the last `k` axes of ``x1`` and the first `k` axes of ``x2``, in order. If a tuple, then ``axes[0]`` specifies the axes of ``x1`` and ``axes[1]`` specifies the axes of ``x2``. precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two such values indicating precision of ``x1`` and ``x2``. preferred_element_type: either ``None`` (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. Returns: array containing the tensor dot product of the inputs See also: - :func:`jax.numpy.tensordot`: equivalent API in the :mod:`jax.numpy` namespace. - :func:`jax.numpy.einsum`: NumPy API for more general tensor contractions. - :func:`jax.lax.dot_general`: XLA API for more general tensor contractions. Examples: >>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) >>> jnp.linalg.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32) Equivalent result when specifying the axes as explicit sequences: >>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1])) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32) Equivalent result via :func:`~jax.numpy.einsum`: >>> jnp.einsum('ijk,jkm->im', x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32) Setting ``axes=1`` for two-dimensional inputs is equivalent to a matrix multiplication: >>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> x2 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.tensordot(x1, x2, axes=1) Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) >>> x1 @ x2 Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) Setting ``axes=0`` for one-dimensional inputs is equivalent to :func:`jax.numpy.linalg.outer`: >>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3]) >>> jnp.linalg.tensordot(x1, x2, axes=0) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) >>> jnp.linalg.outer(x1, x2) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) """ x1, x2 = ensure_arraylike('jnp.linalg.tensordot', x1, x2) return tensor_contractions.tensordot(x1, x2, axes=axes, precision=precision, preferred_element_type=preferred_element_type) @export def svdvals(x: ArrayLike, /) -> Array: """Compute the singular values of a matrix. JAX implementation of :func:`numpy.linalg.svdvals`. Args: x: array of shape ``(..., M, N)`` for which singular values will be computed. Returns: array of singular values of shape ``(..., K)`` with ``K = min(M, N)``. See also: :func:`jax.numpy.linalg.svd`: compute singular values and singular vectors Examples: >>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.svdvals(x) Array([9.508031 , 0.7728694], dtype=float32) """ x = ensure_arraylike('jnp.linalg.svdvals', x) return svd(x, compute_uv=False, hermitian=False) @export def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array: """Extract the diagonal of an matrix or stack of matrices. JAX implementation of :func:`numpy.linalg.diagonal`. Args: x: array of shape ``(..., M, N)`` from which the diagonal will be extracted. offset: positive or negative offset from the main diagonal. Returns: Array of shape ``(..., K)`` where ``K`` is the length of the specified diagonal. See Also: - :func:`jax.numpy.diagonal`: more general functionality for extracting diagonals. - :func:`jax.numpy.diag`: create a diagonal matrix from values. Examples: Diagonals of a single matrix: >>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.diagonal(x) Array([ 1, 6, 11], dtype=int32) >>> jnp.linalg.diagonal(x, offset=1) Array([ 2, 7, 12], dtype=int32) >>> jnp.linalg.diagonal(x, offset=-1) Array([ 5, 10], dtype=int32) Batched diagonals: >>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.diagonal(x) Array([[ 0, 5, 10], [12, 17, 22]], dtype=int32) """ x = ensure_arraylike('jnp.linalg.diagonal', x) return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1) @export def tensorinv(a: ArrayLike, ind: int = 2) -> Array: """Compute the tensor inverse of an array. JAX implementation of :func:`numpy.linalg.tensorinv`. This computes the inverse of the :func:`~jax.numpy.linalg.tensordot` operation with the same ``ind`` value. Args: a: array to be inverted. Must have ``prod(a.shape[:ind]) == prod(a.shape[ind:])`` ind: positive integer specifying the number of indices in the tensor product. Returns: array of shape ``(*a.shape[ind:], *a.shape[:ind])`` containing the tensor inverse of ``a``. See also: - :func:`jax.numpy.linalg.tensordot` - :func:`jax.numpy.linalg.tensorsolve` Examples: >>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool) """ arr = ensure_arraylike("tensorinv", a) ind = operator.index(ind) if ind <= 0: raise ValueError(f"ind must be a positive integer; got {ind=}") contracting_shape, batch_shape = arr.shape[:ind], arr.shape[ind:] flatshape = (math.prod(contracting_shape), math.prod(batch_shape)) if flatshape[0] != flatshape[1]: raise ValueError("tensorinv is only possible when the product of the first" " `ind` dimensions equals that of the remaining dimensions." f" got {arr.shape=} with {ind=}.") return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape) @export def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array: """Solve the tensor equation a x = b for x. JAX implementation of :func:`numpy.linalg.tensorsolve`. Args: a: input array. After reordering via ``axes`` (see below), shape must be ``(*b.shape, *x.shape)``. b: right-hand-side array. axes: optional tuple specifying axes of ``a`` that should be moved to the end Returns: array x such that after reordering of axes of ``a``, ``tensordot(a, x, x.ndim)`` is equivalent to ``b``. See also: - :func:`jax.numpy.linalg.tensordot` - :func:`jax.numpy.linalg.tensorinv` Examples: >>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,) Now show that ``x`` can be used to reconstruct ``b`` using :func:`~jax.numpy.linalg.tensordot`: >>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool) """ a_arr, b_arr = ensure_arraylike("tensorsolve", a, b) if axes is not None: a_arr = jnp.moveaxis(a_arr, axes, len(axes) * (a_arr.ndim - 1,)) out_shape = a_arr.shape[b_arr.ndim:] if a_arr.shape[:b_arr.ndim] != b_arr.shape: raise ValueError("After moving axes to end, leading shape of a must match shape of b." f" got a.shape={a_arr.shape}, b.shape={b_arr.shape}") if b_arr.size != math.prod(out_shape): raise ValueError("Input arrays must have prod(a.shape[:b.ndim]) == prod(a.shape[b.ndim:]);" f" got a.shape={a_arr.shape}, b.ndim={b_arr.ndim}.") a_arr = a_arr.reshape(b_arr.size, math.prod(out_shape)) return solve(a_arr, b_arr.ravel()).reshape(out_shape) @export def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array: """Efficiently compute matrix products between a sequence of arrays. JAX implementation of :func:`numpy.linalg.multi_dot`. JAX internally uses the opt_einsum library to compute the most efficient operation order. Args: arrays: sequence of arrays. All must be two-dimensional, except the first and last which may be one-dimensional. precision: either ``None`` (default), which means the default precision for the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). Returns: an array representing the equivalent of ``reduce(jnp.matmul, arrays)``, but evaluated in the optimal order. This function exists because the cost of computing sequences of matmul operations can differ vastly depending on the order in which the operations are evaluated. For a single matmul, the number of floating point operations (flops) required to compute a matrix product can be approximated this way: >>> def approx_flops(x, y): ... # for 2D x and y, with x.shape[1] == y.shape[0] ... return 2 * x.shape[0] * x.shape[1] * y.shape[1] Suppose we have three matrices that we'd like to multiply in sequence: >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.normal(key1, shape=(200, 5)) >>> y = jax.random.normal(key2, shape=(5, 100)) >>> z = jax.random.normal(key3, shape=(100, 10)) Because of associativity of matrix products, there are two orders in which we might evaluate the product ``x @ y @ z``, and both produce equivalent outputs up to floating point precision: >>> result1 = (x @ y) @ z >>> result2 = x @ (y @ z) >>> jnp.allclose(result1, result2, atol=1E-4) Array(True, dtype=bool) But the computational cost of these differ greatly: >>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z)) (x @ y) @ z flops: 600000 >>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z)) x @ (y @ z) flops: 30000 The second approach is about 20x more efficient in terms of estimated flops! ``multi_dot`` is a function that will automatically choose the fastest computational path for such problems: >>> result3 = jnp.linalg.multi_dot([x, y, z]) >>> jnp.allclose(result1, result3, atol=1E-4) Array(True, dtype=bool) We can use JAX's :ref:`ahead-of-time-lowering` tools to estimate the total flops of each approach, and confirm that ``multi_dot`` is choosing the more efficient option: >>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops'] 600000.0 >>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops'] 30000.0 >>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops'] 30000.0 """ arrs = list(ensure_arraylike('jnp.linalg.multi_dot', *arrays)) if len(arrs) < 2: raise ValueError(f"multi_dot requires at least two arrays; got len(arrays)={len(arrs)}") if not (arrs[0].ndim in (1, 2) and arrs[-1].ndim in (1, 2) and all(a.ndim == 2 for a in arrs[1:-1])): raise ValueError("multi_dot: input arrays must all be two-dimensional, except for" " the first and last array which may be 1 or 2 dimensional." f" Got array shapes {[a.shape for a in arrs]}") if any(a.shape[-1] != b.shape[0] for a, b in zip(arrs[:-1], arrs[1:])): raise ValueError("multi_dot: last dimension of each array must match first dimension" f" of following array. Got array shapes {[a.shape for a in arrs]}") einsum_axes: list[tuple[int, ...]] = [(i, i+1) for i in range(len(arrs))] if arrs[0].ndim == 1: einsum_axes[0] = einsum_axes[0][1:] if arrs[-1].ndim == 1: einsum_axes[-1] = einsum_axes[-1][:1] return einsum.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload] optimize='auto', precision=precision) @export @partial(jit, static_argnames=['p']) def cond(x: ArrayLike, p=None): """Compute the condition number of a matrix. JAX implementation of :func:`numpy.linalg.cond`. The condition number is defined as ``norm(x, p) * norm(inv(x), p)``. For ``p = 2`` (the default), the condition number is the ratio of the largest to the smallest singular value. Args: x: array of shape ``(..., M, N)`` for which to compute the condition number. p: the order of the norm to use. One of ``{None, 1, -1, 2, -2, inf, -inf, 'fro'}``; see :func:`jax.numpy.linalg.norm` for the meaning of these. The default is ``p = None``, which is equivalent to ``p = 2``. If not in ``{None, 2, -2}`` then ``x`` must be square, i.e. ``M = N``. Returns: array of shape ``x.shape[:-2]`` containing the condition number. See also: :func:`jax.numpy.linalg.norm` Examples: Well-conditioned matrix: >>> x = jnp.array([[1, 2], ... [2, 1]]) >>> jnp.linalg.cond(x) Array(3., dtype=float32) Ill-conditioned matrix: >>> x = jnp.array([[1, 2], ... [0, 0]]) >>> jnp.linalg.cond(x) Array(inf, dtype=float32) """ arr = ensure_arraylike("cond", x) if arr.ndim < 2: raise ValueError(f"jnp.linalg.cond: input array must be at least 2D; got {arr.shape=}") if arr.shape[-1] == 0 or arr.shape[-2] == 0: raise ValueError(f"jnp.linalg.cond: input array must not be empty; got {arr.shape=}") if p is None or p == 2: s = svdvals(x) return s[..., 0] / s[..., -1] elif p == -2: s = svdvals(x) r = s[..., -1] / s[..., 0] else: if arr.shape[-2] != arr.shape[-1]: raise ValueError(f"jnp.linalg.cond: for {p=}, array must be square; got {arr.shape=}") r = norm(x, ord=p, axis=(-2, -1)) * norm(inv(x), ord=p, axis=(-2, -1)) # Convert NaNs to infs where original array has no NaNs. return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), np.inf, r) @export def trace(x: ArrayLike, /, *, offset: int = 0, dtype: DTypeLike | None = None) -> Array: """Compute the trace of a matrix. JAX implementation of :func:`numpy.linalg.trace`. Args: x: array of shape ``(..., M, N)`` and whose innermost two dimensions form MxN matrices for which to take the trace. offset: positive or negative offset from the main diagonal (default: 0). dtype: data type of the returned array (default: ``None``). If ``None``, then output dtype will match the dtype of ``x``, promoted to default precision in the case of integer types. Returns: array of batched traces with shape ``x.shape[:-2]`` See also: - :func:`jax.numpy.trace`: similar API in the ``jax.numpy`` namespace. Examples: Trace of a single matrix: >>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.trace(x) Array(18, dtype=int32) >>> jnp.linalg.trace(x, offset=1) Array(21, dtype=int32) >>> jnp.linalg.trace(x, offset=-1, dtype="float32") Array(15., dtype=float32) Batched traces: >>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.trace(x) Array([15, 51], dtype=int32) """ x = ensure_arraylike('jnp.linalg.trace', x) return jnp.trace(x, offset=offset, axis1=-2, axis2=-1, dtype=dtype)