2025-08-11 12:24:21 +08:00

1213 lines
50 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from collections.abc import Callable, Sequence
import dataclasses
from functools import partial
from typing import Any, Union
import numpy as np
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import linear_util as lu
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src import mesh as mesh_lib
from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p
from jax._src.core import Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node, PyTreeDef)
from jax._src.typing import Array
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache, tuple_insert)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Jumbles
# i:(Fin 3) => f32[[3, 1, 4].i]
@dataclasses.dataclass(frozen=True)
class JumbleTy:
binder: core.Var
length: int | Tracer | core.Var
elt_ty: core.DShapedArray
def __repr__(self) -> str:
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
replace = dataclasses.replace
# [3, 1, 4].i
@dataclasses.dataclass(frozen=True)
class IndexedAxisSize:
idx: core.Var
lengths: Array | core.Var | Tracer
def __repr__(self) -> str:
return f'{self.lengths}.Var{id(self.idx)}'
replace = dataclasses.replace
# Jumble(aval=a:3 => f32[[3 1 4].a],
# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Jumble:
aval: JumbleTy
data: Array
# To vmap over a jumble, one must specify the axis as JumbleAxis.
class JumbleAxis: pass
jumble_axis = JumbleAxis()
# As a temporary measure before we have more general JITable / ADable interfaces
# (analogues to vmappable), to enable Jumbles to be used with other
# transformations and higher-order primitives (primarily jit, though also grad
# with allow_int=True) we register them as pytrees.
# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
def _jumble_flatten(jumble):
lengths = []
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
if type(d) is IndexedAxisSize else d
for d in jumble.aval.elt_ty.shape]
elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
aval = jumble.aval.replace(elt_ty=elt_ty)
return (lengths, jumble.data), aval
def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]:
stacked_axis = dim.stacked_axis
ragged_axes = dim.ragged_axes
if len(ragged_axes) != 1:
raise ValueError('Multiple ragged axes not yet implemented.')
ragged_axis_dim = ragged_axes[0][0]
ragged_axis_length = ragged_axes[0][1]
return stacked_axis, ragged_axis_dim, ragged_axis_length
def _jumble_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
if type(d) is IndexedAxisSize else d
for d in aval.elt_ty.shape]
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
aval = aval.replace(elt_ty=elt_ty)
return Jumble(aval, data)
register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten)
def _jumble_result(axis_size, stacked_axis, ragged_axes, x):
binder = core.Var(core.ShapedArray((), np.dtype('int32')))
if stacked_axis != 0:
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
shape = list(x.shape)
del shape[0]
for ragged_axis, segment_lens in ragged_axes:
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Jumble(JumbleTy(binder, axis_size, elt_ty), x)
@dataclasses.dataclass(frozen=True)
class RaggedAxis:
stacked_axis: int
# For each axis, we store its index and the corresponding segment lengths.
# For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: tuple[tuple[int, Any], ...]
@property
def size(self):
# TODO(mattjj, axch): All the segment lengths arrays better be the
# same length!
return len(self.ragged_axes[0][1])
def move_stacked_axis(self: RaggedAxis, dst: int) -> RaggedAxis:
# Assumes that all stored and incoming axes are already canonicalized
def move_axis(ax):
if self.stacked_axis > ax and ax >= dst:
return ax + 1
if self.stacked_axis < ax and ax <= dst:
return ax - 1
return ax
new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes)
return RaggedAxis(dst, new_axes)
def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis:
new_ragged_axes = []
for idx, old_idx in enumerate(perm):
for ax, size in dim.ragged_axes:
if old_idx == ax:
new_ragged_axes.append((idx, size))
break
return _sorted_ragged_axis(dim.stacked_axis, new_ragged_axes)
def _sorted_ragged_axis(stacked_axis, ragged_axes):
return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0])))
def make_batch_axis(
ndim: int,
stacked_axis: int,
ragged_axes: list[tuple[int, Array | core.Var]],
) -> int | RaggedAxis:
if ragged_axes:
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical)
else:
return canonicalize_axis(stacked_axis, ndim)
def bdim_as_shape(
bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape:
if isinstance(bdim, RaggedAxis):
result = list(data_shape)
binder = core.Var(core.ShapedArray((), np.dtype('int32')))
for ragged_axis, segment_lens in bdim.ragged_axes:
result[ragged_axis] = IndexedAxisSize(binder, segment_lens)
return tuple(result)
else:
return data_shape
def shape_as_bdim(
stacked_axis: int, data_shape: core.Shape) -> int | RaggedAxis:
# This assumes that there is only one binder in the data_shape.
ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape)
if isinstance(size, IndexedAxisSize)]
return make_batch_axis(len(data_shape), stacked_axis, ragged_axes)
def _update_annotation(
f: lu.WrappedFun, orig_type: core.InputType | None,
axis_size: core.AxisSize, axis_name: AxisName,
explicit_in_dims: Sequence[int | RaggedAxis | None],
segment_lens: Sequence[Array],
) -> lu.WrappedFun:
if orig_type is None: return f
# By convention, `explicit_in_dims` only accounts for explicit arguments.
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
# We need to:
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
# * for each element of `segment_lengths`, add a new explicit binder for it;
# * drop other implicit binders, replacing DBIdx which refer to them with
# Name objects;
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
# size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim,
# add batch axis (int if corresponding segment_lengths is concrete, Name if
# not);
# * generate full in_type with implicit args too.
class Name:
def __init__(self, a): self.a = a
names = [Name(a) for a, _ in orig_type]
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d
for d in a.shape))
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
new_avals = [core.get_aval(s) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, RaggedAxis):
raise NotImplementedError
else:
new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is Name}
expl_names = set(map(Name, new_avals))
impl_names = mentioned - expl_names # type: ignore
impl_part = [(n.a, False) for n in impl_names] # type: ignore
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
if type(a) is core.DShapedArray else a, True) for a in new_avals]
return lu.annotate(f, (*impl_part, *expl_part))
### vmappable typeclass
Vmappable = Any
Elt = Any
MapSpec = Any
AxisSize = Any
MeshAxis = Any
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
MakeIotaHandler = Callable[[AxisSize], Array]
def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
elif type(x) is Jumble:
if spec is not jumble_axis:
raise TypeError("jumble input without using jumble_axis in_axes spec")
ias: IndexedAxisSize # Not present in the AxisSize union in core.py
(d, ias), = ((i, sz) # type: ignore
for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)])
return BatchTracer(trace, x.data, batch_axis)
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis):
# TODO(mvoz): A vaguely questionable assumption that it is always
# sound to have a 0 axis here. This is true for the current use cases
# and comes from how we handle intermediary products of jumbles in
# vmap.
return BatchTracer(trace, x, 0, source_info_util.current())
# TODO(mvoz): This is a terrible place to fall into if you pass
# a non jumble type in, make it clearer what went wrong.
assert False, f'Unexpected type in ELT? {type(x)}'
to_elt_handlers: dict[type, ToEltHandler] = {}
def from_elt(trace: BatchTrace, axis_size: AxisSize, mesh_axis: MeshAxis,
i: int, x: Elt, spec: MapSpec) -> Vmappable:
handler = from_elt_handlers.get(type(x))
if handler:
def _cont(axis_size, elt, axis):
return from_elt(trace, axis_size, mesh_axis, i, elt, axis)
return handler(_cont, axis_size, x, spec)
val, bdim = trace.to_batch_info(x)
if type(bdim) is RaggedAxis:
if spec is not jumble_axis:
# TODO(mattjj): improve this error message
raise TypeError("ragged output without using jumble_axis out_axes spec")
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else:
try:
return matchaxis(trace.axis_data.name, axis_size, mesh_axis,
bdim, spec, val)
except SpecMatchError:
raise SpecMatchError(i, x.batch_dim, spec) from None
from_elt_handlers: dict[type, FromEltHandler] = {}
def make_iota(axis_size: AxisSize) -> Array:
# Callers of this utility, via batch() or vtile(), must be in a context
# where lax is importable.
from jax import lax # pytype: disable=import-error
handler = make_iota_handlers.get(type(axis_size))
if handler:
return handler(axis_size)
else:
return lax.iota('int32', int(axis_size))
make_iota_handlers: dict[type, MakeIotaHandler] = {}
def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
to_elt: Callable, from_elt: Callable,
make_iota: Callable | None):
vmappables[data_type] = (spec_type, axis_size_type)
spec_types.add(spec_type)
to_elt_handlers[data_type] = to_elt
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: dict[type, tuple[type, type]] = {}
spec_types: set[type] = {JumbleAxis}
def unregister_vmappable(data_type: type) -> None:
_, axis_size_type = vmappables.pop(data_type)
del to_elt_handlers[data_type]
del from_elt_handlers[data_type]
if axis_size_type in make_iota_handlers:
del make_iota_handlers[axis_size_type]
global spec_types
spec_types = (
{JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
)
def is_vmappable(x: Any) -> bool:
return type(x) is Jumble or type(x) in vmappables
@lu.transformation_with_aux2
def flatten_fun_for_vmap(f: Callable,
store: lu.Store, in_tree: PyTreeDef, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
store.store(out_tree)
return ans
# Propagate ragged masking rules from invars to outvars
# rule([params], [raggedness_per_invar], outvars) ->
# [raggedness_per_invar, raggedness_per_outvar]
RaggedMaskingRule = Callable[
[list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]]
]
ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars):
# TODO(mvoz): A util for getting the ragged representations
first_invar_raggedness = invar_raggedness[0]
for other_invar_raggedness in invar_raggedness[1:]:
if other_invar_raggedness != first_invar_raggedness:
raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}')
outvar_raggedness = [first_invar_raggedness] * len(outvars)
return invar_raggedness, outvar_raggedness
def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars):
if any(invar_raggedness):
raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
return invar_raggedness, [None] * len(outvars)
def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars):
return invar_raggedness, [None] * len(outvars)
def ragged_mask_transfer_identity(
eqn_params, invar_raggedness, outvar_raggedness
):
assert len(invar_raggedness) == 1, invar_raggedness
outvar_raggedness = invar_raggedness
return invar_raggedness, outvar_raggedness
### tracer
# TODO(mattjj): use a special sentinel type rather than None
NotMapped = type(None)
not_mapped = None
class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim', 'source_info']
def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
source_info: source_info_util.SourceInfo | None = None):
if config.enable_checks.value:
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
if type(batch_dim) is int:
aval = core.get_aval(val)
assert 0 <= batch_dim < len(aval.shape)
self._trace = trace
self.val = val
self.batch_dim = batch_dim
self.source_info = source_info
@property
def aval(self):
aval = core.get_aval(self.val)
if self._trace.axis_data.spmd_name is not None:
if config._check_vma.value:
aval = aval.update(
vma=aval.vma - frozenset(self._trace.axis_data.spmd_name))
if self.batch_dim is not_mapped:
return aval
elif type(self.batch_dim) is int:
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
elif type(self.batch_dim) is RaggedAxis:
new_aval = core.mapped_aval(
aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
shape = list(new_aval.shape) # pytype: disable=attribute-error
for ragged_axis, segment_lengths in self.batch_dim.ragged_axes:
size_tracer = BatchTracer(self._trace, segment_lengths, 0)
if self.batch_dim.stacked_axis < ragged_axis:
ragged_axis -= 1
shape[ragged_axis] = size_tracer
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
weak_type=aval.weak_type)
def full_lower(self):
if self.batch_dim is not_mapped:
return core.full_lower(self.val)
else:
return self
def _origin_msg(self):
if self.source_info is None:
return ""
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
f"\n {source_info_util.summarize(self.source_info)}")
def _contents(self):
return [('val', self.val), ('batch_dim', self.batch_dim)]
def get_referent(self):
if self.batch_dim is None or type(self.batch_dim) is int:
return core.get_referent(self.val)
else: # TODO(mattjj): could handle the RaggedAxis case?
return self
@dataclasses.dataclass(frozen=True)
class AxisData:
name : Any
size : Any
# Only one of spmd_axis_name and explicit_mesh_axis is set.
spmd_name : Any
explicit_mesh_axis: Any
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
val = axis_data.explicit_mesh_axis
# TODO(yashkatariya): Preserve unreduced here using
# `orig_sharding.spec.update`
new_spec = P(*tuple_insert(orig_sharding.spec, axis, val))
return NamedSharding(orig_sharding.mesh, new_spec)
class BatchTrace(Trace):
def __init__(self, parent_trace, tag, axis_data):
super().__init__()
self.parent_trace = parent_trace
assert isinstance(axis_data, AxisData)
self.axis_data = axis_data
self.tag = tag
def to_batch_info(self, val):
if isinstance(val, BatchTracer) and val._trace.tag is self.tag:
return val.val, val.batch_dim
else:
return val, not_mapped
def process_primitive(self, p, tracers, params):
if config.dynamic_shapes.value:
p.abstract_eval(*(map(core.get_aval, tracers)), **params)
vals_in, dims_in = unzip2(map(self.to_batch_info, tracers))
args_not_mapped = all(bdim is not_mapped for bdim in dims_in)
if p in fancy_primitive_batchers:
if (args_not_mapped
and p in skippable_batchers
and not any(self.axis_data.name == axis_name
for axis_name in skippable_batchers[p](params))):
# no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
else:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = fancy_primitive_batchers[p](
self.axis_data, vals_in, dims_in, **params)
elif args_not_mapped:
# no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
elif p in primitive_batchers:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params)
else:
raise NotImplementedError("Batching rule for '{}' not implemented".format(p))
src = source_info_util.current()
if p.multiple_results:
with core.set_current_trace(self.parent_trace): # val_out may be lazy map
return [BatchTracer(self, x, d, src) if d is not not_mapped else x
for x, d in zip(val_out, dim_out)]
else:
return (BatchTracer(self, val_out, dim_out, src)
if dim_out is not not_mapped else val_out)
def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=params.get('name', f.__name__))
vals, dims = unzip2(map(self.to_batch_info, tracers))
segment_lens, dims = indirectify_ragged_axes(dims)
f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims))
f_ = _update_annotation(
f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens)
with core.set_current_trace(self.parent_trace):
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out())
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2(map(self.to_batch_info, tracers))
# The logic for the dimension math below is as follows:
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
# ║ d / in_axis ║ None ║ int ║
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
# ║ None ║ No extra axis, so in_axis unaffected ║
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
# When both d and in_axis are defined then:
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
new_in_axes = tuple(
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
for d, in_axis in zip(dims, params['in_axes']))
new_dims = tuple(
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
for d, in_axis in zip(dims, params['in_axes']))
f, dims_out = batch_subtrace(f, self.tag, self.axis_data, new_dims)
out_axes_thunk = params['out_axes_thunk']
# NOTE: This assumes that the choice of the dimensions over which outputs
# are batched is entirely dependent on the function and not e.g. on the
# data or its shapes.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
with core.set_current_trace(self.parent_trace):
vals_out = map_primitive.bind(f, *vals, **new_params)
dims_out_ = [d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
for d, out_axis in zip(dims_out(), out_axes_thunk())]
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)]
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims)
out_vals = prim.bind_with_trace(self.parent_trace, (fun, jvp, *in_vals),
dict(symbolic_zeros=symbolic_zeros))
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees,
symbolic_zeros): # pytype: disable=signature-mismatch
in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]]
fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims)
def bwd_in_dims():
_, _, input_fwds = out_trees()
pruned_dims = iter(out_dims2())
full_dims = [next(pruned_dims) if f is None else in_dims[f] for f in input_fwds]
return [*full_dims, *pruned_dims]
bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, bwd_in_dims, in_dims)
out_vals = prim.bind_with_trace(self.parent_trace,
(fun, fwd, bwd) + tuple(in_vals),
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
_, res_tree, input_fwds = out_trees()
num_res = res_tree.num_leaves - sum(f is not None for f in input_fwds)
_, out_dims = split_list(out_dims, [num_res])
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
### API for batching callables with vmappable inputs and outputs
def batch(fun: lu.WrappedFun, axis_data,
in_dims, out_dim_dests) -> lu.WrappedFun:
# we split up _batch_inner and _batch_outer for the leak checker
f = _batch_inner(fun, axis_data, out_dim_dests)
return _batch_outer(f, axis_data, in_dims)
@lu.transformation2
def _batch_outer(f, axis_data, in_dims, *in_vals):
tag = TraceTag()
with source_info_util.transform_name_stack('vmap'):
outs, trace = f(tag, in_dims, *in_vals)
with core.ensure_no_leaks(trace): del trace
return outs
@lu.transformation2
def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
source_info_util.current()))
with core.set_current_trace(parent_trace):
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = f(*in_tracers)
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
range(len(outs)), outs, out_dim_dests)
return out_vals, trace
# NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it.
def vtile(f_flat: lu.WrappedFun,
in_axes_flat: tuple[int | None, ...],
out_axes_flat: tuple[int | None, ...],
tile_size: int | None,
axis_name: AxisName):
@curry
def tile_axis(arg, axis: int | None, tile_size):
if axis is None:
return arg
shape = list(arg.shape)
shape[axis:axis+1] = [tile_size, shape[axis] // tile_size]
return arg.reshape(shape)
def untile_axis(out, axis: int | None):
if axis is None:
return out
shape = list(out.shape)
shape[axis:axis+2] = [shape[axis] * shape[axis+1]]
return out.reshape(shape)
@lu.transformation2
def _map_to_tile(f, *args_flat):
sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None)
tile_size_ = tile_size or next(sizes, None)
assert tile_size_ is not None, "No mapped arguments?"
outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat))
return map(untile_axis, outputs_flat, out_axes_flat)
axis_data = AxisData(axis_name, tile_size, None, None)
return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
### API for batching functions with jaxpr type inputs and outputs
@lu.transformation_with_aux2
def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
with core.set_current_trace(trace):
in_dims = in_dims() if callable(in_dims) else in_dims
in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = f(*in_tracers)
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
segment_lens, out_dims = indirectify_ragged_axes(out_dims)
store.store(out_dims)
return (*segment_lens, *out_vals)
def indirectify_ragged_axes(dims):
if not any(type(d) is RaggedAxis for d in dims):
return [], dims
axis_map : dict[int, tuple[Array, pe.DBIdx]] = collections.OrderedDict()
def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis:
new_ragged_axes = []
for ragged_axis, segment_lengths in d.ragged_axes:
_, dbidx = axis_map.setdefault(
id(core.get_referent(segment_lengths)),
(segment_lengths, pe.DBIdx(len(axis_map))))
new_ragged_axes.append((ragged_axis, dbidx))
return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes))
new_dims = [canonicalize_segment_lengths(d)
if isinstance(d, RaggedAxis) else d for d in dims]
segment_lens = [s for s, _ in axis_map.values()]
return segment_lens, new_dims
def indirectify_ragged_axes_against_inputs_outputs(dims, in_vals, out_vals):
def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis:
new_ragged_axes = []
for ragged_axis, segment_lengths in d.ragged_axes:
key = id(core.get_referent(segment_lengths))
value = _locate_value(key, in_vals, out_vals)
new_ragged_axes.append((ragged_axis, value))
return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes))
new_dims = [canonicalize_segment_lengths(d)
if isinstance(d, RaggedAxis) else d for d in dims]
return new_dims
def _locate_value(key, in_vals, out_vals):
for ix, candidate in enumerate(in_vals):
if key == id(candidate):
return pe.InDBIdx(ix)
for ix, candidate in enumerate(out_vals):
if key == id(candidate):
return pe.OutDBIdx(ix)
assert False, "Could not find segment lengths"
def resolve_ragged_axes(vals, dims):
idxs = {lengths_idx.val for d in dims if isinstance(d, RaggedAxis)
for (_, lengths_idx) in d.ragged_axes}
dims = [RaggedAxis(d.stacked_axis,
tuple((ragged_axis, vals[lengths_idx.val])
for ragged_axis, lengths_idx in d.ragged_axes))
if isinstance(d, RaggedAxis) else d for d in dims]
vals = [x for i, x in enumerate(vals) if i not in idxs]
return vals, dims
def resolve_ragged_axes_against_inputs_outputs(in_vals, out_vals, dims):
def fetch(idx):
if isinstance(idx, pe.InDBIdx):
return in_vals[idx.val]
else:
assert isinstance(idx, pe.OutDBIdx)
return out_vals[idx.val]
dims = [RaggedAxis(d.stacked_axis,
tuple((ragged_axis, fetch(lengths_idx))
for ragged_axis, lengths_idx in d.ragged_axes))
if isinstance(d, RaggedAxis) else d for d in dims]
return dims
### API for batching jaxprs
# TODO(axch): parameterize RaggedAxis annotations by a type parameter so as to
# indicate whether we're dealing with instances that contain Arrays or DBIdx.
# Can reuse same pattern for all dynamic shape stuff.
def batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr,
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]:
# This is only ever used in pjit. The difference vs batch_jaxpr is that
# batch_jaxpr2 lets the callee decide which outputs are batched and what
# their batch axes are; whereas batch_jaxpr has to obey caller-imposed
# consistency constraints, such as type-agreement across arms of a
# `lax.cond`, or input-output agreement for the body of a `lax.scan`.
return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes))
@weakref_lru_cache
def _batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr,
axis_data,
in_axes: tuple[int | NotMapped | RaggedAxis, ...],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
in_axes2, avals_in = unzip2([
handle_ragged(closed_jaxpr.in_avals, dim, aval)
if isinstance(dim, RaggedAxis) else (dim, aval)
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
avals_in2 = []
for aval, b in unsafe_zip(avals_in, in_axes2):
if b is not_mapped:
avals_in2.append(aval)
else:
aval = core.unmapped_aval(
axis_data.size, b, aval, axis_data.explicit_mesh_axis)
if axis_data.spmd_name is not None:
if config._check_vma.value:
aval = aval.update(vma=aval.vma | frozenset(axis_data.spmd_name)) # type: ignore
avals_in2.append(aval)
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
aval: core.ShapedArray) -> tuple[int, core.ShapedArray]:
new_shape = list(aval.shape)
for i, dbi in dim.ragged_axes:
new_shape[i - (dim.stacked_axis < i)] = in_avals[dbi.val].dtype.bound
new_aval = aval.update(shape=tuple(new_shape))
return dim.stacked_axis, new_aval
def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst)
def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
assert (isinstance(instantiate, bool) or
isinstance(instantiate, (list, tuple)) and
all(isinstance(b, bool) for b in instantiate))
if isinstance(instantiate, bool):
instantiate = [instantiate] * len(closed_jaxpr.out_avals)
in_axes = [0 if b else not_mapped for b in in_batched]
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest)
def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest))
@weakref_lru_cache
def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr,
axis_data: AxisData,
in_axes: Sequence[int], out_axes_dest: Sequence[int]):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
avals_in = [core.unmapped_aval(axis_data.size, b, aval,
axis_data.explicit_mesh_axis)
if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@lu.transformation_with_aux2
def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = f(*in_tracers)
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
store.store(new_out_axes)
return out_vals
@lu.transformation_with_aux2
def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes,
*in_vals):
out_vals = f(trace, in_axes, *in_vals)
out_axes = out_axes()
out_axes_dest = [(None if src is not_mapped else 0)
if dst is zero_if_mapped else dst
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
if len(out_axes_dest) != len(out_axes):
out_axis_dest, = out_axes_dest
out_axes_dest = [out_axis_dest] * len(out_axes)
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size,
axis_data.explicit_mesh_axis),
out_axes, out_axes_dest, out_vals)
out_batched = [dst is not None for dst in out_axes_dest]
store.store(out_batched)
return out_vals
@lu.transformation2
def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals):
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
else ax for x, ax in unsafe_zip(in_vals, in_dims)]
tag = TraceTag()
return f(tag, in_dims, *in_vals)
def _merge_bdims(x, y):
if x == y:
return x
elif x is not_mapped:
return y
elif y is not_mapped:
return x
else:
return x # arbitrary
class ZeroIfMapped: pass
zero_if_mapped = ZeroIfMapped()
### functions for handling custom_vjp
@lu.transformation_with_aux2
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
size = axis_data.size
mesh_axis = axis_data.explicit_mesh_axis
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
in_tracers = [val if dim is None else
SymbolicZero(core.mapped_aval(size, dim, val.aval))
if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
for val, dim in zip(in_vals, in_dims * 2)]
with core.set_current_trace(trace):
out_tracers: list[BatchTracer | SymbolicZero] = f(*in_tracers)
out_vals, out_dims = unzip2(map(trace.to_batch_info, out_tracers))
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis),
out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(_matchaxis_symzeros, trace.axis_data.name, size, mesh_axis),
out_tangent_bds, out_dims, out_tangents)
store.store(out_dims)
return out_primals + out_tangents
def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag,
axis_data: AxisData,
in_dims: Callable[[], Sequence[int | None]],
out_dim_dests: Sequence[int | None]) -> lu.WrappedFun:
axis_size = axis_data.size
axis_name = axis_data.name
mesh_axis = axis_data.explicit_mesh_axis
def new_bwd(*args):
in_dims_ = in_dims() if callable(in_dims) else in_dims
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
if type(x) is SymbolicZero else x
for x, dim in zip(args, in_dims_)]
in_dims_ = [None if type(x) is SymbolicZero else d
for x, d in zip(args, in_dims_)]
bwd_, out_dims_thunk = batch_subtrace(bwd, tag, axis_data, in_dims_)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis,
out_dims_thunk, out_dim_dests)
return bwd_.call_wrapped(*args)
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
@lu.transformation2
def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk,
out_dim_dests, *in_vals):
# this is like _match_axes, but we do reduce-sums as needed
out_vals = f(*in_vals)
return map(partial(_matchaxis_symzeros, axis_name, axis_size, mesh_axis,
sum_match=True),
out_dims_thunk(), out_dim_dests, out_vals)
def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
# TODO(mattjj): dedup with matchaxis
if isinstance(x, (Zero, SymbolicZero)):
if src == dst:
return x
elif type(src) == type(dst) == int:
aval = core.mapped_aval(sz, src, x.aval)
return type(x)(core.unmapped_aval(sz, dst, aval, mesh_axis))
elif src is not_mapped and dst is not not_mapped:
return type(x)(core.unmapped_aval(sz, dst, x.aval, mesh_axis))
elif dst is not_mapped and sum_match:
return type(x)(core.mapped_aval(sz, src, x.aval))
else:
raise ValueError((axis_name, x, src, dst))
else:
return matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=sum_match)
### utilities for defining primitives' batching rules
BatchingRule = Callable[
...,
tuple[Any, Union[int, None, tuple[Union[int, None], ...]]]
]
primitive_batchers : dict[core.Primitive, BatchingRule] = {}
# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args
fancy_primitive_batchers: dict[core.Primitive, Callable] = {}
# backwards compat shim. TODO: delete
class AxisPrimitiveBatchersProxy:
def __setitem__(self, prim, batcher):
def wrapped(axis_data, vals, dims, **params):
return batcher(axis_data.size, axis_data.name, None, vals, dims, **params)
fancy_primitive_batchers[prim] = wrapped
axis_primitive_batchers = AxisPrimitiveBatchersProxy()
# Presence in this table allows fancy batchers to be skipped by batch traces for
# irrelevant axes. The Callable takes the params and returns a list of relevant
# axes.
skippable_batchers : dict[core.Primitive, Callable] = {}
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
def vectorized_batcher(prim, batched_args, batch_dims, **params):
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
return prim.bind(*batched_args, **params), batch_dims[0]
def defbroadcasting(prim):
primitive_batchers[prim] = partial(broadcast_batcher, prim)
def broadcast_batcher(prim, args, dims, **params):
"""Process a primitive with built-in broadcasting.
Args:
args: the possibly-batched arguments
dims: list or tuple of the same length as `args`, where each
entry indicates the batching state of the corresponding entry to `args`:
either an int indicating the batch dimension, or else `not_mapped`
indicating no batching.
"""
assert len(args) > 1
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
if d is not not_mapped)
if all(core.definitely_equal_shape(shape, x.shape) and d == dim
for x, d in zip(args, dims) if np.ndim(x)):
# if there's only agreeing batch dims and scalars, just call the primitive
out = prim.bind(*args, **params)
return (out, (dim,) * len(out)) if prim.multiple_results else (out, dim)
else:
# We pass size of 1 here because (1) at least one argument has a real batch
# dimension and (2) all unmapped axes can have a singleton axis inserted and
# then rely on the primitive's built-in broadcasting.
args = [bdim_at_front(x, d, 1) if np.ndim(x) else x
for x, d in zip(args, dims)]
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def _handle_scalar_broadcasting(nd, x, d):
# Callers of this utility, via broadcast_batcher() or defbroadcasting(),
# must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
if d is not_mapped or nd == np.ndim(x):
return x
else:
return lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
def defreducer(prim, ident):
primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params):
def out_axis(axes, axis):
return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis))
operand, = batched_args
bdim, = batch_dims
if isinstance(bdim, int):
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = out_axis(axes, bdim)
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
elif isinstance(bdim, RaggedAxis):
assert ident is not None, "TODO Ragged batching a reduction requires an identity"
axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1)))
bdim_out = out_axis(axes, bdim.stacked_axis)
# For each ragged_axis, we either mask the operand there or append
# it to the set of axes that will be ragged in the result.
axes_to_mask = []
ragged_axes_out = []
for ragged_axis, segment_lengths in bdim.ragged_axes:
if ragged_axis in axes:
axes_to_mask.append((ragged_axis, segment_lengths))
else:
ragged_axes_out.append((out_axis(axes, ragged_axis), segment_lengths))
operand = mask_ragged_axes(
operand, ident, RaggedAxis(bdim.stacked_axis, tuple(axes_to_mask)))
result = prim.bind(operand, axes=axes, **params)
return result, make_batch_axis(operand.ndim, bdim_out, ragged_axes_out)
else:
assert False
def expand_dims_batcher(prim, args, dims, **params):
"""A batching rule for primitives that support matching leading batch
dimensions in all arguments.
"""
size, = {x.shape[bd] for x, bd in zip(args, dims) if bd is not not_mapped}
args = [bdim_at_front(x, bd, size) for x, bd in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array:
# TODO(mattjj, axch) Can we mask multiple axes more efficiently at
# once, rather than one at a time?
for ragged_axis, segment_lengths in axis_spec.ragged_axes:
this_axis_spec = RaggedAxis(
axis_spec.stacked_axis, ((ragged_axis, segment_lengths),))
operand = _mask_one_ragged_axis(operand, ident, this_axis_spec)
return operand
def _mask_one_ragged_axis(
operand: Array, ident, axis_spec: RaggedAxis) -> Array:
# Callers of this utility, via reducer_batcher() or defreducer(),
# must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time"
ragged_axis, segment_lengths = axis_spec.ragged_axes[0]
value = ident(operand.dtype)
positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis)
# TODO(mattjj, axch) can't get ._data, need to convert it
# lengths = lax.convert_element_type(segment_lengths._data, 'int32')
lengths = lax.convert_element_type(segment_lengths, 'int32')
limits = lax.broadcast_in_dim(
lengths, operand.shape, [axis_spec.stacked_axis])
mask = positions < limits
return lax.select(mask, operand, lax.broadcast(value, operand.shape))
def move_stacked_axis(operand, bdim, dst):
dst = canonicalize_axis(dst, operand.ndim)
if isinstance(bdim, int):
return moveaxis(operand, bdim, dst), dst
elif isinstance(bdim, RaggedAxis):
result = moveaxis(operand, bdim.stacked_axis, dst)
return result, bdim.move_stacked_axis(dst)
else:
raise TypeError(f"Unrecognized batch dimension type {bdim}")
### general utilities for manipulating axes on jaxpr types (not vmappables)
def broadcast(x, sz, axis, mesh_axis=None):
# Callers of this utility must be in a context where lax is importable.
from jax import lax # pytype: disable=import-error
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
x_aval = core.get_aval(x)
if x_aval.sharding.mesh.empty:
mesh_axis = None
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
sharding = x_aval.sharding.update(spec=new_spec)
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
# out how to ensure jaxpr arguments always have the context mesh.
with mesh_lib.use_abstract_mesh(sharding.mesh):
x = lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
if config._check_vma.value:
# TODO(yashkatariya,parkers): don't do this, fix during fixit week 2026
spmd_names = core.get_axis_env().spmd_axis_names
if len(spmd_names) > 1:
raise NotImplementedError
if spmd_names:
x = core.pvary(x, tuple(spmd_names))
return x
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
if dst == jumble_axis:
x = bdim_at_front(x, src, sz)
elt_ty = x.aval.update(shape=x.shape[1:])
aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))),
x.shape[0], elt_ty)
return Jumble(aval, x)
try:
_ = core.get_aval(x)
except TypeError as e:
raise TypeError(f"Output from batched function {x!r} with type "
f"{type(x)} is not a valid JAX type") from e
if src == dst:
return x
elif type(src) == type(dst) == int:
return moveaxis(x, src, dst)
elif src is not_mapped and dst is not not_mapped:
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1), mesh_axis)
elif dst is not_mapped and sum_match:
return x.sum(src)
else:
if (not isinstance(axis_name, core._TempAxisName) and
axis_name is not core.no_axis_name):
raise ValueError(f'vmap has mapped output ({axis_name=}) but out_axes is {dst}')
else:
raise SpecMatchError(None, None, None)
class SpecMatchError(Exception):
def __init__(self, leaf_idx, src, dst):
self.leaf_idx = leaf_idx
self.src = src
self.dst = dst
def bdim_at_front(x, bdim, size, mesh_axis=None):
if bdim is not_mapped:
return broadcast(x, size, 0, mesh_axis=mesh_axis)
else:
return moveaxis(x, bdim, 0)
def add_batched(batched_args, batch_dims):
bdx, bdy = batch_dims
x, y = batched_args
if bdx == bdy:
return add_jaxvals(x, y), bdx
elif bdx is not_mapped:
x = broadcast(x, y.shape[bdy], bdy)
return add_jaxvals(x, y), bdy
elif bdy is not_mapped:
y = broadcast(y, x.shape[bdx], bdx)
return add_jaxvals(x, y), bdx
else:
x = moveaxis(x, bdx, bdy)
return add_jaxvals(x, y), bdy
primitive_batchers[add_jaxvals_p] = add_batched
########################### core. ##################################
def _pvary_batcher(vals_in, dims_in, *, axes, axis_index_groups):
if any(type(axis) is int for axis in axes):
raise NotImplementedError
vals_out = core.pvary_p.bind(*vals_in, axes=axes,
axis_index_groups=axis_index_groups)
return vals_out, dims_in
primitive_batchers[core.pvary_p] = _pvary_batcher
### mutable arrays
defvectorized(core.mutable_array_p)