2025-08-11 12:24:21 +08:00

786 lines
30 KiB
Python

# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state primitives."""
from __future__ import annotations
from functools import partial
import types
from typing import Any, Union
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import traceback_util
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.state import indexing
from jax._src.state.types import (
AbstractRef,
AccumEffect,
ReadEffect,
Transform,
TransformedRef,
WriteEffect,
)
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip
## General utilities
## JAX utilities
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
## get/swap/addupdate implementations
# `get` reads a value from a `Ref` type, a.k.a.:
# a = get_p.bind(x)
# or we can read using indices:
# a = get_p.bind(x, 0, 1)
# Staging out `a = get_p.bind(x)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity
Indexer = Union[int, slice, Array, types.EllipsisType]
def get_ref_and_transforms(
ref_or_view: Any,
idx: Indexer | tuple[Indexer, ...] | None,
function_name: str,
force_trailing_indexer: bool = True, # TODO(apaszke): Clean this up.
) -> tuple[Any, tuple[Transform, ...]]:
if isinstance(ref_or_view, TransformedRef):
ref, transforms = ref_or_view.ref, ref_or_view.transforms
else:
ref, transforms = ref_or_view, ()
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.")
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
return ref, ()
if idx is None or idx is Ellipsis:
idx = ()
elif not isinstance(idx, tuple):
idx = (idx,)
if not idx and not force_trailing_indexer:
return ref, transforms
if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer):
return ref, transforms
nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape)
return ref, (*transforms, nd_indexer)
def ref_get(
ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None
) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get")
flat_transforms, tree = tree_util.tree_flatten(transforms)
return get_p.bind(ref, *flat_transforms, tree=tree)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
# It generalizes the setting operation for a `Ref` as we can ignore the return
# value:
# _ = swap_p.bind(x, a)
# `swap_p` also takes in index arguments following the value, i.e.:
# _ = swap_p.bind(x, a, 0, 1)
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` and the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# b:f32[3], x:Ref{f32[3]} <- x, a
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
# `Ref((3,), np.dtype('float32'))` , the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
assert len(invar_raggedness) == 2
invar_raggedness_lhs = invar_raggedness[0]
invar_raggedness_rhs = invar_raggedness[1]
return [invar_raggedness_rhs, invar_raggedness_lhs], [None]
batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule
def ref_swap(
ref_or_view: AbstractRef | TransformedRef,
idx: Indexer | tuple[Indexer, ...] | None,
value: Array,
_function_name: str = "ref_swap",
) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
if hasattr(ref_or_view, 'dtype'):
value = _maybe_implicit_cast(ref_or_view.dtype, value)
ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name)
flat_transforms, tree = tree_util.tree_flatten(transforms)
return swap_p.bind(ref, value, *flat_transforms, tree=tree)
# TODO(slebedev,mattjj): replace with special handling of Python numeric types:
# if (isinstance(value, (int, float, complex)) and
# value == np.array(value, dtype).item()): return cast
def _maybe_implicit_cast(dtype, value):
aval = core.typeof(value)
if (aval.weak_type and
(dtypes.issubdtype(dtype, np.floating) and
dtypes.issubdtype(aval.dtype, np.floating)) or
(dtypes.issubdtype(dtype, np.integer) and
dtypes.issubdtype(aval.dtype, np.integer))):
return lax.convert_element_type(value, dtype)
return value
def ref_set(
ref_or_view: AbstractRef | TransformedRef,
idx: Indexer | tuple[Indexer, ...] | None,
value: Array,
) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref_or_view, idx, value, _function_name="ref_set")
# `addupdate_p` mutates a `Ref`, adding a value to its existing value.
# Semantically,
# ```
# addupdate ref a *idx
# ```
# is equivalent to
# ```
# b = get ref *idx
# c = add b x
# _ = swap ref c *idx
# ```
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))
def ref_addupdate(
ref_or_view: AbstractRef,
idx: Indexer | tuple[Indexer, ...] | None,
x: Array,
) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate")
flat_transforms, tree = tree_util.tree_flatten(transforms)
return addupdate_p.bind(ref, x, *flat_transforms, tree=tree)
## get/set/addupdate abstract evaluation rules
def _shape_after_transforming(
shape: tuple[int | Array, ...], transforms: tuple[Transform, ...]
) -> tuple[int | Array, ...]:
for transform in transforms:
shape = transform.transform_shape(shape) # type: ignore
assert shape is not None
return shape
def _dtype_after_transforming(
dtype: Any, transforms: tuple[Transform, ...]
) -> Any:
for transform in transforms:
dtype = transform.transform_dtype(dtype)
assert dtype is not None
return dtype
def _sharding_after_transforming(sharding, transforms):
for transform in transforms:
sharding = transform.transform_sharding(sharding)
assert sharding is not None
return sharding
def _get_abstract_eval(ref_aval: AbstractRef, *args,
tree):
transforms = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms)
out_aval = ref_aval.inner_aval.update(
shape=out_shape, dtype=out_dtype, sharding=out_sharding)
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*args: Any, tree):
transforms = tree_util.tree_unflatten(tree, args)
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
assert isinstance(val_aval, core.ShapedArray)
expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms)
expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
if expected_out_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. "
f"Transforms: {transforms}. ")
if expected_out_dtype != val_aval.dtype:
raise ValueError(
"Invalid dtype for `swap`. "
f"Ref dtype: {expected_out_dtype}. "
f"Value dtype: {val_aval.dtype}. "
)
out_aval = core.ShapedArray(expected_out_shape, expected_out_dtype)
else:
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*args: Any, tree):
transforms = tree_util.tree_unflatten(tree, args)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
out_shape = _shape_after_transforming(ref_aval.shape, transforms)
out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms)
assert isinstance(val_aval, core.ShapedArray)
if out_shape != val_aval.shape:
raise ValueError(
"Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. "
f"Expected shape: {out_shape}. "
f"Value shape: {val_aval.shape}. "
f"Transforms: {transforms}. "
)
if out_dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
else:
# Check that the transforms are valid
if transforms:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
## Pretty printing for `get` and `swap` in jaxprs
pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL,
foreground=pp.Color.GREEN)
def _pp_transforms(
context: core.JaxprPpContext,
transforms: tuple[Transform, ...],
):
if not transforms:
return pp.text("[...]")
return pp.concat(
[transform.pretty_print(context) for transform in transforms]
)
def pp_ref_transforms(context: core.JaxprPpContext, ref, transforms):
return pp_ref_var(
pp.concat([
pp.text(core.pp_var(ref, context)),
_pp_transforms(context, transforms),
])
)
def _get_pp_rule(eqn, context, settings) -> pp.Doc:
# Pretty prints `a = get x i` as `x[i] <- a`
y, = eqn.outvars
x, *flat_idx = eqn.invars
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat(
[lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)]
)
core.pp_eqn_rules[get_p] = _get_pp_rule
def _swap_pp_rule(eqn, context, settings) -> pp.Doc:
y, = eqn.outvars
x, v, *flat_idx = eqn.invars
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
if type(y) is core.DropVar:
# In the case of a set (ignored return value),
# pretty print `_ = swap x v i` as `x[i] <- v`
del y
return pp.concat([
pp_ref_transforms(context, x, transforms),
pp.text(" <- "),
pp.text(core.pp_var(v, context)),
])
else:
# pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v`
x_i = pp_ref_transforms(context, x, transforms)
y = core.pp_vars([y], context, print_shapes=settings.print_shapes)
return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '),
x_i, pp.text(', '),
pp.text(core.pp_var(v, context))])
core.pp_eqn_rules[swap_p] = _swap_pp_rule
def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc:
del settings
# pretty-print ` = addupdate x i v` as `x[i] += v`
() = eqn.outvars
x, v, *flat_idx = eqn.invars
transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx)
return pp.concat([
pp_ref_transforms(context, x, transforms),
pp.text(" += "),
pp.text(core.pp_var(v, context)),
])
core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
## get/swap/addupdate JVP rules
def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, AbstractRef)
return (get_p.bind(ref_primal, *idx, **params),
get_p.bind(ref_tangent, *idx, **params))
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
# if type(ref_tangent) is ad_util.Zero:
# raise Exception("you're an idiot")
assert isinstance(ref_tangent.aval, AbstractRef)
x_tangent = ad_util.instantiate(x_tangent)
return (swap_p.bind(ref_primal, x_primal, *idx, **params),
swap_p.bind(ref_tangent, x_tangent, *idx, **params))
ad.primitive_jvps[swap_p] = _swap_jvp
def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any):
ref_primal, x_primal, *idx = primals
ref_tangent, x_tangent, *_ = tangents
x_tangent = ad_util.instantiate(x_tangent)
addupdate_p.bind(ref_primal, x_primal, *idx, **params)
addupdate_p.bind(ref_tangent, x_tangent, *idx, **params)
return [], []
ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule
## get/swap/addupdate transpose rules
def _get_transpose(g, ref, *idx, **params):
# get transpose is addupdate
if type(g) is not ad_util.Zero:
addupdate_p.bind(ref, g, *idx, **params)
return [None] + [None] * len(idx)
ad.primitive_transposes[get_p] = _get_transpose
def _swap_transpose(g, ref, x, *idx, **params):
del x # old value doesn't matter anymore
# swap transpose is swap
x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params)
return [None, x_bar] + [None] * len(idx)
ad.primitive_transposes[swap_p] = _swap_transpose
def addupdate_transpose(cts_in, ref, x, *idx, **params):
# addupdate transpose is get
del cts_in, x
g = get_p.bind(ref, *idx, **params)
return [None, g] + [None] * len(idx)
ad.primitive_transposes[addupdate_p] = addupdate_transpose
## get/swap/addupdate partial_eval_custom rules
def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn):
if any(unks_in):
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res
elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params):
return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), []
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res
pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom,
get_p)
pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom,
swap_p)
pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
_state_partial_eval_custom, addupdate_p)
## get/swap/addupdate batching rules
def _batch_indexer(
indexer: indexing.NDIndexer,
dims,
axis_size: int,
ref_shape: tuple[int, ...],
ref_dim: int | batching.NotMapped,
idx_is_batched: bool,
) -> indexing.NDIndexer:
"""Converts a batched indexer into an unbatched one.
This function handles the complexity of `vmap`-style batching where either the
`ref` being indexed, the indexer, or both may have batched dimensions. The
goal is to produce a new indexer that acts as if applied in a batched context,
but without actual batching, enabling downstream code to process it as usual.
If any index in `indexer` is batched, all array indexers are normalized. If
the array indexer contains a batched dimension, the dimension is moved to the
front (axis 0). If the array indexer not batched, it is broadcasted to include
a batch dimension at the front. This is to guarantee that all array indexers
are still of the same shape.
Slices are passed through unchanged unless they contain dynamic elements and
are themselves batched, which is currently unsupported.
If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example
indexing by inserting a new iota array at the position corresponding to
`ref_dim` in the indexer.
It is worth noting that if the array indexers in the original indexer are
contiguous, but become non-contiguous in the new indexer due to the insertion
of the iota, the dimensions corresponding to the array indexers will be moved
to the front in the indexing result. The batched dimension will be at axis 0,
while the dimensions corresponding to the array indexers in the original
indexer will start from axis 1. This behavior would cause a mismatch between
the original indexer and the new indexer. Callers must take this behavior into
account and properly transpose the arrays involved to avoid this mismatch.
Args:
indexer: An `NDIndexer` that indexes into `ref`.
dims: A pytree with the same structure as `indexer`, indicating which
dimension (if any) is batched for each array indexer.
axis_size: Size of the batch dimension.
ref_shape: Shape of `ref`.
ref_dim: The dimension of `ref` that is batched (if any).
idx_is_batched: Whether any index in the `indexer` is batched.
"""
indices = indexer.indices
indices_dims = dims.indices
new_indices: list[Array | indexing.Slice | int] = []
new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape)
for idx, dim in zip(indices, indices_dims):
if idx_is_batched:
# If at least one of the idx is batched, we broadcast them all and move the
# batch dim to the front.
if isinstance(idx, indexing.Slice):
# size is static, but start can be dynamic
# Check if start is static (which it can be)
is_static_slice = len(tree_util.tree_leaves(idx)) == 0
if is_static_slice:
new_indices.append(idx)
continue
dim = dim.start
if dim is batching.not_mapped:
# Broadcasting the slice is free (the start index stays the same)
new_indices.append(idx)
else:
raise NotImplementedError(
f"No support for vmapping over nontrivial slices just yet: {idx}")
else:
# Check if we are indexing with a scalar or not. If we are indexing
# with a scalar and we are not batched, we can avoid broadcasting it.
assert hasattr(idx, "shape")
if not idx.shape:
if dim is not batching.not_mapped:
assert idx.shape == (axis_size,)
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,))
new_indices.append(idx)
else:
if dim is batching.not_mapped:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
else:
idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type]
new_indices.append(idx)
else:
if ref_dim is not batching.not_mapped:
if not isinstance(idx, indexing.Slice):
assert hasattr(idx, "shape")
if idx.shape:
bcast_dims = tuple(range(1, np.ndim(idx) + 1))
idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape,
bcast_dims)
new_indices.append(idx)
if ref_dim is not batching.not_mapped:
iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0)
new_indices.insert(ref_dim, iota)
return indexing.NDIndexer(
tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True
)
def _get_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, *flat_idxs = batched_args
ref_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_int_indexing)[0]) == 1)
)
is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0])
new_int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1)
)
out = get_p.bind(ref, *flat_indexers, tree=tree)
if not int_indexers_contiguous: # will always be moved to the front
out_bdim = 0
else: # originally not going to be moved to the front
if new_int_indexers_contiguous: # now not going to be moved to the front
out_bdim = is_new_int_indexing.index(True)
else: # now going to be moved to the front
original_pos = is_int_indexing.index(True)
array_indexer_shape = new_indexers[0].int_indexer_shape
array_indexer_len = len(array_indexer_shape)
transpose_order = list(range(len(out.shape)))
transpose_order = (
transpose_order[0],
*transpose_order[array_indexer_len:array_indexer_len+original_pos],
*transpose_order[1:array_indexer_len],
*transpose_order[array_indexer_len+original_pos:],
)
out = lax.transpose(out, transpose_order)
out_bdim = 0
return out, out_bdim
batching.primitive_batchers[get_p] = _get_vmap
def _swap_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if not ref_is_batched:
raise Exception("performing a set/swap operation with vmapped value on "
"an unbatched mutable array reference "
f"of type {core.typeof(ref)}. Move the mutable array to be "
"an argument to the vmapped function?")
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_int_indexing)[0]) == 1)
)
is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0])
new_int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1)
)
if not new_int_indexers_contiguous: # will be moved to the front
batched_dim_in_result = 0
else:
batched_dim_in_result = is_new_int_indexing.index(True) + 0
if not val_is_batched:
if ref_is_batched or idx_is_batched:
val = batching.broadcast(val, axis_size, batched_dim_in_result)
else:
val = batching.moveaxis(val, val_dim, batched_dim_in_result)
transpose_order_inversed = None
# Originally not going to be moved to the front, but now going to be moved to
# the front.
if int_indexers_contiguous and not new_int_indexers_contiguous:
original_pos = is_int_indexing.index(True)
array_indexer_shape = new_indexers[0].int_indexer_shape
array_indexer_len = len(array_indexer_shape)
transpose_order = list(range(len(val.shape)))
transpose_order = (
transpose_order[0],
*transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)],
*transpose_order[1:1+original_pos],
*transpose_order[(1+original_pos)+(array_indexer_len-1):],
)
val = val.transpose(transpose_order)
transpose_order_inversed = np.argsort(transpose_order)
out = swap_p.bind(ref, val, *flat_indexers, tree=tree)
# `val` should not be transposed, but we needed to transpose it to match
# `swap_p`. As a result, the output of `swap_p` is also transposed. Now we
# need to transpose it back.
if transpose_order_inversed is not None:
out = out.transpose(transpose_order_inversed)
return out, batched_dim_in_result
batching.primitive_batchers[swap_p] = _swap_vmap
def _addupdate_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims)
ref_is_batched = ref_dim is not batching.not_mapped
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_int_indexing)[0]) == 1)
)
is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0])
new_int_indexers_contiguous = bool(
np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1)
)
if not new_int_indexers_contiguous: # will be moved to the front
batched_dim_in_result = 0
else:
batched_dim_in_result = is_new_int_indexing.index(True)
if not val_is_batched:
if ref_is_batched or idx_is_batched:
val = batching.broadcast(val, axis_size, batched_dim_in_result)
else:
val = batching.moveaxis(val, val_dim, batched_dim_in_result)
# Originally not going to be moved to the front, but now going to be moved to
# the front.
if int_indexers_contiguous and not new_int_indexers_contiguous:
original_pos = is_int_indexing.index(True)
array_indexer_shape = new_indexers[0].int_indexer_shape
array_indexer_len = len(array_indexer_shape)
transpose_order = list(range(len(val.shape)))
transpose_order = (
transpose_order[0],
*transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)],
*transpose_order[1:1+original_pos],
*transpose_order[(1+original_pos)+(array_indexer_len-1):],
)
val = val.transpose(transpose_order)
return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
# then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`)
# so in the lowering, we have to expand out those squeezed dimensions again.
# Having a simple `broadcast_to` primitive allows us to lower directly
# to `tl.broadcast_to`.
broadcast_to_p = core.Primitive('broadcast_to')
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
import jax.numpy as jnp
a = jnp.asarray(a)
if a.shape == shape:
return a
return broadcast_to_p.bind(a, shape=shape)
@broadcast_to_p.def_impl
def _broadcast_to_impl(a, *, shape):
import jax.numpy as jnp
return jnp.broadcast_to(a, shape)
@broadcast_to_p.def_abstract_eval
def _broadcast_to_abstract_eval(aval, *, shape):
return core.ShapedArray(shape, aval.dtype)
mlir.register_lowering(
broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False)
)
# === AD rules for mutable arrays ===
def _mut_jvp(primals, tangents):
(init_val,), (init_val_dot,) = primals, tangents
primal_out = core.mutable_array_p.bind(init_val)
if type(init_val_dot) is ad_util.Zero:
tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval))
else:
tangent_out = core.mutable_array_p.bind(init_val_dot)
return primal_out, tangent_out
ad.primitive_jvps[core.mutable_array_p] = _mut_jvp
ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))