70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2023 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| from functools import partial
 | |
| 
 | |
| from jax import jit
 | |
| from jax._src.typing import Array, ArrayLike
 | |
| import jax.numpy as jnp
 | |
| 
 | |
| 
 | |
| @partial(jit, static_argnames=('axis',))
 | |
| def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
 | |
|               axis: int = -1) -> Array:
 | |
|   r"""
 | |
|   Integrate along the given axis using the composite trapezoidal rule.
 | |
| 
 | |
|   JAX implementation of :func:`scipy.integrate.trapezoid`
 | |
| 
 | |
|   The trapezoidal rule approximates the integral under a curve by summing the
 | |
|   areas of trapezoids formed between adjacent data points.
 | |
| 
 | |
|   Args:
 | |
|     y: array of data to integrate.
 | |
|     x: optional array of sample points corresponding to the ``y`` values. If not
 | |
|        provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
 | |
|     dx: The spacing between sample points when `x` is None (default: 1.0).
 | |
|     axis: The axis along which to integrate (default: -1)
 | |
| 
 | |
|   Returns:
 | |
|     The definite integral approximated by the trapezoidal rule.
 | |
| 
 | |
|   See also:
 | |
|     :func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration
 | |
| 
 | |
|   Examples:
 | |
|     Integrate over a regular grid, with spacing 1.0:
 | |
| 
 | |
|     >>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
 | |
|     >>> jax.scipy.integrate.trapezoid(y, dx=1.0)
 | |
|     Array(13., dtype=float32)
 | |
| 
 | |
|     Integrate over an irregular grid:
 | |
| 
 | |
|     >>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
 | |
|     >>> jax.scipy.integrate.trapezoid(y, x)
 | |
|     Array(43., dtype=float32)
 | |
| 
 | |
|     Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
 | |
| 
 | |
|     >>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
 | |
|     >>> y = jnp.sin(x) ** 2
 | |
|     >>> result = jax.scipy.integrate.trapezoid(y, x)
 | |
|     >>> jnp.allclose(result, jnp.pi)
 | |
|     Array(True, dtype=bool)
 | |
|   """
 | |
|   return jnp.trapezoid(y, x, dx, axis)
 |