345 lines
13 KiB
Plaintext
345 lines
13 KiB
Plaintext
Metadata-Version: 2.2
|
||
Name: jax
|
||
Version: 0.6.2
|
||
Summary: Differentiate, compile, and transform Numpy code.
|
||
Home-page: https://github.com/jax-ml/jax
|
||
Author: JAX team
|
||
Author-email: jax-dev@google.com
|
||
License: Apache-2.0
|
||
Classifier: Development Status :: 5 - Production/Stable
|
||
Classifier: Programming Language :: Python :: 3.10
|
||
Classifier: Programming Language :: Python :: 3.11
|
||
Classifier: Programming Language :: Python :: 3.12
|
||
Classifier: Programming Language :: Python :: 3.13
|
||
Classifier: Programming Language :: Python :: Free Threading :: 3 - Stable
|
||
Requires-Python: >=3.10
|
||
Description-Content-Type: text/markdown
|
||
License-File: LICENSE
|
||
License-File: AUTHORS
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2
|
||
Requires-Dist: ml_dtypes>=0.5.0
|
||
Requires-Dist: numpy>=1.26
|
||
Requires-Dist: opt_einsum
|
||
Requires-Dist: scipy>=1.12
|
||
Provides-Extra: minimum-jaxlib
|
||
Requires-Dist: jaxlib==0.6.2; extra == "minimum-jaxlib"
|
||
Provides-Extra: cpu
|
||
Provides-Extra: ci
|
||
Requires-Dist: jaxlib==0.6.1; extra == "ci"
|
||
Provides-Extra: tpu
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2; extra == "tpu"
|
||
Requires-Dist: libtpu==0.0.17.*; extra == "tpu"
|
||
Requires-Dist: requests; extra == "tpu"
|
||
Provides-Extra: cuda
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2; extra == "cuda"
|
||
Requires-Dist: jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == "cuda"
|
||
Provides-Extra: cuda12
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2; extra == "cuda12"
|
||
Requires-Dist: jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == "cuda12"
|
||
Provides-Extra: cuda12-local
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2; extra == "cuda12-local"
|
||
Requires-Dist: jax-cuda12-plugin<=0.6.2,>=0.6.2; extra == "cuda12-local"
|
||
Provides-Extra: rocm
|
||
Requires-Dist: jaxlib<=0.6.2,>=0.6.2; extra == "rocm"
|
||
Requires-Dist: jax-rocm60-plugin<=0.6.2,>=0.6.2; extra == "rocm"
|
||
Provides-Extra: k8s
|
||
Requires-Dist: kubernetes; extra == "k8s"
|
||
Provides-Extra: xprof
|
||
Requires-Dist: xprof; extra == "xprof"
|
||
Dynamic: author
|
||
Dynamic: author-email
|
||
Dynamic: classifier
|
||
Dynamic: description
|
||
Dynamic: description-content-type
|
||
Dynamic: home-page
|
||
Dynamic: license
|
||
Dynamic: provides-extra
|
||
Dynamic: requires-dist
|
||
Dynamic: requires-python
|
||
Dynamic: summary
|
||
|
||
<div align="center">
|
||
<img src="https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png" alt="logo"></img>
|
||
</div>
|
||
|
||
# Transformable numerical computing at scale
|
||
|
||
[](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml)
|
||
[](https://pypi.org/project/jax/)
|
||
|
||
[**Transformations**](#transformations)
|
||
| [**Scaling**](#scaling)
|
||
| [**Install guide**](#installation)
|
||
| [**Change logs**](https://docs.jax.dev/en/latest/changelog.html)
|
||
| [**Reference docs**](https://docs.jax.dev/en/latest/)
|
||
|
||
|
||
## What is JAX?
|
||
|
||
JAX is a Python library for accelerator-oriented array computation and program transformation,
|
||
designed for high-performance numerical computing and large-scale machine learning.
|
||
|
||
JAX can automatically differentiate native
|
||
Python and NumPy functions. It can differentiate through loops, branches,
|
||
recursion, and closures, and it can take derivatives of derivatives of
|
||
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
|
||
via [`jax.grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
|
||
and the two can be composed arbitrarily to any order.
|
||
|
||
JAX uses [XLA](https://www.tensorflow.org/xla)
|
||
to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators.
|
||
You can compile your own pure functions with [`jax.jit`](#compilation-with-jit).
|
||
Compilation and automatic differentiation can be composed arbitrarily.
|
||
|
||
Dig a little deeper, and you'll see that JAX is really an extensible system for
|
||
[composable function transformations](#transformations) at [scale](#scaling).
|
||
|
||
This is a research project, not an official Google product. Expect
|
||
[sharp edges](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||
Please help by trying it out, [reporting bugs](https://github.com/jax-ml/jax/issues),
|
||
and letting us know what you think!
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def predict(params, inputs):
|
||
for W, b in params:
|
||
outputs = jnp.dot(inputs, W) + b
|
||
inputs = jnp.tanh(outputs) # inputs to the next layer
|
||
return outputs # no activation on last layer
|
||
|
||
def loss(params, inputs, targets):
|
||
preds = predict(params, inputs)
|
||
return jnp.sum((preds - targets)**2)
|
||
|
||
grad_loss = jax.jit(jax.grad(loss)) # compiled gradient evaluation function
|
||
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
|
||
```
|
||
|
||
### Contents
|
||
* [Transformations](#transformations)
|
||
* [Scaling](#scaling)
|
||
* [Current gotchas](#gotchas-and-sharp-bits)
|
||
* [Installation](#installation)
|
||
* [Neural net libraries](#neural-network-libraries)
|
||
* [Citing JAX](#citing-jax)
|
||
* [Reference documentation](#reference-documentation)
|
||
|
||
## Transformations
|
||
|
||
At its core, JAX is an extensible system for transforming numerical functions.
|
||
Here are three: `jax.grad`, `jax.jit`, and `jax.vmap`.
|
||
|
||
### Automatic differentiation with `grad`
|
||
|
||
Use [`jax.grad`](https://docs.jax.dev/en/latest/jax.html#jax.grad)
|
||
to efficiently compute reverse-mode gradients:
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def tanh(x):
|
||
y = jnp.exp(-2.0 * x)
|
||
return (1.0 - y) / (1.0 + y)
|
||
|
||
grad_tanh = jax.grad(tanh)
|
||
print(grad_tanh(1.0))
|
||
# prints 0.4199743
|
||
```
|
||
|
||
You can differentiate to any order with `grad`:
|
||
|
||
```python
|
||
print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
|
||
# prints 0.62162673
|
||
```
|
||
|
||
You're free to use differentiation with Python control flow:
|
||
|
||
```python
|
||
def abs_val(x):
|
||
if x > 0:
|
||
return x
|
||
else:
|
||
return -x
|
||
|
||
abs_val_grad = jax.grad(abs_val)
|
||
print(abs_val_grad(1.0)) # prints 1.0
|
||
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
|
||
```
|
||
|
||
See the [JAX Autodiff
|
||
Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html)
|
||
and the [reference docs on automatic
|
||
differentiation](https://docs.jax.dev/en/latest/jax.html#automatic-differentiation)
|
||
for more.
|
||
|
||
### Compilation with `jit`
|
||
|
||
Use XLA to compile your functions end-to-end with
|
||
[`jit`](https://docs.jax.dev/en/latest/jax.html#just-in-time-compilation-jit),
|
||
used either as an `@jit` decorator or as a higher-order function.
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def slow_f(x):
|
||
# Element-wise ops see a large benefit from fusion
|
||
return x * x + x * 2.0
|
||
|
||
x = jnp.ones((5000, 5000))
|
||
fast_f = jax.jit(slow_f)
|
||
%timeit -n10 -r3 fast_f(x)
|
||
%timeit -n10 -r3 slow_f(x)
|
||
```
|
||
|
||
Using `jax.jit` constrains the kind of Python control flow
|
||
the function can use; see
|
||
the tutorial on [Control Flow and Logical Operators with JIT](https://docs.jax.dev/en/latest/control-flow.html)
|
||
for more.
|
||
|
||
### Auto-vectorization with `vmap`
|
||
|
||
[`vmap`](https://docs.jax.dev/en/latest/jax.html#vectorization-vmap) maps
|
||
a function along array axes.
|
||
But instead of just looping over function applications, it pushes the loop down
|
||
onto the function’s primitive operations, e.g. turning matrix-vector multiplies into
|
||
matrix-matrix multiplies for better performance.
|
||
|
||
Using `vmap` can save you from having to carry around batch dimensions in your
|
||
code:
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def l1_distance(x, y):
|
||
assert x.ndim == y.ndim == 1 # only works on 1D inputs
|
||
return jnp.sum(jnp.abs(x - y))
|
||
|
||
def pairwise_distances(dist1D, xs):
|
||
return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)
|
||
|
||
xs = jax.random.normal(jax.random.key(0), (100, 3))
|
||
dists = pairwise_distances(l1_distance, xs)
|
||
dists.shape # (100, 100)
|
||
```
|
||
|
||
By composing `jax.vmap` with `jax.grad` and `jax.jit`, we can get efficient
|
||
Jacobian matrices, or per-example gradients:
|
||
|
||
```python
|
||
per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))
|
||
```
|
||
|
||
## Scaling
|
||
|
||
To scale your computations across thousands of devices, you can use any
|
||
composition of these:
|
||
* [**Compiler-based automatic parallelization**](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
|
||
where you program as if using a single global machine, and the compiler chooses
|
||
how to shard data and partition computation (with some user-provided constraints);
|
||
* [**Explicit sharding and automatic partitioning**](https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
|
||
where you still have a global view but data shardings are
|
||
explicit in JAX types, inspectable using `jax.typeof`;
|
||
* [**Manual per-device programming**](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
|
||
where you have a per-device view of data
|
||
and computation, and can communicate with explicit collectives.
|
||
|
||
| Mode | View? | Explicit sharding? | Explicit Collectives? |
|
||
|---|---|---|---|
|
||
| Auto | Global | ❌ | ❌ |
|
||
| Explicit | Global | ✅ | ❌ |
|
||
| Manual | Per-device | ✅ | ✅ |
|
||
|
||
```python
|
||
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
|
||
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
|
||
set_mesh(mesh)
|
||
|
||
# parameters are sharded for FSDP:
|
||
for W, b in params:
|
||
print(f'{jax.typeof(W)}') # f32[512@data,512]
|
||
print(f'{jax.typeof(b)}') # f32[512]
|
||
|
||
# shard data for batch parallelism:
|
||
inputs, targets = jax.device_put((inputs, targets), P('data'))
|
||
|
||
# evaluate gradients, automatically parallelized!
|
||
gradfun = jax.jit(jax.grad(loss))
|
||
param_grads = gradfun(params, (inputs, targets))
|
||
```
|
||
|
||
See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and
|
||
[advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more.
|
||
|
||
## Gotchas and sharp bits
|
||
|
||
See the [Gotchas
|
||
Notebook](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||
|
||
## Installation
|
||
|
||
### Supported platforms
|
||
|
||
| | Linux x86_64 | Linux aarch64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 |
|
||
|------------|--------------|---------------|--------------|----------------|---------------------|
|
||
| CPU | yes | yes | yes | yes | yes |
|
||
| NVIDIA GPU | yes | yes | n/a | no | experimental |
|
||
| Google TPU | yes | n/a | n/a | n/a | n/a |
|
||
| AMD GPU | yes | no | n/a | no | no |
|
||
| Apple GPU | n/a | no | experimental | n/a | n/a |
|
||
| Intel GPU | experimental | n/a | n/a | no | no |
|
||
|
||
|
||
### Instructions
|
||
|
||
| Platform | Instructions |
|
||
|-----------------|-----------------------------------------------------------------------------------------------------------------|
|
||
| CPU | `pip install -U jax` |
|
||
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
|
||
| Google TPU | `pip install -U "jax[tpu]"` |
|
||
| AMD GPU (Linux) | Follow [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md). |
|
||
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
|
||
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
|
||
|
||
See [the documentation](https://docs.jax.dev/en/latest/installation.html)
|
||
for information on alternative installation strategies. These include compiling
|
||
from source, installing with Docker, using other versions of CUDA, a
|
||
community-supported conda build, and answers to some frequently-asked questions.
|
||
|
||
## Citing JAX
|
||
|
||
To cite this repository:
|
||
|
||
```
|
||
@software{jax2018github,
|
||
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
|
||
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
|
||
url = {http://github.com/jax-ml/jax},
|
||
version = {0.3.13},
|
||
year = {2018},
|
||
}
|
||
```
|
||
|
||
In the above bibtex entry, names are in alphabetical order, the version number
|
||
is intended to be that from [jax/version.py](../main/jax/version.py), and
|
||
the year corresponds to the project's open-source release.
|
||
|
||
A nascent version of JAX, supporting only automatic differentiation and
|
||
compilation to XLA, was described in a [paper that appeared at SysML
|
||
2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
|
||
covering JAX's ideas and capabilities in a more comprehensive and up-to-date
|
||
paper.
|
||
|
||
## Reference documentation
|
||
|
||
For details about the JAX API, see the
|
||
[reference documentation](https://docs.jax.dev/).
|
||
|
||
For getting started as a JAX developer, see the
|
||
[developer documentation](https://docs.jax.dev/en/latest/developer.html).
|