# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Module for state primitives.""" from __future__ import annotations from functools import partial import types from typing import Any, Union import numpy as np from jax._src import ad_util from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import traceback_util from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax from jax._src.state import indexing from jax._src.state.types import ( AbstractRef, AccumEffect, ReadEffect, Transform, TransformedRef, WriteEffect, ) from jax._src.typing import Array from jax._src.util import safe_map, safe_zip ## General utilities ## JAX utilities map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip traceback_util.register_exclusion(__file__) ## get/swap/addupdate implementations # `get` reads a value from a `Ref` type, a.k.a.: # a = get_p.bind(x) # or we can read using indices: # a = get_p.bind(x, 0, 1) # Staging out `a = get_p.bind(x)` where the aval of `x` is # `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like # a:f32[3] <- x[] get_p = core.Primitive("get") get_p.def_impl(partial(dispatch.apply_primitive, get_p)) batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity Indexer = Union[int, slice, Array, types.EllipsisType] def get_ref_and_transforms( ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None, function_name: str, force_trailing_indexer: bool = True, # TODO(apaszke): Clean this up. ) -> tuple[Any, tuple[Transform, ...]]: if isinstance(ref_or_view, TransformedRef): ref, transforms = ref_or_view.ref, ref_or_view.transforms else: ref, transforms = ref_or_view, () ref_aval = core.get_aval(ref) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") if not isinstance(ref_aval.inner_aval, core.ShapedArray): return ref, () if idx is None or idx is Ellipsis: idx = () elif not isinstance(idx, tuple): idx = (idx,) if not idx and not force_trailing_indexer: return ref, transforms if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer): return ref, transforms nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) return ref, (*transforms, nd_indexer) def ref_get( ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None ) -> Array: """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get") flat_transforms, tree = tree_util.tree_flatten(transforms) return get_p.bind(ref, *flat_transforms, tree=tree) # `swap` mutates a `Ref`, setting its value and returns its previous value. # b = swap_p.bind(x, a) # It generalizes the setting operation for a `Ref` as we can ignore the return # value: # _ = swap_p.bind(x, a) # `swap_p` also takes in index arguments following the value, i.e.: # _ = swap_p.bind(x, a, 0, 1) # Staging out `b = swap_p.bind(x, a)` where the aval of `x` is # `Ref((3,), np.dtype('float32'))` and the aval of `a` is # `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like # b:f32[3], x:Ref{f32[3]} <- x, a # Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is # `Ref((3,), np.dtype('float32'))` , the aval of `a` is # `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j` # are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like # x:Ref{f32[3]}[i, j] <- a swap_p = core.Primitive("swap") swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): assert len(invar_raggedness) == 2 invar_raggedness_lhs = invar_raggedness[0] invar_raggedness_rhs = invar_raggedness[1] return [invar_raggedness_rhs, invar_raggedness_lhs], [None] batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule def ref_swap( ref_or_view: AbstractRef | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, value: Array, _function_name: str = "ref_swap", ) -> Array: """Sets a `Ref`'s value and returns the original value.""" if hasattr(ref_or_view, 'dtype'): value = _maybe_implicit_cast(ref_or_view.dtype, value) ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) flat_transforms, tree = tree_util.tree_flatten(transforms) return swap_p.bind(ref, value, *flat_transforms, tree=tree) # TODO(slebedev,mattjj): replace with special handling of Python numeric types: # if (isinstance(value, (int, float, complex)) and # value == np.array(value, dtype).item()): return cast def _maybe_implicit_cast(dtype, value): aval = core.typeof(value) if (aval.weak_type and (dtypes.issubdtype(dtype, np.floating) and dtypes.issubdtype(aval.dtype, np.floating)) or (dtypes.issubdtype(dtype, np.integer) and dtypes.issubdtype(aval.dtype, np.integer))): return lax.convert_element_type(value, dtype) return value def ref_set( ref_or_view: AbstractRef | TransformedRef, idx: Indexer | tuple[Indexer, ...] | None, value: Array, ) -> None: """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" ref_swap(ref_or_view, idx, value, _function_name="ref_set") # `addupdate_p` mutates a `Ref`, adding a value to its existing value. # Semantically, # ``` # addupdate ref a *idx # ``` # is equivalent to # ``` # b = get ref *idx # c = add b x # _ = swap ref c *idx # ``` addupdate_p = core.Primitive('addupdate') addupdate_p.multiple_results = True addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p)) def ref_addupdate( ref_or_view: AbstractRef, idx: Indexer | tuple[Indexer, ...] | None, x: Array, ) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate") flat_transforms, tree = tree_util.tree_flatten(transforms) return addupdate_p.bind(ref, x, *flat_transforms, tree=tree) ## get/set/addupdate abstract evaluation rules def _shape_after_transforming( shape: tuple[int | Array, ...], transforms: tuple[Transform, ...] ) -> tuple[int | Array, ...]: for transform in transforms: shape = transform.transform_shape(shape) # type: ignore assert shape is not None return shape def _dtype_after_transforming( dtype: Any, transforms: tuple[Transform, ...] ) -> Any: for transform in transforms: dtype = transform.transform_dtype(dtype) assert dtype is not None return dtype def _sharding_after_transforming(sharding, transforms): for transform in transforms: sharding = transform.transform_sharding(sharding) assert sharding is not None return sharding def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) out_sharding = _sharding_after_transforming(ref_aval.sharding, transforms) out_aval = ref_aval.inner_aval.update( shape=out_shape, dtype=out_dtype, sharding=out_sharding) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {ReadEffect(0)}) get_p.def_effectful_abstract_eval(_get_abstract_eval) def _swap_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): transforms = tree_util.tree_unflatten(tree, args) out_aval: core.AbstractValue if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): assert isinstance(val_aval, core.ShapedArray) expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) if expected_out_shape != val_aval.shape: raise ValueError("Invalid shape for `swap`. " f"Ref shape: {ref_aval.shape}. " f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " f"Transforms: {transforms}. ") if expected_out_dtype != val_aval.dtype: raise ValueError( "Invalid dtype for `swap`. " f"Ref dtype: {expected_out_dtype}. " f"Value dtype: {val_aval.dtype}. " ) out_aval = core.ShapedArray(expected_out_shape, expected_out_dtype) else: if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {WriteEffect(0)}) swap_p.def_effectful_abstract_eval(_swap_abstract_eval) def _addupdate_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): out_shape = _shape_after_transforming(ref_aval.shape, transforms) out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) if out_shape != val_aval.shape: raise ValueError( "Invalid shape for `addupdate`. " f"Ref shape: {ref_aval.shape}. " f"Expected shape: {out_shape}. " f"Value shape: {val_aval.shape}. " f"Transforms: {transforms}. " ) if out_dtype != val_aval.dtype: raise ValueError("Invalid dtype for `addupdate`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") else: # Check that the transforms are valid if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") return [], {AccumEffect(0)} addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) ## Pretty printing for `get` and `swap` in jaxprs pp_ref_var = partial(pp.color, intensity=pp.Intensity.NORMAL, foreground=pp.Color.GREEN) def _pp_transforms( context: core.JaxprPpContext, transforms: tuple[Transform, ...], ): if not transforms: return pp.text("[...]") return pp.concat( [transform.pretty_print(context) for transform in transforms] ) def pp_ref_transforms(context: core.JaxprPpContext, ref, transforms): return pp_ref_var( pp.concat([ pp.text(core.pp_var(ref, context)), _pp_transforms(context, transforms), ]) ) def _get_pp_rule(eqn, context, settings) -> pp.Doc: # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars x, *flat_idx = eqn.invars transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) return pp.concat( [lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)] ) core.pp_eqn_rules[get_p] = _get_pp_rule def _swap_pp_rule(eqn, context, settings) -> pp.Doc: y, = eqn.outvars x, v, *flat_idx = eqn.invars transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return pp.concat([ pp_ref_transforms(context, x, transforms), pp.text(" <- "), pp.text(core.pp_var(v, context)), ]) else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` x_i = pp_ref_transforms(context, x, transforms) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '), x_i, pp.text(', '), pp.text(core.pp_var(v, context))]) core.pp_eqn_rules[swap_p] = _swap_pp_rule def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: del settings # pretty-print ` = addupdate x i v` as `x[i] += v` () = eqn.outvars x, v, *flat_idx = eqn.invars transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) return pp.concat([ pp_ref_transforms(context, x, transforms), pp.text(" += "), pp.text(core.pp_var(v, context)), ]) core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule ## get/swap/addupdate JVP rules def _get_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, *idx = primals assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, *_ = tangents assert isinstance(ref_tangent.aval, AbstractRef) return (get_p.bind(ref_primal, *idx, **params), get_p.bind(ref_tangent, *idx, **params)) ad.primitive_jvps[get_p] = _get_jvp def _swap_jvp(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals assert isinstance(ref_primal.aval, AbstractRef) ref_tangent, x_tangent, *_ = tangents # if type(ref_tangent) is ad_util.Zero: # raise Exception("you're an idiot") assert isinstance(ref_tangent.aval, AbstractRef) x_tangent = ad_util.instantiate(x_tangent) return (swap_p.bind(ref_primal, x_primal, *idx, **params), swap_p.bind(ref_tangent, x_tangent, *idx, **params)) ad.primitive_jvps[swap_p] = _swap_jvp def addupdate_jvp_rule(primals: list[Any], tangents: list[Any], **params: Any): ref_primal, x_primal, *idx = primals ref_tangent, x_tangent, *_ = tangents x_tangent = ad_util.instantiate(x_tangent) addupdate_p.bind(ref_primal, x_primal, *idx, **params) addupdate_p.bind(ref_tangent, x_tangent, *idx, **params) return [], [] ad.primitive_jvps[addupdate_p] = addupdate_jvp_rule ## get/swap/addupdate transpose rules def _get_transpose(g, ref, *idx, **params): # get transpose is addupdate if type(g) is not ad_util.Zero: addupdate_p.bind(ref, g, *idx, **params) return [None] + [None] * len(idx) ad.primitive_transposes[get_p] = _get_transpose def _swap_transpose(g, ref, x, *idx, **params): del x # old value doesn't matter anymore # swap transpose is swap x_bar = swap_p.bind(ref, ad_util.instantiate(g), *idx, **params) return [None, x_bar] + [None] * len(idx) ad.primitive_transposes[swap_p] = _swap_transpose def addupdate_transpose(cts_in, ref, x, *idx, **params): # addupdate transpose is get del cts_in, x g = get_p.bind(ref, *idx, **params) return [None, g] + [None] * len(idx) ad.primitive_transposes[addupdate_p] = addupdate_transpose ## get/swap/addupdate partial_eval_custom rules def _state_partial_eval_custom(prim, saveable, unks_in, inst_in, eqn): if any(unks_in): res = [v for v, inst in zip(eqn.invars, inst_in) if not inst] return None, eqn, [True] * len(eqn.outvars), [True] * len(eqn.outvars), res elif saveable(prim, *[var.aval for var in eqn.invars], **eqn.params): return eqn, None, [False] * len(eqn.outvars), [False] * len(eqn.outvars), [] res = [v for v, inst in zip(eqn.invars, inst_in) if not inst] return eqn, eqn, [False] * len(eqn.outvars), [True] * len(eqn.outvars), res pe.partial_eval_jaxpr_custom_rules[get_p] = partial(_state_partial_eval_custom, get_p) pe.partial_eval_jaxpr_custom_rules[swap_p] = partial(_state_partial_eval_custom, swap_p) pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial( _state_partial_eval_custom, addupdate_p) ## get/swap/addupdate batching rules def _batch_indexer( indexer: indexing.NDIndexer, dims, axis_size: int, ref_shape: tuple[int, ...], ref_dim: int | batching.NotMapped, idx_is_batched: bool, ) -> indexing.NDIndexer: """Converts a batched indexer into an unbatched one. This function handles the complexity of `vmap`-style batching where either the `ref` being indexed, the indexer, or both may have batched dimensions. The goal is to produce a new indexer that acts as if applied in a batched context, but without actual batching, enabling downstream code to process it as usual. If any index in `indexer` is batched, all array indexers are normalized. If the array indexer contains a batched dimension, the dimension is moved to the front (axis 0). If the array indexer not batched, it is broadcasted to include a batch dimension at the front. This is to guarantee that all array indexers are still of the same shape. Slices are passed through unchanged unless they contain dynamic elements and are themselves batched, which is currently unsupported. If `ref` is batched (`ref_dim` is not `NotMapped`), we simulate per-example indexing by inserting a new iota array at the position corresponding to `ref_dim` in the indexer. It is worth noting that if the array indexers in the original indexer are contiguous, but become non-contiguous in the new indexer due to the insertion of the iota, the dimensions corresponding to the array indexers will be moved to the front in the indexing result. The batched dimension will be at axis 0, while the dimensions corresponding to the array indexers in the original indexer will start from axis 1. This behavior would cause a mismatch between the original indexer and the new indexer. Callers must take this behavior into account and properly transpose the arrays involved to avoid this mismatch. Args: indexer: An `NDIndexer` that indexes into `ref`. dims: A pytree with the same structure as `indexer`, indicating which dimension (if any) is batched for each array indexer. axis_size: Size of the batch dimension. ref_shape: Shape of `ref`. ref_dim: The dimension of `ref` that is batched (if any). idx_is_batched: Whether any index in the `indexer` is batched. """ indices = indexer.indices indices_dims = dims.indices new_indices: list[Array | indexing.Slice | int] = [] new_integer_indexer_shape = (axis_size, *indexer.int_indexer_shape) for idx, dim in zip(indices, indices_dims): if idx_is_batched: # If at least one of the idx is batched, we broadcast them all and move the # batch dim to the front. if isinstance(idx, indexing.Slice): # size is static, but start can be dynamic # Check if start is static (which it can be) is_static_slice = len(tree_util.tree_leaves(idx)) == 0 if is_static_slice: new_indices.append(idx) continue dim = dim.start if dim is batching.not_mapped: # Broadcasting the slice is free (the start index stays the same) new_indices.append(idx) else: raise NotImplementedError( f"No support for vmapping over nontrivial slices just yet: {idx}") else: # Check if we are indexing with a scalar or not. If we are indexing # with a scalar and we are not batched, we can avoid broadcasting it. assert hasattr(idx, "shape") if not idx.shape: if dim is not batching.not_mapped: assert idx.shape == (axis_size,) idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, (0,)) new_indices.append(idx) else: if dim is batching.not_mapped: bcast_dims = tuple(range(1, np.ndim(idx) + 1)) idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) else: idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type] new_indices.append(idx) else: if ref_dim is not batching.not_mapped: if not isinstance(idx, indexing.Slice): assert hasattr(idx, "shape") if idx.shape: bcast_dims = tuple(range(1, np.ndim(idx) + 1)) idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) new_indices.append(idx) if ref_dim is not batching.not_mapped: iota = lax.broadcasted_iota(np.dtype('int32'), new_integer_indexer_shape, 0) new_indices.insert(ref_dim, iota) return indexing.NDIndexer( tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True ) def _get_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} ref, *flat_idxs = batched_args ref_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) flat_indexers, tree = tree_util.tree_flatten(new_indexers) is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) int_indexers_contiguous = bool( np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) new_int_indexers_contiguous = bool( np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) ) out = get_p.bind(ref, *flat_indexers, tree=tree) if not int_indexers_contiguous: # will always be moved to the front out_bdim = 0 else: # originally not going to be moved to the front if new_int_indexers_contiguous: # now not going to be moved to the front out_bdim = is_new_int_indexing.index(True) else: # now going to be moved to the front original_pos = is_int_indexing.index(True) array_indexer_shape = new_indexers[0].int_indexer_shape array_indexer_len = len(array_indexer_shape) transpose_order = list(range(len(out.shape))) transpose_order = ( transpose_order[0], *transpose_order[array_indexer_len:array_indexer_len+original_pos], *transpose_order[1:array_indexer_len], *transpose_order[array_indexer_len+original_pos:], ) out = lax.transpose(out, transpose_order) out_bdim = 0 return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap def _swap_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) ref_is_batched = ref_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) if not ref_is_batched: raise Exception("performing a set/swap operation with vmapped value on " "an unbatched mutable array reference " f"of type {core.typeof(ref)}. Move the mutable array to be " "an argument to the vmapped function?") if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) flat_indexers, tree = tree_util.tree_flatten(new_indexers) is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) int_indexers_contiguous = bool( np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) new_int_indexers_contiguous = bool( np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) ) if not new_int_indexers_contiguous: # will be moved to the front batched_dim_in_result = 0 else: batched_dim_in_result = is_new_int_indexing.index(True) + 0 if not val_is_batched: if ref_is_batched or idx_is_batched: val = batching.broadcast(val, axis_size, batched_dim_in_result) else: val = batching.moveaxis(val, val_dim, batched_dim_in_result) transpose_order_inversed = None # Originally not going to be moved to the front, but now going to be moved to # the front. if int_indexers_contiguous and not new_int_indexers_contiguous: original_pos = is_int_indexing.index(True) array_indexer_shape = new_indexers[0].int_indexer_shape array_indexer_len = len(array_indexer_shape) transpose_order = list(range(len(val.shape))) transpose_order = ( transpose_order[0], *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], *transpose_order[1:1+original_pos], *transpose_order[(1+original_pos)+(array_indexer_len-1):], ) val = val.transpose(transpose_order) transpose_order_inversed = np.argsort(transpose_order) out = swap_p.bind(ref, val, *flat_indexers, tree=tree) # `val` should not be transposed, but we needed to transpose it to match # `swap_p`. As a result, the output of `swap_p` is also transposed. Now we # need to transpose it back. if transpose_order_inversed is not None: out = out.transpose(transpose_order_inversed) return out, batched_dim_in_result batching.primitive_batchers[swap_p] = _swap_vmap def _addupdate_vmap(batched_args, batched_dims, *, tree): axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) if d is not batching.not_mapped} ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) indexers_dims = tree_util.tree_unflatten(tree, flat_idx_dims) ref_is_batched = ref_dim is not batching.not_mapped val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) flat_indexers, tree = tree_util.tree_flatten(new_indexers) is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0]) int_indexers_contiguous = bool( np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) is_new_int_indexing, _, _ = indexing.unpack_ndindexer(new_indexers[0]) new_int_indexers_contiguous = bool( np.all(np.diff(np.where(is_new_int_indexing)[0]) == 1) ) if not new_int_indexers_contiguous: # will be moved to the front batched_dim_in_result = 0 else: batched_dim_in_result = is_new_int_indexing.index(True) if not val_is_batched: if ref_is_batched or idx_is_batched: val = batching.broadcast(val, axis_size, batched_dim_in_result) else: val = batching.moveaxis(val, val_dim, batched_dim_in_result) # Originally not going to be moved to the front, but now going to be moved to # the front. if int_indexers_contiguous and not new_int_indexers_contiguous: original_pos = is_int_indexing.index(True) array_indexer_shape = new_indexers[0].int_indexer_shape array_indexer_len = len(array_indexer_shape) transpose_order = list(range(len(val.shape))) transpose_order = ( transpose_order[0], *transpose_order[1+original_pos:(1+original_pos)+(array_indexer_len-1)], *transpose_order[1:1+original_pos], *transpose_order[(1+original_pos)+(array_indexer_len-1):], ) val = val.transpose(transpose_order) return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] batching.primitive_batchers[addupdate_p] = _addupdate_vmap # Currently, JAX doesn't have a primitive that does an equal-rank broadcast. # We could use `jnp.broadcast_to` but that lowers to squeezing, # then broadcast_in_dim. Triton has an equal-rank broadcast (`tl.broadcast_to`) # so in the lowering, we have to expand out those squeezed dimensions again. # Having a simple `broadcast_to` primitive allows us to lower directly # to `tl.broadcast_to`. broadcast_to_p = core.Primitive('broadcast_to') def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: import jax.numpy as jnp a = jnp.asarray(a) if a.shape == shape: return a return broadcast_to_p.bind(a, shape=shape) @broadcast_to_p.def_impl def _broadcast_to_impl(a, *, shape): import jax.numpy as jnp return jnp.broadcast_to(a, shape) @broadcast_to_p.def_abstract_eval def _broadcast_to_abstract_eval(aval, *, shape): return core.ShapedArray(shape, aval.dtype) mlir.register_lowering( broadcast_to_p, mlir.lower_fun(_broadcast_to_impl, False) ) # === AD rules for mutable arrays === def _mut_jvp(primals, tangents): (init_val,), (init_val_dot,) = primals, tangents primal_out = core.mutable_array_p.bind(init_val) if type(init_val_dot) is ad_util.Zero: tangent_out = core.mutable_array_p.bind(ad_util.zeros_like_aval(init_val_dot.aval)) else: tangent_out = core.mutable_array_p.bind(init_val_dot) return primal_out, tangent_out ad.primitive_jvps[core.mutable_array_p] = _mut_jvp ad.defjvp(core.freeze_p, lambda g, _: core.freeze(g))