70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
|
|
from jax import jit
|
|
from jax._src.typing import Array, ArrayLike
|
|
import jax.numpy as jnp
|
|
|
|
|
|
@partial(jit, static_argnames=('axis',))
|
|
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
|
axis: int = -1) -> Array:
|
|
r"""
|
|
Integrate along the given axis using the composite trapezoidal rule.
|
|
|
|
JAX implementation of :func:`scipy.integrate.trapezoid`
|
|
|
|
The trapezoidal rule approximates the integral under a curve by summing the
|
|
areas of trapezoids formed between adjacent data points.
|
|
|
|
Args:
|
|
y: array of data to integrate.
|
|
x: optional array of sample points corresponding to the ``y`` values. If not
|
|
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
|
|
dx: The spacing between sample points when `x` is None (default: 1.0).
|
|
axis: The axis along which to integrate (default: -1)
|
|
|
|
Returns:
|
|
The definite integral approximated by the trapezoidal rule.
|
|
|
|
See also:
|
|
:func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration
|
|
|
|
Examples:
|
|
Integrate over a regular grid, with spacing 1.0:
|
|
|
|
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
|
|
>>> jax.scipy.integrate.trapezoid(y, dx=1.0)
|
|
Array(13., dtype=float32)
|
|
|
|
Integrate over an irregular grid:
|
|
|
|
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
|
|
>>> jax.scipy.integrate.trapezoid(y, x)
|
|
Array(43., dtype=float32)
|
|
|
|
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
|
|
|
|
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
|
|
>>> y = jnp.sin(x) ** 2
|
|
>>> result = jax.scipy.integrate.trapezoid(y, x)
|
|
>>> jnp.allclose(result, jnp.pi)
|
|
Array(True, dtype=bool)
|
|
"""
|
|
return jnp.trapezoid(y, x, dx, axis)
|