# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from functools import partial import operator from typing import cast, Any import numpy as np import jax.numpy as jnp from jax import jit from jax import jvp from jax import vmap from jax import lax from jax._src import api_util from jax._src import config from jax._src import core from jax._src import custom_derivatives from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact from jax._src.ops import special as ops_special from jax._src.third_party.scipy.betaln import betaln as _betaln_impl from jax._src.typing import Array, ArrayLike from jax._src.nn.functions import softmax as nn_softmax from jax._src.nn.functions import log_softmax as nn_log_softmax def gammaln(x: ArrayLike) -> Array: r"""Natural log of the absolute value of the gamma function. JAX implementation of :obj:`scipy.special.gammaln`. .. math:: \mathrm{gammaln}(x) = \log(|\Gamma(x)|) Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Args: x: arraylike, real valued. Returns: array containing the values of the log-gamma function See Also: - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function Notes: ``gammaln`` does not support complex-valued inputs. """ x, = promote_args_inexact("gammaln", x) return lax.lgamma(x) @jit def gammasgn(x: ArrayLike) -> Array: r"""Sign of the gamma function. JAX implementation of :obj:`scipy.special.gammasgn`. .. math:: \mathrm{gammasgn}(x) = \begin{cases} +1 & \Gamma(x) > 0 \\ -1 & \Gamma(x) < 0 \end{cases} Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Because :math:`\Gamma(x)` is never zero, no condition is required for this case. * if :math:`x = -\infty`, NaN is returned. * if :math:`x = \pm 0`, :math:`\pm 1` is returned. * if :math:`x` is a negative integer, NaN is returned. The sign of gamma at a negative integer depends on from which side the pole is approached. * if :math:`x = \infty`, :math:`1` is returned. * if :math:`x` is NaN, NaN is returned. Args: x: arraylike, real valued. Returns: array containing the sign of the gamma function See Also: - :func:`jax.scipy.special.gamma`: the gamma function - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function """ x, = promote_args_inexact("gammasgn", x) typ = x.dtype.type floor_x = lax.floor(x) x_negative = x < 0 return jnp.select( [(x_negative & (x == floor_x)) | jnp.isnan(x), (x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))], [typ(np.nan), typ(-1.0)], typ(1.0)) def gamma(x: ArrayLike) -> Array: r"""The gamma function. JAX implementation of :obj:`scipy.special.gamma`. The gamma function is defined for :math:`\Re(z)>0` as .. math:: \mathrm{gamma}(z) = \Gamma(z) = \int_0^\infty t^{z-1}e^{-t}\mathrm{d}t and is extended by analytic continuation to arbitrary complex values `z`. For positive integers `n`, the gamma function is related to the :func:`~jax.scipy.special.factorial` function via the following identity: .. math:: \Gamma(n) = (n - 1)! * if :math:`z = -\infty`, NaN is returned. * if :math:`x = \pm 0`, :math:`\pm \infty` is returned. * if :math:`x` is a negative integer, NaN is returned. The sign of gamma at a negative integer depends on from which side the pole is approached. * if :math:`x = \infty`, :math:`\infty` is returned. * if :math:`x` is NaN, NaN is returned. Args: x: arraylike, real valued. Returns: array containing the values of the gamma function See Also: - :func:`jax.scipy.special.factorial`: the factorial function. - :func:`jax.scipy.special.gammaln`: the natural log of the gamma function - :func:`jax.scipy.special.gammasgn`: the sign of the gamma function Notes: Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs. """ x, = promote_args_inexact("gamma", x) return gammasgn(x) * lax.exp(lax.lgamma(x)) def betaln(a: ArrayLike, b: ArrayLike) -> Array: r"""Natural log of the absolute value of the beta function JAX implementation of :obj:`scipy.special.betaln`. .. math:: \mathrm{betaln}(a, b) = \log B(a, b) where :math:`B` is the :func:`~jax.scipy.special.beta` function. Args: a: arraylike, real-valued. Parameter *a* of the beta distribution. b: arraylike, real-valued. Parameter *b* of the beta distribution. Returns: array containing the values of the log-beta function See Also: :func:`jax.scipy.special.beta` """ a, b = promote_args_inexact("betaln", a, b) return _betaln_impl(a, b) def factorial(n: ArrayLike, exact: bool = False) -> Array: r"""Factorial function JAX implementation of :obj:`scipy.special.factorial` .. math:: \mathrm{factorial}(n) = n! = \prod_{k=1}^n k Args: n: arraylike, values for which factorial will be computed elementwise exact: bool, only ``exact=False`` is supported. Returns: array containing values of the factorial. Notes: This computes the float-valued factorial via the :func:`~jax.scipy.special.gamma` function. JAX does not support exact factorials, because it is not particularly useful: above ``n=20``, the exact result cannot be represented by 64-bit integers, which are the largest integers available to JAX. See Also: :func:`jax.scipy.special.gamma` """ if exact: raise NotImplementedError("factorial with exact=True") n, = promote_args_inexact("factorial", n) return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) def beta(a: ArrayLike, b: ArrayLike) -> Array: r"""The beta function JAX implementation of :obj:`scipy.special.beta`. .. math:: \mathrm{beta}(a, b) = B(a, b) = \frac{\Gamma(a)\Gamma(b)}{\Gamma(a + b)} where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Args: a: arraylike, real-valued. Parameter *a* of the beta distribution. b: arraylike, real-valued. Parameter *b* of the beta distribution. Returns: array containing the values of the beta function. See Also: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.betaln` """ a, b = promote_args_inexact("beta", a, b) sign = gammasgn(a) * gammasgn(b) * gammasgn(a + b) return sign * lax.exp(betaln(a, b)) def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""The regularized incomplete beta function. JAX implementation of :obj:`scipy.special.betainc`. .. math:: \mathrm{betainc}(a, b, x) = B(a, b)\int_0^x t^{a-1}(1-t^{b-1})\mathrm{d}t where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. Args: a: arraylike, real-valued. Parameter *a* of the beta distribution. b: arraylike, real-valued. Parameter *b* of the beta distribution. x: arraylike, real-valued. Upper limit of the integration. Returns: array containing values of the betainc function See Also: - :func:`jax.scipy.special.beta` - :func:`jax.scipy.special.betaln` """ a, b, x = promote_args_inexact("betainc", a, b, x) return lax.betainc(a, b, x) def digamma(x: ArrayLike) -> Array: r"""The digamma function JAX implementation of :obj:`scipy.special.digamma`. .. math:: \mathrm{digamma}(z) = \psi(z) = \frac{\mathrm{d}}{\mathrm{d}z}\log \Gamma(z) where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function. Args: x: arraylike, real-valued. Returns: array containing values of the digamma function. Notes: The JAX version of `digamma` accepts real-valued inputs. See also: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.polygamma` """ x, = promote_args_inexact("digamma", x) return lax.digamma(x) def gammainc(a: ArrayLike, x: ArrayLike) -> Array: r"""The regularized lower incomplete gamma function. JAX implementation of :obj:`scipy.special.gammainc`. .. math:: \mathrm{gammainc}(x; a) = \frac{1}{\Gamma(a)}\int_0^x t^{a-1}e^{-t}\mathrm{d}t where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. Args: a: arraylike, real-valued. Positive shape parameter of the gamma distribution. x: arraylike, real-valued. Non-negative upper limit of integration Returns: array containing values of the gammainc function. See Also: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.gammaincc` """ a, x = promote_args_inexact("gammainc", a, x) return lax.igamma(a, x) def gammaincc(a: ArrayLike, x: ArrayLike) -> Array: r"""The regularized upper incomplete gamma function. JAX implementation of :obj:`scipy.special.gammaincc`. .. math:: \mathrm{gammaincc}(x; a) = \frac{1}{\Gamma(a)}\int_x^\infty t^{a-1}e^{-t}\mathrm{d}t where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function. Args: a: arraylike, real-valued. Positive shape parameter of the gamma distribution. x: arraylike, real-valued. Non-negative lower limit of integration Returns: array containing values of the gammaincc function. See Also: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.gammainc` """ a, x = promote_args_inexact("gammaincc", a, x) return lax.igammac(a, x) def erf(x: ArrayLike) -> Array: r"""The error function JAX implementation of :obj:`scipy.special.erf`. .. math:: \mathrm{erf}(x) = \frac{2}{\sqrt\pi} \int_{0}^x e^{-t^2} \mathrm{d}t Args: x: arraylike, real-valued. Returns: array containing values of the error function. Notes: The JAX version only supports real-valued inputs. See also: - :func:`jax.scipy.special.erfc` - :func:`jax.scipy.special.erfinv` """ x, = promote_args_inexact("erf", x) return lax.erf(x) def erfc(x: ArrayLike) -> Array: r"""The complement of the error function JAX implementation of :obj:`scipy.special.erfc`. .. math:: \mathrm{erfc}(x) = \frac{2}{\sqrt\pi} \int_{x}^\infty e^{-t^2} \mathrm{d}t This is the complement of the error function :func:`~jax.scipy.special.erf`, ``erfc(x) = 1 - erf(x)``. Args: x: arraylike, real-valued. Returns: array containing values of the complement of the error function. Notes: The JAX version only supports real-valued inputs. See also: - :func:`jax.scipy.special.erf` - :func:`jax.scipy.special.erfinv` """ x, = promote_args_inexact("erfc", x) return lax.erfc(x) def erfinv(x: ArrayLike) -> Array: """The inverse of the error function JAX implementation of :obj:`scipy.special.erfinv`. Returns the inverse of :func:`~jax.scipy.special.erf`. Args: x: arraylike, real-valued. Returns: array containing values of the inverse error function. Notes: The JAX version only supports real-valued inputs. See also: - :func:`jax.scipy.special.erf` - :func:`jax.scipy.special.erfc` """ x, = promote_args_inexact("erfinv", x) return lax.erf_inv(x) @custom_derivatives.custom_jvp def logit(x: ArrayLike) -> Array: r"""The logit function JAX implementation of :obj:`scipy.special.logit`. .. math:: \mathrm{logit}(p) = \log\frac{p}{1 - p} Args: x: arraylike, real-valued. Returns: array containing values of the logit function. """ x, = promote_args_inexact("logit", x) return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) logit.defjvps( lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) def expit(x: ArrayLike) -> Array: r"""The logistic sigmoid (expit) function JAX implementation of :obj:`scipy.special.expit`. .. math:: \mathrm{expit}(x) = \frac{1}{1 + e^{-x}} Args: x: arraylike, real-valued. Returns: array containing values of the expit function. """ x, = promote_args_inexact("expit", x) return lax.logistic(x) logsumexp = ops_special.logsumexp @custom_derivatives.custom_jvp def xlogy(x: ArrayLike, y: ArrayLike) -> Array: """Compute x*log(y), returning 0 for x=0. JAX implementation of :obj:`scipy.special.xlogy`. This is defined to return zero when :math:`(x, y) = (0, 0)`, with a custom derivative rule so that automatic differentiation is well-defined at this point. Args: x: arraylike, real-valued. y: arraylike, real-valued. Returns: array containing xlogy values. See also: :func:`jax.scipy.special.xlog1py` """ # Note: xlogy(0, 0) should return 0 according to the function documentation. x, y = promote_args_inexact("xlogy", x, y) x_ok = x != 0. return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x)) def _xlogy_jvp(primals, tangents): (x, y) = primals (x_dot, y_dot) = tangents result = xlogy(x, y) return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype) xlogy.defjvp(_xlogy_jvp) @custom_derivatives.custom_jvp def xlog1py(x: ArrayLike, y: ArrayLike) -> Array: """Compute x*log(1 + y), returning 0 for x=0. JAX implementation of :obj:`scipy.special.xlog1py`. This is defined to return 0 when :math:`(x, y) = (0, -1)`, with a custom derivative rule so that automatic differentiation is well-defined at this point. Args: x: arraylike, real-valued. y: arraylike, real-valued. Returns: array containing xlog1py values. See also: :func:`jax.scipy.special.xlogy` """ # Note: xlog1py(0, -1) should return 0 according to the function documentation. x, y = promote_args_inexact("xlog1py", x, y) x_ok = x != 0. return jnp.where(x_ok, lax.mul(x, lax.log1p(y)), jnp.zeros_like(x)) def _xlog1py_jvp(primals, tangents): (x, y) = primals (x_dot, y_dot) = tangents result = xlog1py(x, y) return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype) xlog1py.defjvp(_xlog1py_jvp) @custom_derivatives.custom_jvp def _xlogx(x): """Compute x log(x) with well-defined derivatives.""" return xlogy(x, x) def _xlogx_jvp(primals, tangents): x, = primals x_dot, = tangents return _xlogx(x), x_dot * (lax.log(x) + 1) _xlogx.defjvp(_xlogx_jvp) def entr(x: ArrayLike) -> Array: r"""The entropy function JAX implementation of :obj:`scipy.special.entr`. .. math:: \mathrm{entr}(x) = \begin{cases} -x\log(x) & x > 0 \\ 0 & x = 0\\ -\infty & \mathrm{otherwise} \end{cases} Args: x: arraylike, real-valued. Returns: array containing entropy values. See also: - :func:`jax.scipy.special.kl_div` - :func:`jax.scipy.special.rel_entr` """ x, = promote_args_inexact("entr", x) return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(_xlogx(x))) def multigammaln(a: ArrayLike, d: ArrayLike) -> Array: r"""The natural log of the multivariate gamma function. JAX implementation of :func:`scipy.special.multigammaln`. .. math:: \mathrm{multigammaln}(a, d) = \log\Gamma_d(a) where .. math:: \Gamma_d(a) = \pi^{d(d-1)/4}\prod_{i=1}^d\Gamma(a-(i-1)/2) and :math:`\Gamma(x)` is the :func:`~jax.scipy.special.gamma` function. Args: a: arraylike, real-valued. d: int, the dimension of the integration space. Returns: array containing values of the log-multigamma function. See also: - :func:`jax.scipy.special.gamma` """ d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = promote_args_inexact("multigammaln", a, d) constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_), lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant def kl_div( p: ArrayLike, q: ArrayLike, ) -> Array: r"""The Kullback-Leibler divergence. JAX implementation of :obj:`scipy.special.kl_div`. .. math:: \mathrm{kl\_div}(p, q) = \begin{cases} p\log(p/q)-p+q & p>0,q>0\\ q & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases} Args: p: arraylike, real-valued. q: arraylike, real-valued. Returns: array of KL-divergence values See also: - :func:`jax.scipy.special.entr` - :func:`jax.scipy.special.rel_entr` """ p, q = promote_args_inexact("kl_div", p, q) return rel_entr(p, q) - p + q def rel_entr( p: ArrayLike, q: ArrayLike, ) -> Array: r"""The relative entropy function. JAX implementation of :obj:`scipy.special.rel_entr`. .. math:: \mathrm{rel\_entr}(p, q) = \begin{cases} p\log(p/q) & p>0,q>0\\ 0 & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases} Args: p: arraylike, real-valued. q: arraylike, real-valued. Returns: array of relative entropy values. See also: - :func:`jax.scipy.special.entr` - :func:`jax.scipy.special.kl_div` """ p, q = promote_args_inexact("rel_entr", p, q) zero = _lax_const(p, 0.0) both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero)) one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero)) safe_p = jnp.where(both_gt_zero_mask, p, 1) safe_q = jnp.where(both_gt_zero_mask, q, 1) log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)) result = jnp.where( both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf) ) return result # coefs of (2k)! / B_{2k} where B are bernoulli numbers # those numbers are obtained using https://www.wolframalpha.com _BERNOULLI_COEFS = [ 12, -720, 30240, -1209600, 47900160, -1307674368000 / 691, 74724249600, -10670622842880000 / 3617, 5109094217170944000 / 43867, -802857662698291200000 / 174611, 14101100039391805440000 / 77683, -1693824136731743669452800000 / 236364091, 186134520519971831808000000 / 657931, -37893265687455865519472640000000 / 3392780147, 759790291646040068357842010112000000 / 1723168255201, -134196726836183700385281186201600000000 / 7709321041217, ] @custom_derivatives.custom_jvp def zeta(x: ArrayLike, q: ArrayLike | None = None) -> Array: r"""The Hurwitz zeta function. JAX implementation of :func:`scipy.special.zeta`. JAX does not implement the Riemann zeta function (i.e. ``q = None``). .. math:: \zeta(x, q) = \sum_{n=0}^\infty \frac{1}{(n + q)^x} Args: x: arraylike, real-valued q: arraylike, real-valued Returns: array of zeta function values """ if q is None: raise NotImplementedError( "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") x, q = promote_args_inexact("zeta", x, q) return lax.zeta(x, q) # There is no general closed-form derivative for the zeta function, so we compute # derivatives via a series expansion def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array: if q is None: raise NotImplementedError( "Riemann zeta function not implemented; pass q != None to compute the Hurwitz Zeta function.") # Reference: Johansson, Fredrik. # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives." # Numerical Algorithms 69.2 (2015): 253-270. # https://arxiv.org/abs/1309.2877 - formula (5) # here we keep the same notation as in reference s, a = promote_args_inexact("zeta", x, q) dtype = lax.dtype(a).type s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) # precision ~ N, M N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16) assert M <= len(_BERNOULLI_COEFS) k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim))) S = jnp.sum((a_ + k) ** -s_, -1) I = lax.div((a + N) ** (dtype(1) - s), s - dtype(1)) T0 = (a + N) ** -s m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] T1 = jnp.clip(T1, max=jnp.finfo(dtype).max) coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T zeta.defjvp(partial(jvp, _zeta_series_expansion)) def polygamma(n: ArrayLike, x: ArrayLike) -> Array: r"""The polygamma function. JAX implementation of :func:`scipy.special.polygamma`. .. math:: \mathrm{polygamma}(n, x) = \psi^{(n)}(x) = \frac{\mathrm{d}^n}{\mathrm{d}x^n}\log \Gamma(x) where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function. Args: n: arraylike, integer-valued. The order of the derivative. x: arraylike, real-valued. The value at which to evaluate the function. Returns: array See also: - :func:`jax.scipy.special.gamma` - :func:`jax.scipy.special.digamma` """ assert jnp.issubdtype(lax.dtype(n), jnp.integer) n_arr, x_arr = promote_args_inexact("polygamma", n, x) return lax.polygamma(n_arr, x_arr) # Normal distributions # Functions "ndtr" and "ndtri" are derived from calculations made in: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # The "spence" function is also based on the Cephes library with # the corresponding spence.c file located in the tarball: # https://netlib.org/cephes/misc.tgz # In the following email exchange, the author gives his consent to redistribute # derived works under an Apache 2.0 license. # # From: Stephen Moshier # Date: Sat, Jun 9, 2018 at 2:36 PM # Subject: Re: Licensing cephes under Apache (BSD-like) license. # To: rif # # # # Hello Rif, # # Yes, Google may distribute Cephes files under the Apache 2 license. # # If clarification is needed, I do not favor BSD over other free licenses. # I would agree that Apache 2 seems to cover the concern you mentioned # about sublicensees. # # Best wishes for good luck with your projects! # Steve Moshier # # # # On Thu, 31 May 2018, rif wrote: # # > Hello Steve. # > My name is Rif. I work on machine learning software at Google. # > # > Your cephes software continues to be incredibly useful and widely used. I # > was wondering whether it would be permissible for us to use the Cephes code # > under the Apache 2.0 license, which is extremely similar in permissions to # > the BSD license (Wikipedia comparisons). This would be quite helpful to us # > in terms of avoiding multiple licenses on software. # > # > I'm sorry to bother you with this (I can imagine you're sick of hearing # > about this by now), but I want to be absolutely clear we're on the level and # > not misusing your important software. In former conversation with Eugene # > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD, # > the formal way that has been handled is simply to add a statement to the # > effect that you are incorporating the Cephes software by permission of the # > author." I wanted to confirm that (a) we could use the Apache license, (b) # > that we don't need to (and probably you don't want to) keep getting # > contacted about individual uses, because your intent is generally to allow # > this software to be reused under "BSD-like" license, and (c) you're OK # > letting incorporators decide whether a license is sufficiently BSD-like? # > # > Best, # > # > rif # > # > # > # log_ndtr uses different functions over the ranges # (-infty, lower](lower, upper](upper, infty) # Lower bound values were chosen by examining where the support of ndtr # appears to be zero, relative to scipy's (which is always 64bit). They were # then made more conservative just to be safe. (Conservative means use the # expansion more than we probably need to.) _LOGNDTR_FLOAT64_LOWER = np.array(-20, np.float64) _LOGNDTR_FLOAT32_LOWER = np.array(-10, np.float32) # Upper bound values were chosen by examining for which values of 'x' # Log[cdf(x)] is 0, after which point we need to use the approximation # Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly # conservative, meaning we use the approximation earlier than needed. _LOGNDTR_FLOAT64_UPPER = np.array(8, np.float64) _LOGNDTR_FLOAT32_UPPER = np.array(5, np.float32) def ndtr(x: ArrayLike) -> Array: r"""Normal distribution function. JAX implementation of :obj:`scipy.special.ndtr`. Returns the area under the Gaussian probability density function, integrated from minus infinity to x: .. math:: \begin{align} \mathrm{ndtr}(x) =& \ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\ =&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\ =&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}}) \end{align} Args: x: An array of type `float32`, `float64`. Returns: An array with `dtype=x.dtype`. Raises: TypeError: if `x` is not floating-type. """ x = jnp.asarray(x) dtype = lax.dtype(x) if dtype not in (jnp.float32, jnp.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." .format(dtype)) return _ndtr(x) def _ndtr(x: ArrayLike) -> Array: """Implements ndtr core logic.""" dtype = lax.dtype(x).type half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype) w = x * half_sqrt_2 z = lax.abs(w) y = lax.select(lax.lt(z, half_sqrt_2), dtype(1.) + lax.erf(w), lax.select(lax.gt(w, dtype(0.)), dtype(2.) - lax.erfc(z), lax.erfc(z))) return dtype(0.5) * y def ndtri(p: ArrayLike) -> Array: r"""The inverse of the CDF of the Normal distribution function. JAX implementation of :obj:`scipy.special.ndtri`. Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal to `p`. A piece-wise rational approximation is done for the function. This is based on the implementation in netlib. Args: p: an array of type `float32`, `float64`. Returns: an array with `dtype=p.dtype`. Raises: TypeError: if `p` is not floating-type. """ dtype = lax.dtype(p) if dtype not in (jnp.float32, jnp.float64): raise TypeError( "x.dtype={} is not supported, see docstring for supported types." .format(dtype)) return _ndtri(p) def _ndtri(p: ArrayLike) -> Array: """Implements ndtri core logic.""" # Constants used in piece-wise rational approximations. Taken from the cephes # library: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html p0 = list(reversed([-5.99633501014107895267E1, 9.80010754185999661536E1, -5.66762857469070293439E1, 1.39312609387279679503E1, -1.23916583867381258016E0])) q0 = list(reversed([1.0, 1.95448858338141759834E0, 4.67627912898881538453E0, 8.63602421390890590575E1, -2.25462687854119370527E2, 2.00260212380060660359E2, -8.20372256168333339912E1, 1.59056225126211695515E1, -1.18331621121330003142E0])) p1 = list(reversed([4.05544892305962419923E0, 3.15251094599893866154E1, 5.71628192246421288162E1, 4.40805073893200834700E1, 1.46849561928858024014E1, 2.18663306850790267539E0, -1.40256079171354495875E-1, -3.50424626827848203418E-2, -8.57456785154685413611E-4])) q1 = list(reversed([1.0, 1.57799883256466749731E1, 4.53907635128879210584E1, 4.13172038254672030440E1, 1.50425385692907503408E1, 2.50464946208309415979E0, -1.42182922854787788574E-1, -3.80806407691578277194E-2, -9.33259480895457427372E-4])) p2 = list(reversed([3.23774891776946035970E0, 6.91522889068984211695E0, 3.93881025292474443415E0, 1.33303460815807542389E0, 2.01485389549179081538E-1, 1.23716634817820021358E-2, 3.01581553508235416007E-4, 2.65806974686737550832E-6, 6.23974539184983293730E-9])) q2 = list(reversed([1.0, 6.02427039364742014255E0, 3.67983563856160859403E0, 1.37702099489081330271E0, 2.16236993594496635890E-1, 1.34204006088543189037E-2, 3.28014464682127739104E-4, 2.89247864745380683936E-6, 6.79019408009981274425E-9])) dtype = lax.dtype(p).type shape = jnp.shape(p) def _create_polynomial(var, coeffs): """Compute n_th order polynomial via Horner's method.""" coeffs = np.array(coeffs, dtype) if not coeffs.size: return jnp.zeros_like(var) return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p) # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs # later on. The result from the computation when p == 0 is not used so any # number that doesn't result in NaNs is fine. sanitized_mcp = jnp.where( maybe_complement_p == dtype(0.), jnp.full(shape, dtype(0.5)), maybe_complement_p) # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). w = sanitized_mcp - dtype(0.5) ww = lax.square(w) x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) x_for_big_p *= -dtype(np.sqrt(2. * np.pi)) # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different # arrays based on whether p < exp(-32). z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp)) first_term = z - lax.log(z) / z second_term_small_p = ( _create_polynomial(dtype(1.) / z, p2) / _create_polynomial(dtype(1.) / z, q2) / z) second_term_otherwise = ( _create_polynomial(dtype(1.) / z, p1) / _create_polynomial(dtype(1.) / z, q1) / z) x_for_small_p = first_term - second_term_small_p x_otherwise = first_term - second_term_otherwise x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p, jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise)) x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x) with config.debug_infs(False): infinity = jnp.full(shape, dtype(np.inf)) x = jnp.where( p == dtype(0.0), -infinity, jnp.where(p == dtype(1.0), infinity, x)) if not isinstance(x, core.Tracer): try: dispatch.check_special("ndtri", [x]) except api_util.InternalFloatingPointError as e: raise FloatingPointError( f"invalid value ({e.ty}) encountered in ndtri.") from None return x @partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,)) def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array: r"""Log Normal distribution function. JAX implementation of :obj:`scipy.special.log_ndtr`. For details of the Normal distribution function see `ndtr`. This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling :math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically: - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on :math:`\log(1-x) \approx -x, x \ll 1`. - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique and take a log. - For `x <= lower_segment`, we use the series approximation of `erf` to compute the log CDF directly. The `lower_segment` is set based on the precision of the input: .. math:: \begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align} When `x < lower_segment`, the `ndtr` asymptotic series approximation is: .. math:: \begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align} where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a `double-factorial `_ operator. Args: x: an array of type `float32`, `float64`. series_order: Positive Python integer. Maximum depth to evaluate the asymptotic expansion. This is the `N` above. Returns: an array with `dtype=x.dtype`. Raises: TypeError: if `x.dtype` is not handled. TypeError: if `series_order` is a not Python `integer.` ValueError: if `series_order` is not in `[0, 30]`. """ if not isinstance(series_order, int): raise TypeError("series_order must be a Python integer.") if series_order < 0: raise ValueError("series_order must be non-negative.") if series_order > 30: raise ValueError("series_order must be <= 30.") x_arr = jnp.asarray(x) dtype = lax.dtype(x_arr) if dtype == jnp.float64: lower_segment: np.ndarray = _LOGNDTR_FLOAT64_LOWER upper_segment: np.ndarray = _LOGNDTR_FLOAT64_UPPER elif dtype == jnp.float32: lower_segment = _LOGNDTR_FLOAT32_LOWER upper_segment = _LOGNDTR_FLOAT32_UPPER else: raise TypeError(f"x.dtype={np.dtype(dtype)} is not supported.") # The basic idea here was ported from: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # We copy the main idea, with a few changes # * For x >> 1, and X ~ Normal(0, 1), # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], # which extends the range of validity of this function. # * We use one fixed series_order for all of 'x', rather than adaptive. # * Our docstring properly reflects that this is an asymptotic series, not a # Taylor series. We also provided a correct bound on the remainder. # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when # x=0. This happens even though the branch is unchosen because when x=0 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. return jnp.where( lax.gt(x_arr, upper_segment), -_ndtr(-x_arr), # log(1-x) ~= -x, x << 1 jnp.where(lax.gt(x_arr, lower_segment), lax.log(_ndtr(lax.max(x_arr, lower_segment))), _log_ndtr_lower(lax.min(x_arr, lower_segment), series_order))) def _log_ndtr_jvp(series_order, primals, tangents): (x,), (t,) = primals, tangents ans = log_ndtr(x, series_order=series_order) t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans))) return ans, t_out log_ndtr.defjvp(_log_ndtr_jvp) def _log_ndtr_lower(x, series_order): """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`.""" dtype = lax.dtype(x).type x_2 = lax.square(x) # Log of the term multiplying (1 + sum) log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi)) return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order)) def _log_ndtr_asymptotic_series(x, series_order): """Calculates the asymptotic series used in log_ndtr.""" dtype = lax.dtype(x).type if series_order <= 0: return np.array(1, dtype) x_2 = lax.square(x) even_sum = jnp.zeros_like(x) odd_sum = jnp.zeros_like(x) x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1. for n in range(1, series_order + 1): y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n if n % 2: odd_sum += y else: even_sum += y x_2n *= x_2 return dtype(1.) + even_sum - odd_sum def _double_factorial(n: int) -> np.ndarray: """The double factorial function for small Python integer `n`.""" return np.prod(np.arange(n, 1, -2)) _norm_logpdf_constant = np.log(np.sqrt(2 * np.pi)) def _norm_logpdf(x): neg_half = _lax_const(x, -0.5) log_normalizer = _lax_const(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer) def i0e(x: ArrayLike) -> Array: r"""Exponentially scaled modified bessel function of zeroth order. JAX implementation of :obj:`scipy.special.i0e`. .. math:: \mathrm{i0e}(x) = e^{-|x|} I_0(x) where :math:`I_0(x)` is the modified Bessel function :func:`~jax.scipy.special.i0`. Args: x: array, real-valued Returns: array of bessel function values. See also: - :func:`jax.scipy.special.i0` - :func:`jax.scipy.special.i1` - :func:`jax.scipy.special.i1e` """ x, = promote_args_inexact("i0e", x) return lax.bessel_i0e(x) def i0(x: ArrayLike) -> Array: r"""Modified bessel function of zeroth order. JAX implementation of :obj:`scipy.special.i0`. .. math:: \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^\infty \frac{(x^2/4)^k}{(k!)^2} Args: x: array, real-valued Returns: array of bessel function values. See also: - :func:`jax.scipy.special.i0e` - :func:`jax.scipy.special.i1` - :func:`jax.scipy.special.i1e` """ x, = promote_args_inexact("i0", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x)) def i1e(x: ArrayLike) -> Array: r"""Exponentially scaled modified bessel function of first order. JAX implementation of :obj:`scipy.special.i1e`. .. math:: \mathrm{i1e}(x) = e^{-|x|} I_1(x) where :math:`I_1(x)` is the modified Bessel function :func:`~jax.scipy.special.i1`. Args: x: array, real-valued Returns: array of bessel function values See also: - :func:`jax.scipy.special.i0` - :func:`jax.scipy.special.i0e` - :func:`jax.scipy.special.i1` """ x, = promote_args_inexact("i1e", x) return lax.bessel_i1e(x) def i1(x: ArrayLike) -> Array: r"""Modified bessel function of first order. JAX implementation of :obj:`scipy.special.i1`. .. math:: \mathrm{i1}(x) = I_1(x) = \frac{1}{2}x\sum_{k=0}^\infty\frac{(x^2/4)^k}{k!(k+1)!} Args: x: array, real-valued Returns: array of bessel function values See also: - :func:`jax.scipy.special.i0` - :func:`jax.scipy.special.i0e` - :func:`jax.scipy.special.i1e` """ x, = promote_args_inexact("i1", x) return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x)) def _bessel_jn_scan_body_fun(carry, k): f0, f1, bs, z = carry f = 2.0 * (k + 1.0) * f1 / z - f0 def true_fn_update_bs(u): bs, f = u return bs + 2.0 * f def false_fn_update_bs(u): bs, _ = u return bs bs = lax.cond(jnp.mod(k, 2) == 0, true_fn_update_bs, false_fn_update_bs, operand=(bs, f)) f0 = f1 f1 = f return (f0, f1, bs, z), f def _bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: f0 = _lax_const(z, 0.0) f1 = _lax_const(z, 1E-16) f = _lax_const(z, 0.0) bs = _lax_const(z, 0.0) (_, _, bs, _), j_vals = lax.scan( f=_bessel_jn_scan_body_fun, init=(f0, f1, bs, z), xs=lax.iota(lax.dtype(z), n_iter+1), reverse=True) f = j_vals[0] # Use the value at the last iteration. j_vals = j_vals[:v+1] j_vals = j_vals / (bs - f) return j_vals @partial(jit, static_argnames=["v", "n_iter"]) def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array: """Bessel function of the first kind of integer order and real argument. Reference: Shanjie Zhang and Jian-Ming Jin. Computation of special functions. Wiley-Interscience, 1996. Args: z: The sampling point(s) at which the Bessel function of the first kind are computed. v: The order (int) of the Bessel function. n_iter: The number of iterations required for updating the function values. As a rule of thumb, `n_iter` is the smallest nonnegative integer that satisfies the condition `int(0.5 * log10(6.28 + n_iter) - n_iter * log10(1.36 + abs(z) / n_iter)) > 20`. Details in `BJNDD` (https://people.sc.fsu.edu/~jburkardt/f77_src/special_functions/special_functions.f) Returns: An array of shape `(v+1, *z.shape)` containing the values of the Bessel function of orders 0, 1, ..., v. The return type matches the type of `z`. Raises: TypeError if `v` is not integer. ValueError if elements of array `z` are not float. """ z = jnp.asarray(z) z, = promote_dtypes_inexact(z) z_dtype = lax.dtype(z) if dtypes.issubdtype(z_dtype, complex): raise ValueError("complex input not supported.") v = core.concrete_or_error(operator.index, v, 'Argument v of bessel_jn.') n_iter = core.concrete_or_error(int, n_iter, 'Argument n_iter of bessel_jn.') bessel_jn_fun = partial(_bessel_jn, v=v, n_iter=n_iter) for _ in range(z.ndim): bessel_jn_fun = vmap(bessel_jn_fun) return jnp.moveaxis(bessel_jn_fun(z), -1, 0) def _gen_recurrence_mask( l_max: int, is_normalized: bool, dtype: Any ) -> tuple[Array, Array]: """Generates a mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. is_normalized: True if the recurrence mask is used by normalized associated Legendre functions. Returns: Arrays representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = jnp.meshgrid( jnp.arange(l_max + 1, dtype=dtype), jnp.arange(l_max + 1, dtype=dtype), indexing='ij') if is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = jnp.triu_indices(l_max + 1, 1) d1_mask_indices = jnp.triu_indices(l_max + 1, 2) d_zeros = jnp.zeros((l_max + 1, l_max + 1), dtype=dtype) d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices]) d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices]) # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. # i = jnp.arange(l_max + 1)[:, None, None] # j = jnp.arange(l_max + 1)[None, :, None] # k = jnp.arange(l_max + 1)[None, None, :] i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] mask = (i + j - k == 0).astype(dtype) d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) return (d0_mask_3d, d1_mask_3d) @partial(jit, static_argnums=(2)) def _gen_derivatives(p: Array, x: Array, is_normalized: bool) -> Array: """Generates derivatives of associated Legendre functions of the first kind. Args: p: The 3D array containing the values of associated Legendre functions; the dimensions are in the sequence of order (m), degree (l), and evaluation points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. Returns: The 3D array representing the derivatives of associated Legendre functions of the first kind. """ num_m, num_l, num_x = p.shape # p_{l-1}^m. p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] # p_{l-1}^{m+2}. p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] # p_{l-1}^{m-2}. p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] # Derivative computation requires negative orders. if is_normalized: raise NotImplementedError( 'Negative orders for normalization is not implemented yet.') else: if num_l > 1: l_vec = jnp.arange(1, num_l - 1, dtype=x.dtype) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1, dtype=x.dtype) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1)) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2) m_mat, l_mat = jnp.meshgrid( jnp.arange(num_m, dtype=x.dtype), jnp.arange(num_l, dtype=x.dtype), indexing='ij') coeff_zeros = jnp.zeros((num_m, num_l), dtype=x.dtype) upper_0_indices = jnp.triu_indices(num_m, 0, num_l) zero_vec = jnp.zeros((num_l,), dtype=x.dtype) a0 = -0.5 / (m_mat - 1.0) a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) a0_masked = a0_masked.at[1, :].set(zero_vec) b0 = l_mat + m_mat c0 = a0 * (b0 - 2.0) * (b0 - 1.0) c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) e0 = d0 * b0 * (b0 + 1.0) e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l, dtype=p.dtype) g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(0) return p_derivative @partial(jit, static_argnums=(0, 2)) def _gen_associated_legendre(l_max: int, x: Array, is_normalized: bool) -> Array: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, respectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy quadrature points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evaluation points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype) a_idx = jnp.arange(1, l_max + 1, dtype=x.dtype) b_idx = jnp.arange(l_max, dtype=x.dtype) if is_normalized: initial_value: ArrayLike = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: initial_value = 1.0 # The initial value p(0,0). f_a = jnp.cumprod(1.0 - 2.0 * a_idx) f_b = 2.0 * b_idx + 1.0 p = p.at[(0, 0)].set(initial_value) # Compute the diagonal entries p(l,l) with recurrence. y = jnp.cumprod( jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) # Compute the remaining entries with recurrence. d0_mask_3d, d1_mask_3d = _gen_recurrence_mask( l_max, is_normalized=is_normalized, dtype=x.dtype) def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum('ij,ijk->ijk', coeff_0, jnp.einsum( 'ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(p_val, shift=2, axis=1))) p_val = p_val + h return p_val # TODO(jakevdp): use some sort of fixed-point procedure here instead? p = p.astype(jnp.result_type(p, x, d0_mask_3d)) if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max+1, body_fun=body_fun, init_val=p) return p def lpmn(m: int, n: int, z: Array) -> tuple[Array, Array]: """The associated Legendre functions (ALFs) of the first kind. Args: m: The maximum order of the associated Legendre functions. n: The maximum degree of the associated Legendre function, often called `l` in describing ALFs. Both the degrees and orders are `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. z: A vector of type `float32` or `float64` containing the sampling points at which the ALFs are computed. Returns: A 2-tuple of 3D arrays of shape `(l_max + 1, l_max + 1, len(z))` containing the values and derivatives of the associated Legendre functions of the first kind. The return type matches the type of `z`. Raises: TypeError if elements of array `z` are not in (float32, float64). ValueError if array `z` is not 1D. NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) if dtype not in (jnp.float32, jnp.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.' .format(dtype)) if z.ndim != 1: raise ValueError('z must be a 1D array.') m = core.concrete_or_error(int, m, 'Argument m of lpmn.') n = core.concrete_or_error(int, n, 'Argument n of lpmn.') if m != n: raise NotImplementedError('Computations for m!=n are not yet supported.') l_max = n is_normalized = False p_vals = _gen_associated_legendre(l_max, z, is_normalized) p_derivatives = _gen_derivatives(p_vals, z, is_normalized) return (p_vals, p_derivatives) def lpmn_values(m: int, n: int, z: Array, is_normalized: bool) -> Array: r"""The associated Legendre functions (ALFs) of the first kind. Unlike `lpmn`, this function only computes the values of ALFs. The ALFs of the first kind can be used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as :math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`, where :math:`N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, respectively. :math:`N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow and achieves better numerical stability. Args: m: The maximum order of the associated Legendre functions. n: The maximum degree of the associated Legendre function, often called `l` in describing ALFs. Both the degrees and orders are `[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree. z: A vector of type `float32` or `float64` containing the sampling points at which the ALFs are computed. is_normalized: True if the associated Legendre functions are normalized. With normalization, :math:`N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Returns: A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing the values of the associated Legendre functions of the first kind. The return type matches the type of `z`. Raises: TypeError if elements of array `z` are not in (float32, float64). ValueError if array `z` is not 1D. NotImplementedError if `m!=n`. """ dtype = lax.dtype(z) if dtype not in (jnp.float32, jnp.float64): raise TypeError( 'z.dtype={} is not supported, see docstring for supported types.' .format(dtype)) if z.ndim != 1: raise ValueError('z must be a 1D array.') m = core.concrete_or_error(int, m, 'Argument m of lpmn.') n = core.concrete_or_error(int, n, 'Argument n of lpmn.') if m != n: raise NotImplementedError('Computations for m!=n are not yet supported.') l_max = n return _gen_associated_legendre(l_max, z, is_normalized) @partial(jit, static_argnums=(4,)) def _sph_harm(n: Array, m: Array, theta: Array, phi: Array, n_max: int) -> Array: """Computes the spherical harmonics.""" cos_colatitude = jnp.cos(theta) legendre = _gen_associated_legendre(n_max, cos_colatitude, True) legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip") angle = abs(m) * phi vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) harmonics = lax.complex(legendre_val * jnp.real(vandermonde), legendre_val * jnp.imag(vandermonde)) # Negative order. harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics), harmonics) return harmonics def sph_harm_y(n: Array, m: Array, theta: Array, phi: Array, diff_n: int | None = None, n_max: int | None = None) -> Array: r"""Computes the spherical harmonics. The JAX version has one extra argument `n_max`, the maximum value in `n`. The spherical harmonic of degree `n` and order `m` can be written as :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \theta) * \exp(i m \phi)`, where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\theta` and :math:`\phi` are the colatitude and longitude, respectively. :math:`N_n^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Args: n: The degree of the harmonic; must have `n >= 0`. The standard notation for degree in descriptions of spherical harmonics is `l (lower case L)`. We use `n` here to be consistent with `scipy.special.sph_harm_y`. Return values for `n < 0` are undefined. m: The order of the harmonic; must have `|m| <= n`. Return values for `|m| > n` are undefined. theta: The polar (colatitudinal) coordinate; must be in [0, pi]. phi: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi]. diff_n: Unsupported by JAX. n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true maximum value of `n`, the results are clipped to `n_max`. For example, `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` actually returns `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` Returns: A 1D array containing the spherical harmonics at (m, n, theta, phi). """ if diff_n is not None: raise NotImplementedError( "The 'diff_n' argument to jax.scipy.special.sph_harm_y is not supported.") if jnp.isscalar(theta): theta = jnp.array([theta]) if n_max is None: n_max = np.max(n) n_max = core.concrete_or_error( int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' 'be statically specified to use `sph_harm` within JAX transformations.') return _sph_harm(n, m, theta, phi, n_max) def sph_harm(m: Array, n: Array, theta: Array, phi: Array, n_max: int | None = None) -> Array: r"""Computes the spherical harmonics. Note: This function is deprecated, and :func:`~jax.scipy.special.sph_harm_y` should be used instead, noting that the order of ``m`` and ``n`` are reversed, and definitions of ``theta`` and ``phi`` are swapped. The JAX version has one extra argument `n_max`, the maximum value in `n`. The spherical harmonic of degree `n` and order `m` can be written as :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`, where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and :math:`\theta` are the colatitude and longitude, respectively. :math:`N_n^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Args: m: The order of the harmonic; must have `|m| <= n`. Return values for `|m| > n` are undefined. n: The degree of the harmonic; must have `n >= 0`. The standard notation for degree in descriptions of spherical harmonics is `l (lower case L)`. We use `n` here to be consistent with `scipy.special.sph_harm`. Return values for `n < 0` are undefined. theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi]. phi: The polar (colatitudinal) coordinate; must be in [0, pi]. n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true maximum value of `n`, the results are clipped to `n_max`. For example, `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` actually returns `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` Returns: A 1D array containing the spherical harmonics at (m, n, theta, phi). """ # Added 2025-01-06. # TODO(dfm): Remove after deprecation period. deprecations.warn( "jax-scipy-special-sph-harm", ("jax.scipy.special.sph_harm is deprecated. Please use " "jax.scipy.special.sph_harm_y instead, noting that the order of `m` and " "`n` are reversed, and definitions of `theta` and `phi` are swapped."), stacklevel=2, ) return sph_harm_y(n, m, phi, theta, n_max=n_max) # exponential integrals # these algorithms are ported over from the files ei.c and expn.c in the Cephes mathematical library. # https://fossies.org/dox/cephes-math-28/ei_8c_source.html # https://fossies.org/dox/cephes-math-28/expn_8c_source.html def _expint1(x: Array) -> Array: # 0 < x <= 2 A = [ -5.350447357812542947283e0, 2.185049168816613393830e2, -4.176572384826693777058e3, 5.541176756393557601232e4, -3.313381331178144034309e5, 1.592627163384945414220e6, ] B = [ 1.0, -5.250547959112862969197e1, 1.259616186786790571525e3, -1.756549581973534652631e4, 1.493062117002725991967e5, -7.294949239640527645655e5, 1.592627163384945429726e6, ] A_arr = jnp.array(A, dtype=x.dtype) B_arr = jnp.array(B, dtype=x.dtype) f = jnp.polyval(A_arr, x) / jnp.polyval(B_arr, x) return x * f + jnp.euler_gamma + jnp.log(x) def _eval_expint_k(A: list[float], B: list[float], x: Array) -> Array: # helper function for all subsequent intervals A_arr = jnp.array(A, dtype=x.dtype) B_arr = jnp.array(B, dtype=x.dtype) one = _lax_const(x, 1.0) w = one / x f = jnp.polyval(A_arr, w) / jnp.polyval(B_arr, w) f = w * f + one return jnp.exp(x) * w * f def _expint2(x: Array) -> Array: # 2 <= x < 4 A = [ 1.981808503259689673238e-2, -1.271645625984917501326e0, -2.088160335681228318920e0, 2.755544509187936721172e0, -4.409507048701600257171e-1, 4.665623805935891391017e-2, -1.545042679673485262580e-3, 7.059980605299617478514e-5, ] B = [ 1.0, 1.476498670914921440652e0, 5.629177174822436244827e-1, 1.699017897879307263248e-1, 2.291647179034212017463e-2, 4.450150439728752875043e-3, 1.727439612206521482874e-4, 3.953167195549672482304e-5, ] return _eval_expint_k(A, B, x) def _expint3(x: Array) -> Array: # 4 <= x <= 8 A = [ -1.373215375871208729803e0, -7.084559133740838761406e-1, 1.580806855547941010501e0, -2.601500427425622944234e-1, 2.994674694113713763365e-2, -1.038086040188744005513e-3, 4.371064420753005429514e-5, 2.141783679522602903795e-6, ] B = [ 1.0, 8.585231423622028380768e-1, 4.483285822873995129957e-1, 7.687932158124475434091e-2, 2.449868241021887685904e-2, 8.832165941927796567926e-4, 4.590952299511353531215e-4, -4.729848351866523044863e-6, 2.665195537390710170105e-6, ] return _eval_expint_k(A, B, x) def _expint4(x: Array) -> Array: # 8 <= x <= 16 A = [ -2.106934601691916512584e0, 1.732733869664688041885e0, -2.423619178935841904839e-1, 2.322724180937565842585e-2, 2.372880440493179832059e-4, -8.343219561192552752335e-5, 1.363408795605250394881e-5, -3.655412321999253963714e-7, 1.464941733975961318456e-8, 6.176407863710360207074e-10, ] B = [ 1.0, -2.298062239901678075778e-1, 1.105077041474037862347e-1, -1.566542966630792353556e-2, 2.761106850817352773874e-3, -2.089148012284048449115e-4, 1.708528938807675304186e-5, -4.459311796356686423199e-7, 1.394634930353847498145e-8, 6.150865933977338354138e-10, ] return _eval_expint_k(A, B, x) def _expint5(x): # 16 <= x <= 32 A = [ -2.458119367674020323359e-1, -1.483382253322077687183e-1, 7.248291795735551591813e-2, -1.348315687380940523823e-2, 1.342775069788636972294e-3, -7.942465637159712264564e-5, 2.644179518984235952241e-6, -4.239473659313765177195e-8, ] B = [ 1.0, -1.044225908443871106315e-1, -2.676453128101402655055e-1, 9.695000254621984627876e-2, -1.601745692712991078208e-2, 1.496414899205908021882e-3, -8.462452563778485013756e-5, 2.728938403476726394024e-6, -4.239462431819542051337e-8, ] return _eval_expint_k(A, B, x) def _expint6(x): # 32 <= x <= 64 A = [ 1.212561118105456670844e-1, -5.823133179043894485122e-1, 2.348887314557016779211e-1, -3.040034318113248237280e-2, 1.510082146865190661777e-3, -2.523137095499571377122e-5, ] B = [ 1.0, -1.002252150365854016662e0, 2.928709694872224144953e-1, -3.337004338674007801307e-2, 1.560544881127388842819e-3, -2.523137093603234562648e-5, ] return _eval_expint_k(A, B, x) def _expint7(x): # x > 64 A = [ -7.657847078286127362028e-1, 6.886192415566705051750e-1, -2.132598113545206124553e-1, 3.346107552384193813594e-2, -3.076541477344756050249e-3, 1.747119316454907477380e-4, -6.103711682274170530369e-6, 1.218032765428652199087e-7, -1.086076102793290233007e-9, ] B = [ 1.0, -1.888802868662308731041e0, 1.066691687211408896850e0, -2.751915982306380647738e-1, 3.930852688233823569726e-2, -3.414684558602365085394e-3, 1.866844370703555398195e-4, -6.345146083130515357861e-6, 1.239754287483206878024e-7, -1.086076102793126632978e-9, ] return _eval_expint_k(A, B, x) def _expi_pos(x: Array) -> Array: # x >= 0 _c = _lax_const conds = [(_c(x, 0) < x) & (x <= _c(x, 2))] + [ (_c(x, 2 ** i) < x) & (x <= _c(x, 2 ** (i + 1))) for i in range(1, 6) ] return jnp.piecewise( x, conds, [_expint1, _expint2, _expint3, _expint4, _expint5, _expint6, _expint7], ) def _expi_neg(x: Array) -> Array: # x < 0 return -exp1(-x) @custom_derivatives.custom_jvp @jit def expi(x: ArrayLike) -> Array: r"""Exponential integral function. JAX implementation of :obj:`scipy.special.expi` .. math:: \mathrm{expi}(x) = \int_{-\infty}^x \frac{e^t}{t} \mathrm{d}t Args: x: arraylike, real-valued Returns: array of expi values See also: - :func:`jax.scipy.special.expn` - :func:`jax.scipy.special.exp1` """ x_arr, = promote_args_inexact("expi", x) return jnp.piecewise(x_arr, [x_arr < 0], [_expi_neg, _expi_pos]) @expi.defjvp @jit def expi_jvp(primals, tangents): (x,) = primals (x_dot,) = tangents return expi(x), jnp.exp(x) / x * x_dot def _expn1(x: Array, n: Array) -> Array: # exponential integral En _c = _lax_const MACHEP = jnp.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) psi = -jnp.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( x=x, z=-x, xk=zero, yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), t=jnp.inf, ) def body(d): d["xk"] += one d["yk"] *= d["z"] / d["xk"] d["pk"] += one d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero) d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one) return d def cond(d): return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) t = n r = n - _c(n, 1) return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"] def _expn2(x: Array, n: Array) -> Array: # x > 1. _c = _lax_const BIG = _c(x, 1.44115188075855872e17) MACHEP = jnp.finfo(BIG.dtype).eps # ? zero = _c(x, 0.0) one = _c(x, 1.0) init = dict( k=_c(n, 1), pkm2=one, qkm2=x, pkm1=one, qkm1=x + n, ans=one / (x + n), t=_c(x, jnp.inf), r=zero, x=x, ) def body(d): x = d["x"] d["k"] += _c(d["k"], 1) k = d["k"] odd = k % _c(k, 2) == _c(k, 1) yk = jnp.where(odd, one, x) xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2)) pk = d["pkm1"] * yk + d["pkm2"] * xk qk = d["qkm1"] * yk + d["qkm2"] * xk nz = qk != zero d["r"] = r = jnp.where(nz, pk / qk, d["r"]) d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one) d["ans"] = jnp.where(nz, r, d["ans"]) d["pkm2"] = d["pkm1"] d["pkm1"] = pk d["qkm2"] = d["qkm1"] d["qkm1"] = qk is_big = abs(pk) > BIG for s in "pq": for i in "12": key = s + "km" + i d[key] = jnp.where(is_big, d[key] / BIG, d[key]) return d def cond(d): return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) return d["ans"] * jnp.exp(-x) def _expn3(x: Array, n: Array) -> Array: # n >= 5000 _c = _lax_const one = _c(x, 1.0) xk = x + n yk = one / (xk * xk) t = n ans = yk * t * (_c(x, 6) * x * x - _c(x, 8) * t * x + t * t) ans = yk * (ans + t * (t - _c(x, 2) * x)) ans = yk * (ans + t) return (ans + one) * jnp.exp(-x) / xk @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0,)) @jnp.vectorize @jit def expn(n: ArrayLike, x: ArrayLike) -> Array: r"""Generalized exponential integral function. JAX implementation of :obj:`scipy.special.expn`. .. math:: \mathrm{expn}(x) = E_n(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t^n}\mathrm{d}t Args: n: arraylike, real-valued x: arraylike, real-valued Returns: array of expn values See also: - :func:`jax.scipy.special.expi` - :func:`jax.scipy.special.exp1` """ n, x = promote_args_inexact("expn", n, x) _c = _lax_const zero = _c(x, 0) one = _c(x, 1) conds = [ (n < _c(n, 0)) | (x < zero), (x == zero) & (n < _c(n, 2)), (x == zero) & (n >= _c(n, 2)), (n == _c(n, 0)) & (x >= zero), (n >= _c(n, 5000)), (x > one), ] n1 = jnp.where(n == _c(n, 1), n + n, n) vals = [ jnp.nan, jnp.inf, one / n1, # prevent div by zero jnp.exp(-x) / x, _expn3, _expn2, _expn1, ] ret = jnp.piecewise(x, conds, vals, n=n) return ret @expn.defjvp @jit def expn_jvp(n, primals, tangents): (x,), (x_dot,) = primals, tangents return expn(n, x), lax.mul( lax.neg(x_dot), expn(lax.sub(n, _lax_const(n, 1)), x) ) def exp1(x: ArrayLike) -> Array: r"""Exponential integral function. JAX implementation of :obj:`scipy.special.exp1` .. math:: \mathrm{exp1}(x) = E_1(x) = x^{n-1}\int_x^\infty\frac{e^{-t}}{t}\mathrm{d}t Args: x: arraylike, real-valued Returns: array of exp1 values See also: - :func:`jax.scipy.special.expi` - :func:`jax.scipy.special.expn` """ x, = promote_args_inexact("exp1", x) # Casting because custom_jvp generic does not work correctly with mypy. return cast(Array, expn(1, x)) def _spence_poly(w: Array) -> Array: A = jnp.array([4.65128586073990045278E-5, 7.31589045238094711071E-3, 1.33847639578309018650E-1, 8.79691311754530315341E-1, 2.71149851196553469920E0, 4.25697156008121755724E0, 3.29771340985225106936E0, 1.00000000000000000126E0, ], dtype=w.dtype) B = jnp.array([6.90990488912553276999E-4, 2.54043763932544379113E-2, 2.82974860602568089943E-1, 1.41172597751831069617E0, 3.63800533345137075418E0, 5.03278880143316990390E0, 3.54771340985225096217E0, 9.99999999999999998740E-1, ],dtype=w.dtype) return -w * jnp.polyval(A, w) / jnp.polyval(B, w) def _spence_calc(x: Array) -> Array: x2_bool = x > 2.0 x = jnp.piecewise(x, [x2_bool], [lambda x: 1.0 / x, lambda x: x]) x1_5_bool = x > 1.5 x_5_bool = x < 0.5 x2_bool = x2_bool | x1_5_bool w = jnp.piecewise(x, [x1_5_bool, x_5_bool], [lambda x: 1.0 / x - 1.0, lambda x: -x, lambda x: x - 1.0]) y = _spence_poly(w) y_flag_one = jnp.pi ** 2 / 6.0 - jnp.log(x) * jnp.log(1.0 - x) - y y = jnp.where(x_5_bool, y_flag_one, y) y_flag_two = -0.5 * jnp.log(x) ** 2 - y return jnp.where(x2_bool, y_flag_two, y) def _spence(x: Array) -> Array: return jnp.piecewise(x, [x < 0.0, x == 1.0, x == 0.0], [jnp.nan, 0, jnp.pi ** 2 / 6, _spence_calc]) def spence(x: Array) -> Array: r"""Spence's function, also known as the dilogarithm for real values. JAX implementation of :obj:`scipy.special.spence`. It is defined to be: .. math:: \mathrm{spence}(x) = \begin{equation} \int_1^x \frac{\log(t)}{1 - t}dt \end{equation} Unlike the SciPy implementation, this is only defined for positive real values of `z`. For negative values, `NaN` is returned. Args: z: An array of type `float32`, `float64`. Returns: An array with `dtype=z.dtype`. computed values of Spence's function. Raises: TypeError: if elements of array `z` are not in (float32, float64). Notes: There is a different convention which defines Spence's function by the integral: .. math:: \begin{equation} -\int_0^z \frac{\log(1 - t)}{t}dt \end{equation} This is our spence(1 - z). """ x = jnp.asarray(x) dtype = lax.dtype(x) if dtype not in (jnp.float32, jnp.float64): raise TypeError( f"x.dtype={dtype} is not supported, see docstring for supported types.") return _spence(x) def bernoulli(n: int) -> Array: """Generate the first N Bernoulli numbers. JAX implementation of :func:`scipy.special.bernoulli`. Args: n: integer, the number of Bernoulli terms to generate. Returns: Array containing the first ``n`` Bernoulli numbers. Notes: ``bernoulli`` generates numbers using the :math:`B_n^-` convention, such that :math:`B_1=-1/2`. """ # Generate Bernoulli numbers using the Chowla and Hartung algorithm. n = core.concrete_or_error(operator.index, n, "Argument n of bernoulli") if n < 0: raise ValueError("n must be a non-negative integer.") b3 = jnp.array([1, -1/2, 1/6]) if n < 3: return b3[:n + 1] bn = jnp.zeros(n + 1).at[:3].set(b3) m = jnp.arange(4, n + 1, 2, dtype=bn.dtype) q1 = (1. / jnp.pi ** 2) * jnp.cumprod(-(m - 1) * m / 4 / jnp.pi ** 2) k = jnp.arange(2, 50, dtype=bn.dtype) # Choose 50 because 2 ** -50 < 1E-15 q2 = jnp.sum(k[:, None] ** -m[None, :], axis=0) return bn.at[4::2].set(q1 * (1 + q2)) @custom_derivatives.custom_jvp def poch(z: ArrayLike, m: ArrayLike) -> Array: r"""The Pochammer symbol. JAX implementation of :obj:`scipy.special.poch`. .. math:: \mathrm{poch}(z, m) = (z)_m = \frac{\Gamma(z + m)}{\Gamma(z)} where :math:`\Gamma(z)` is the :func:`~jax.scipy.special.gamma` function. Args: z: arraylike, real-valued m: arraylike, real-valued Returns: array of Pochammer values. Notes: The JAX version supports only real-valued inputs. """ z, m = promote_args_inexact("poch", z, m) return jnp.where(m == 0., jnp.array(1, dtype=z.dtype), gamma(z + m) / gamma(z)) def _poch_z_derivative(z, m): """ Defined in : https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/01/ """ return (digamma(z + m) - digamma(z)) * poch(z, m) def _poch_m_derivative(z, m): """ Defined in : https://functions.wolfram.com/GammaBetaErf/Pochhammer/20/01/02/ """ return digamma(z + m) * poch(z, m) poch.defjvps( lambda z_dot, primal_out, z, m: _poch_z_derivative(z, m) * z_dot, lambda m_dot, primal_out, z, m: _poch_m_derivative(z, m) * m_dot, ) def _hyp1f1_serie(a, b, x): """ Compute the 1F1 hypergeometric function using the taylor expansion See Eq. 3.2 and associated method (a) from PEARSON, OLVER & PORTER 2014 https://doi.org/10.48550/arXiv.1407.7786 """ precision = jnp.finfo(x.dtype).eps def body(state): serie, k, term = state serie += term term *= (a + k) / (b + k) * x / (k + 1) k += 1 return serie, k, term def cond(state): serie, k, term = state return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, a / b * x return lax.while_loop(cond, body, init)[0] def _hyp1f1_asymptotic(a, b, x): """ Compute the 1F1 hypergeometric function using asymptotic expansion See Eq. 3.8 and simplification for real inputs from PEARSON, OLVER & PORTER 2014 https://doi.org/10.48550/arXiv.1407.7786 """ precision = jnp.finfo(x.dtype).eps def body(state): serie, k, term = state serie += term term *= (b - a + k) * (1 - a + k) / (k + 1) / x k += 1 return serie, k, term def cond(state): serie, k, term = state return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 1, 1, (b - a) * (1 - a) / x serie = lax.while_loop(cond, body, init)[0] return gamma(b) / gamma(a) * lax.exp(x) * x ** (a - b) * serie @jit @jnp.vectorize def _hyp1f1_a_derivative(a, b, x): """ Define it as a serie using : https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/01/ """ precision = jnp.finfo(x.dtype).eps def body(state): serie, k, term = state serie += term * (digamma(a + k) - digamma(a)) term *= (a + k) / (b + k) * x / (k + 1) k += 1 return serie, k, term def cond(state): serie, k, term = state return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x return lax.while_loop(cond, body, init)[0] @jit @jnp.vectorize def _hyp1f1_b_derivative(a, b, x): """ Define it as a serie using : https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/02/ """ precision = jnp.finfo(x.dtype).eps def body(state): serie, k, term = state serie += term * (digamma(b) - digamma(b + k)) term *= (a + k) / (b + k) * x / (k + 1) k += 1 return serie, k, term def cond(state): serie, k, term = state return (k < 250) & (lax.abs(term) / lax.abs(serie) > precision) init = 0, 1, a / b * x return lax.while_loop(cond, body, init)[0] @jit def _hyp1f1_x_derivative(a, b, x): """ Define it as a serie using : https://functions.wolfram.com/HypergeometricFunctions/Hypergeometric1F1/20/01/04/ """ return a / b * hyp1f1(a + 1, b + 1, x) @custom_derivatives.custom_jvp @jit @jnp.vectorize def hyp1f1(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array: r"""The 1F1 hypergeometric function. JAX implementation of :obj:`scipy.special.hyp1f1`. .. math:: \mathrm{hyp1f1}(a, b, x) = {}_1F_1(x;a, b) = \sum_{k=0}^\infty \frac{(a)_k}{(b)_kk!}x^k where :math:`(\cdot)_k` is the Pochammer symbol (refer to :func:`~jax.scipy.special.poch`). The JAX version only accepts positive and real inputs. Values of ``a``, ``b``, and ``x``, leading to high values of 1F1 may lead to erroneous results; consider enabling double precision in this case. The convention for ``a = b = 0`` is ``1``, unlike in scipy's implementation. Args: a: arraylike, real-valued b: arraylike, real-valued x: arraylike, real-valued Returns: array of 1F1 values. """ # This is backed by https://doi.org/10.48550/arXiv.1407.7786 # There is room for improvement in the implementation using recursion to # evaluate lower values of hyp1f1 when a or b or both are > 60-80 a, b, x = promote_args_inexact('hyp1f1', a, b, x) result = lax.cond(lax.abs(x) < 100, _hyp1f1_serie, _hyp1f1_asymptotic, a, b, x) index = (a == 0) * 1 + ((a == b) & (a != 0)) * 2 + ((b == 0) & (a != 0)) * 3 return lax.select_n(index, result, jnp.array(1, dtype=x.dtype), jnp.exp(x), jnp.array(jnp.inf, dtype=x.dtype)) hyp1f1.defjvps( lambda a_dot, primal_out, a, b, x: _hyp1f1_a_derivative(a, b, x) * a_dot, lambda b_dot, primal_out, a, b, x: _hyp1f1_b_derivative(a, b, x) * b_dot, lambda x_dot, primal_out, a, b, x: _hyp1f1_x_derivative(a, b, x) * x_dot ) def _hyp2f1_terminal(a, b, c, x): """ The Taylor series representation of the 2F1 hypergeometric function terminates when either a or b is a non-positive integer. See Eq. 4.1 and Taylor Series Method (a) from PEARSON, OLVER & PORTER 2014 https://doi.org/10.48550/arXiv.1407.7786 """ # Ensure that between a and b, the negative integer parameter with the greater # absolute value - that still has a magnitude less than the absolute value of # c if c is non-positive - is used for the upper limit in the loop. eps = jnp.finfo(x.dtype).eps * 50 ib = jnp.round(b) mask = jnp.logical_and( b < a, jnp.logical_and( jnp.abs(b - ib) < eps, jnp.logical_not( jnp.logical_and( c % 1 == 0, jnp.logical_and( c <= 0, c > b ) ) ) ) ) orig_a = a a = jnp.where(mask, b, a) b = jnp.where(mask, orig_a, b) a = jnp.abs(a) def body(i, state): serie, term = state term *= -(a - i + 1) / (c + i - 1) * (b + i - 1) / i * x serie += term return serie, term init = (jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) return lax.fori_loop(jnp.array(1, dtype=a.dtype), a + 1, body, init)[0] def _hyp2f1_serie(a, b, c, x): """ Compute the 2F1 hypergeometric function using the Taylor expansion. See Eq. 4.1 from PEARSON, OLVER & PORTER 2014 https://doi.org/10.48550/arXiv.1407.7786 """ rtol = jnp.finfo(x.dtype).eps def body(state): serie, k, term = state serie += term term *= (a + k - 1) * (b + k - 1) / (c + k - 1) / k * x k += 1 return serie, k, term def cond(state): serie, k, term = state return (k < 250) & (lax.abs(term) > rtol * lax.abs(serie)) init = (jnp.array(0, dtype=x.dtype), jnp.array(1, dtype=x.dtype), jnp.array(1, dtype=x.dtype)) return lax.while_loop(cond, body, init)[0] def _hyp2f1_terminal_or_serie(a, b, c, x): """ Check for recurrence relations along with whether or not the series terminates. True recursion is not possible; however, the recurrence relation may still be approximated. See 4.6.1. Recurrence Relations from PEARSON, OLVER & PORTER 2014 https://doi.org/10.48550/arXiv.1407.7786 """ eps = jnp.finfo(x.dtype).eps * 50 d = c - a - b ia = jnp.round(a) ib = jnp.round(b) id = jnp.round(d) neg_int_a = jnp.logical_and(a <= 0, jnp.abs(a - ia) < eps) neg_int_b = jnp.logical_and(b <= 0, jnp.abs(b - ib) < eps) neg_int_a_or_b = jnp.logical_or(neg_int_a, neg_int_b) not_neg_int_a_or_b = jnp.logical_not(neg_int_a_or_b) index = jnp.where(jnp.logical_and(x > 0.9, not_neg_int_a_or_b), jnp.where(jnp.abs(d - id) >= eps, 0, 1), jnp.where(neg_int_a_or_b, 2, 0)) return lax.select_n(index, _hyp2f1_serie(a, b, c, x), _hyp2f1_digamma_transform(a, b, c, x), _hyp2f1_terminal(a, b, c, x)) def _hyp2f1_digamma_transform(a, b, c, x): """ Digamma transformation of the 2F1 hypergeometric function. See AMS55 #15.3.10, #15.3.11, #15.3.12 """ rtol = jnp.finfo(x.dtype).eps d = c - a - b s = 1 - x rd = jnp.round(d) e = jnp.where(rd >= 0, d, -d) d1 = jnp.where(rd >= 0, d, jnp.array(0, dtype=d.dtype)) d2 = jnp.where(rd >= 0, jnp.array(0, dtype=d.dtype), d) ard = jnp.where(rd >= 0, rd, -rd).astype('int32') ax = jnp.log(s) y = digamma(1.0) + digamma(1.0 + e) - digamma(a + d1) - digamma(b + d1) - ax y /= gamma(e + 1.0) p = (a + d1) * (b + d1) * s / gamma(e + 2.0) def cond(state): _, _, _, _, _, _, q, _, _, t, y = state return jnp.logical_and( t < 250, jnp.abs(q) >= rtol * jnp.abs(y) ) def body(state): a, ax, b, d1, e, p, q, r, s, t, y = state r = digamma(1.0 + t) + digamma(1.0 + t + e) - digamma(a + t + d1) \ - digamma(b + t + d1) - ax q = p * r y += q p *= s * (a + t + d1) / (t + 1.0) p *= (b + t + d1) / (t + 1.0 + e) t += 1.0 return a, ax, b, d1, e, p, q, r, s, t, y init = (a, ax, b, d1, e, p, y, jnp.array(0, dtype=x.dtype), s, jnp.array(1, dtype=x.dtype), y) _, _, _, _, _, _, q, r, _, _, y = lax.while_loop(cond, body, init) def compute_sum(y): y1 = jnp.array(1, dtype=x.dtype) t = jnp.array(0, dtype=x.dtype) p = jnp.array(1, dtype=x.dtype) def for_body(i, state): a, b, d2, e, p, s, t, y1 = state r = 1.0 - e + t p *= s * (a + t + d2) * (b + t + d2) / r t += 1.0 p /= t y1 += p return a, b, d2, e, p, s, t, y1 init_val = a, b, d2, e, p, s, t, y1 y1 = lax.fori_loop(1, ard, for_body, init_val)[-1] p = gamma(c) y1 *= gamma(e) * p / (gamma(a + d1) * gamma(b + d1)) y *= p / (gamma(a + d2) * gamma(b + d2)) y = jnp.where((ard & 1) != 0, -y, y) q = s ** rd return jnp.where(rd > 0, y * q + y1, y + y1 * q) return jnp.where( rd == 0, y * gamma(c) / (gamma(a) * gamma(b)), compute_sum(y) ) @jit @jnp.vectorize def hyp2f1(a: ArrayLike, b: ArrayLike, c: ArrayLike, x: ArrayLike) -> Array: r"""The 2F1 hypergeometric function. JAX implementation of :obj:`scipy.special.hyp2f1`. .. math:: \mathrm{hyp2f1}(a, b, c, x) = {}_2F_1(a; b; c; x) = \sum_{k=0}^\infty \frac{(a)_k(b)_k}{(c)_k}\frac{x^k}{k!} where :math:`(\cdot)_k` is the Pochammer symbol. The JAX version only accepts positive and real inputs. Values of ``a``, ``b``, ``c``, and ``x`` leading to high values of 2F1 may lead to erroneous results; consider enabling double precision in this case. Args: a: arraylike, real-valued b: arraylike, real-valued c: arraylike, real-valued x: arraylike, real-valued Returns: array of 2F1 values. """ # This is backed by https://doi.org/10.48550/arXiv.1407.7786 a, b, c, x = promote_args_inexact('hyp2f1', a, b, c, x) eps = jnp.finfo(x.dtype).eps * 50 d = c - a - b s = 1 - x ca = c - a cb = c - b id = jnp.round(d) ica = jnp.round(ca) icb = jnp.round(cb) neg_int_ca = jnp.logical_and(ca <= 0, jnp.abs(ca - ica) < eps) neg_int_cb = jnp.logical_and(cb <= 0, jnp.abs(cb - icb) < eps) neg_int_ca_or_cb = jnp.logical_or(neg_int_ca, neg_int_cb) index = jnp.where(jnp.logical_or(x == 0, jnp.logical_and(jnp.logical_or(a == 0, b == 0), c != 0)), 0, jnp.where(jnp.logical_or(c == 0, jnp.logical_and(c < 0, c % 1 == 0)), 1, jnp.where(jnp.logical_and(d <= -1, jnp.logical_not(jnp.logical_and(jnp.abs(d - id) >= eps, s < 0))), 2, jnp.where(jnp.logical_and(d <= 0, x == 1), 1, jnp.where(jnp.logical_and(x < 1, b == c), 3, jnp.where(jnp.logical_and(x < 1, a == c), 4, jnp.where(x > 1, 1, jnp.where(x == 1, 5, 6)))))))) return lax.select_n(index, jnp.array(1, dtype=x.dtype), jnp.array(jnp.inf, dtype=x.dtype), s ** d * _hyp2f1_terminal_or_serie(ca, cb, c, x), s ** (-a), s ** (-b), gamma(c) * gamma(d) / (gamma(ca) * gamma(cb)), _hyp2f1_terminal_or_serie(a, b, c, x)) def softmax(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, ) -> Array: r"""Softmax function. JAX implementation of :func:`scipy.special.softmax`. Computes the function which rescales elements to the range :math:`[0, 1]` such that the elements along :code:`axis` sum to :math:`1`. .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Args: x : input array axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. Returns: An array of the same shape as ``x``. Note: If any input values are ``+inf``, the result will be all ``NaN``: this reflects the fact that ``inf / inf`` is not well-defined in the context of floating-point math. See also: :func:`log_softmax` """ return nn_softmax(x, axis=axis) def log_softmax(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, ) -> Array: r"""Log-Softmax function. JAX implementation of :func:`scipy.special.log_softmax` Computes the logarithm of the :code:`softmax` function, which rescales elements to the range :math:`[-\infty, 0)`. .. math :: \mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right) Args: x : input array axis: the axis or axes along which the :code:`log_softmax` should be computed. Returns: An array of the same shape as ``x`` Note: If any input values are ``+inf``, the result will be all ``NaN``: this reflects the fact that ``inf / inf`` is not well-defined in the context of floating-point math. See also: :func:`softmax` """ return nn_log_softmax(x, axis=axis)