2025-08-11 12:24:21 +08:00

1844 lines
81 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Hashable, Sequence, Set
import enum
from functools import partial
import inspect
from math import prod
import operator as op
from typing import Any, TypeVar, Union
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import debugging
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import pvary
from jax._src.core import Tracer, typeof
from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh,
get_abstract_mesh, get_concrete_mesh)
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list,
merge_lists, split_list, subs_list2,
fun_name as util_fun_name)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_leaves, keystr)
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
generate_key_paths, KeyPath)
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
global_array_to_host_local_array)
P = PartitionSpec
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
traceback_util.register_exclusion(__file__)
# API
Specs = Any # PyTree[PartitionSpec]
AxisName = Hashable
def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
in_specs: Specs | None = None,
mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True):
"""Map a function over shards of data using a mesh of devices.
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
Args:
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the
array of devices over which to shard the data and on which to execute
instances of ``f``. The names of the ``Mesh`` can be used in collective
communication operations in ``f``. If mesh is None, it will be inferred
from the context which can be set via `jax.sharding.use_mesh` context
manager.
in_specs: (optional, default None) a pytree with
``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure
that is a tree prefix of the args tuple to be mapped over. Similar to
``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the
corresponding argument (or subtree of arguments) should be sharded along
the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a
``mesh`` axis name at a position expresses sharding the corresponding
argument array axis along that positional axis; not mentioning an axis
name expresses replication. If ``None``, all mesh axes must be of type
`Explicit`, in which case the in_specs are inferred from the argument types.
out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree
structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name
at a position expresses concatenation of that mesh axis's shards along the
corresponding positional axis; not mentioning a ``mesh`` axis name
expresses a promise that the output values are equal along that mesh axis,
and that rather than concatenating only a single value should be produced.
axis_names: (optional, default set()) set of axis names from ``mesh`` over
which the function ``f`` is manual. If empty, ``f``, is manual
over all mesh axes.
check_vma: (optional) boolean (default True) representing whether to enable
additional validity checks and automatic differentiation optimizations.
The validity checks concern whether any mesh axis names not mentioned in
``out_specs`` are consistent with how the outputs of ``f`` are replicated.
Returns:
A callable representing a mapped version of ``f``, which accepts positional
arguments corresponding to those of ``f`` and produces output corresponding
to that of ``f``.
"""
kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs,
axis_names=axis_names, check_vma=check_vma)
if f is None:
return lambda g: _shard_map(g, **kwargs)
return _shard_map(f, **kwargs)
def _axes_to_pspec(axis_name, axis):
if axis is None:
return P()
return P(*[None] * axis + [axis_name])
class InferFromArgs:
def __repr__(self):
return "jax.sharding.Infer"
def __reduce__(self):
return (_get_default_infer, ())
Infer = InferFromArgs()
def _get_default_infer():
return Infer
# TODO(yashkatariya): We need a singleton which users can provide to `in_axes`
# to tell smap to infer in_specs from args when mesh is fully explicit.
def smap(f, /, *, in_axes=Infer, out_axes, axis_name: AxisName):
if isinstance(axis_name, (list, tuple)):
raise TypeError(
f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}")
if (in_axes is not None and in_axes is not Infer and
not isinstance(in_axes, (int, tuple))):
raise TypeError(
"smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of"
" entries corresponding to the positional arguments passed to the"
f" function, but got {in_axes}.")
if (in_axes is not Infer and
not all(isinstance(l, int) for l in tree_leaves(in_axes))):
raise TypeError(
"smap in_axes must be an int, None, jax.sharding.Infer, or (nested)"
f" container with those types as leaves, but got {in_axes}.")
if not all(isinstance(l, int) for l in tree_leaves(out_axes)):
raise TypeError("smap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
in_specs = (None if in_axes is Infer else
tree_map(partial(_axes_to_pspec, axis_name), in_axes,
is_leaf=lambda x: x is None))
out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes,
is_leaf=lambda x: x is None)
return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs,
axis_names={axis_name}, check_vma=True, _smap=True)
def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None,
in_specs: Specs, out_specs: Specs | Callable[[], Specs],
axis_names: Set[AxisName], check_vma: bool,
_skip_mesh_check: bool = False, _smap: bool = False) -> Callable:
if not callable(f):
raise TypeError("shard_map requires a callable for its first argument, "
f"but got {f} of type {type(f)}.")
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
nonlocal mesh, axis_names
mesh, axis_names = _shmap_checks(mesh, axis_names, in_specs, out_specs,
_skip_mesh_check, _smap)
fun = lu.wrap_init(
f, debug_info=api_util.debug_info("shard_map", f, args, {}))
args_flat, in_tree = tree_flatten(args)
fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
try:
in_specs_flat = broadcast_prefix(
in_specs, args, is_leaf=lambda x: x is None)
except ValueError:
e, *_ = prefix_errors(in_specs, args)
raise e('shard_map in_specs') from None
if (in_specs is None and
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
arg_s = [typeof(a).sharding for a in args_flat]
assert all(i is None for i in in_specs_flat), in_specs_flat
in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s]
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
if s is not None)
fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False)
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat,
args_flat)
@memoize
def out_specs_thunk():
if callable(out_specs):
out_specs_ = out_specs()
_check_specs(SpecErrorType.out, out_specs_, axis_names)
else:
out_specs_ = out_specs
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
try:
out_specs_flat = broadcast_prefix(out_specs_, dummy)
except ValueError:
e, *_ = prefix_errors(out_specs_, dummy)
raise e('shard_map out_specs') from None
return tuple(out_specs_flat)
if check_vma:
fun = _implicit_pvary_on_output(fun, out_specs_thunk)
try:
out_flat = shard_map_p.bind(
fun, *args_flat, mesh=mesh, in_specs=in_specs_flat,
out_specs_thunk=out_specs_thunk, check_vma=check_vma,
manual_axes=axis_names)
except _SpecError as e:
fails, = e.args
if not callable(out_specs):
msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails)
if any(fail is not no_fail and not fail.shape for fail in fails):
msg += (" In particular, for rank 0 outputs which are not constant "
"over the mesh, add at least one (singleton) axis to them so "
"that they can be concatenated using out_specs.")
raise ValueError(msg) from None
except _RepError as e:
fails, = e.args
if not callable(out_specs):
msg = _inout_vma_error(f, mesh, out_tree(), out_specs, fails)
raise ValueError(msg) from None
return tree_unflatten(out_tree(), out_flat)
return wrapped
def _shmap_checks(mesh, axis_names, in_specs, out_specs, _skip_mesh_check,
_smap):
if mesh is None:
mesh = get_abstract_mesh()
if mesh.empty:
raise ValueError(
"The context mesh cannot be empty. Use"
" `jax.sharding.use_mesh(mesh)` to enter into a mesh context")
else:
ctx_mesh = get_abstract_mesh()
if (not _skip_mesh_check and not ctx_mesh.empty and
mesh.abstract_mesh != ctx_mesh):
raise ValueError(
f"The context mesh {ctx_mesh} should match the mesh passed to"
f" shard_map {mesh}")
if not isinstance(mesh, (Mesh, AbstractMesh)):
raise TypeError("shard_map requires a `jax.sharding.Mesh` or a "
"`jax.sharding.AbstractMesh` instance for its "
f"second argument, but got {mesh} of type {type(mesh)}.")
if not isinstance(axis_names, (frozenset, set)):
raise TypeError(
"`axis_names` argument of shard_map should be of type `frozenset` or"
f" `set`. Got type: {type(axis_names)}")
if isinstance(axis_names, set):
axis_names = frozenset(axis_names)
if not axis_names:
axis_names = frozenset(mesh.axis_names)
if not axis_names.issubset(mesh.axis_names):
raise ValueError(
f"jax.shard_map requires axis_names={axis_names} to be a subset of "
f"mesh.axis_names={mesh.axis_names}")
if (in_specs is None and
not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names)
if _smap:
msg = (f"in_axes was not specified when axis_name={axis_names} was of"
f" type {axis_types}")
else:
msg = ("shard_map in_specs argument must be a pytree of"
" `jax.sharding.PartitionSpec` instances, but it was `None` when"
f" {axis_names=} are of type {axis_types}")
raise TypeError(msg)
if in_specs is not None:
_check_specs(SpecErrorType.input, in_specs, axis_names)
if not callable(out_specs):
_check_specs(SpecErrorType.out, out_specs, axis_names)
return mesh, axis_names
def _manual_spec(manual_axes, spec: P) -> P:
out = [] # type: ignore
for s in spec:
if s is None:
out.append(s)
elif isinstance(s, tuple):
temp = [p if p in manual_axes else None for p in s]
while temp and temp[-1] is None:
temp.pop()
if None in temp:
raise ValueError(f"Invalid spec: {spec}")
out.append(None if len(temp) == 0 else tuple(temp))
else:
out.append(s if s in manual_axes else None)
return P(*out)
# Error checking and messages
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None:
if error_type == SpecErrorType.input and specs is None:
raise TypeError(
"shard_map in_specs argument must be a pytree of "
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
"where `P = jax.sharding.PartitionSpec`?")
def check_spec(p):
if not isinstance(p, PartitionSpec):
return False
for names in p:
names = (names,) if not isinstance(names, tuple) else names
for name in names:
if name is not None and name not in manual_axes:
return False
return True
if all(check_spec(p) for p in tree_leaves(specs)):
return
prefix = 'in' if error_type == SpecErrorType.input else 'out'
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
if not msgs:
for key, p in generate_key_paths(specs):
for names in p:
names = (names,) if not isinstance(names, tuple) else names
for name in names:
if name is not None and name not in manual_axes:
msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}")
raise ValueError(
f"shard_map {prefix}_specs argument must refer to an axis "
f"marked as manual ({manual_axes}), but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
raise TypeError(
f"shard_map {prefix}_specs argument must be a pytree of "
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
class NoFail: pass
no_fail = NoFail()
def _check_specs_vs_args(
f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs,
dyn_argnums: Sequence[int], in_specs_flat: Sequence[P],
xs: Sequence) -> None:
in_avals = map(core.shaped_abstractify, xs)
fail = [a if not len(p) <= a.ndim else no_fail
for p, a in zip(in_specs_flat, in_avals)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_spec_to_names, in_specs_flat))
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
raise ValueError(msg)
def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int],
fail: Sequence[core.ShapedArray | NoFail]
) -> list[core.ShapedArray | NoFail]:
fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves
for i, f in zip(dyn_argnums, fail):
fail_[i] = f
return fail_
def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: list[core.ShapedArray | NoFail]) -> str:
fun_name = util_fun_name(f)
if error_type == SpecErrorType.input:
prefix, base = 'in', 'args'
ba = _try_infer_args(f, tree)
else:
prefix, base = 'out', f'{fun_name}(*args)'
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
extra = ""
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key
if arg_key.idx < len(ba.arguments):
param_name = list(ba.arguments.keys())[arg_key.idx]
extra = (f", where {base}{arg_key} is bound to {fun_name}'s "
f"parameter '{param_name}',")
else:
param = list(ba.signature.parameters.values())[-1]
assert param.kind == inspect.Parameter.VAR_POSITIONAL
extra = (f", where {base}{arg_key} is the index "
f"{arg_key.idx - len(ba.signature.parameters) + 1} component "
f"of {fun_name}'s varargs parameter '{param.name}',")
msgs.append(
f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length "
f"{len(spec)}, but "
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given an "
f"{prefix}_specs entry which is too long to be compatible with the "
f"corresponding {prefix}put value from the function:\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Entries in {prefix}_specs must be of length no greater than the "
f"number of axes in the corresponding {prefix}put value.\n\n"
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
f"that its {prefix}puts have sufficient rank.")
if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)):
msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs "
"entry of `P()`, where `P = jax.sharding.PartitionSpec`.")
return msg
def _spec_divisibility_error(
f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs,
fails: list[core.ShapedArray | NoFail]) -> str:
ba = _try_infer_args(f, tree)
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
extra = ""
if ba is not None:
arg_key, *_ = fail_key
if arg_key.idx < len(ba.arguments):
param_name = list(ba.arguments.keys())[arg_key.idx]
extra = (f", where args{arg_key} is bound to {fun_name}'s "
f"parameter '{param_name}',")
else:
param = list(ba.signature.parameters.values())[-1]
assert param.kind == inspect.Parameter.VAR_POSITIONAL
extra = (f", where args{arg_key} is the index "
f"{arg_key.idx - len(ba.signature.parameters) + 1} component "
f"of {fun_name}'s varargs parameter '{param.name}',")
names = _spec_to_names(spec)
for d, ns in names.items():
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = prod(mesh.shape[n] for n in ns)
msgs.append(
f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
f"arrays with axis sizes that are not evenly divisible by the "
f"corresponding mesh axis sizes:\n\n"
f"The mesh given has shape {tuple(mesh.shape.values())} with "
f"corresponding axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Array arguments' axis sizes must be evenly divisible by the mesh "
f"axis or axes indicated by the corresponding elements of the "
f"argument's in_specs entry. Consider checking that in_specs are "
f"correct, and if so consider changing the mesh axis sizes or else "
f"padding the input and adapting '{fun_name}' appropriately.")
return msg
def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef,
specs: Specs, fails: list[set | NoFail]) -> str:
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails):
unmentioned = _unmentioned(mesh, spec)
if len(unmentioned) > 1:
need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec))))
got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma)))
diff = ','.join(map(str, order_wrt_mesh(
mesh, [n for n in unmentioned if n in vma])))
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is only varying across mesh axes "
f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be "
f"possibly varying over {{{got_vma}}}")
else:
need_rep_, = unmentioned
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given "
f"out_specs which require replication which can't be statically "
f"inferred given the mesh:\n\n"
f"The mesh given has shape {tuple(mesh.shape.values())} with "
f"corresponding axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
"Check if these output values are meant to be replicated over those "
"mesh axes. If not, consider revising the corresponding out_specs "
"entries. If so, consider disabling the check by passing the "
"check_vma=False argument to `jax.shard_map`.")
return msg
def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]:
vma_set = _spec_to_vma(spec)
return [n for n in mesh.axis_names if n not in vma_set]
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
return inspect.signature(f).bind(*dummy_args)
except (TypeError, ValueError):
return None
T = TypeVar('T')
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail]
) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]:
failures = tree_unflatten(tree, fails)
failures_aug = generate_key_paths(failures)
specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs))
leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data)
in zip(specs_aug, failures_aug)
if s is not None and fail_data is not no_fail]
# Primitive
@lu.transformation2
def _implicit_pvary_on_output(f, out_specs_thunk, *args, **kwargs):
out_flat = f(*args, **kwargs)
return [pvary(o, tuple(_spec_to_vma(sp) - typeof(o).vma))
for o, sp in zip(out_flat, out_specs_thunk())]
JaxType = Any
MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive):
multiple_results = True
def bind(self, *args, **params):
return self._true_bind(*args, **params)
def bind_with_trace(self, trace, fun_and_args, params):
fun: lu.WrappedFun
fun, *args = fun_and_args
return trace.process_shard_map(shard_map_p, fun, args, **params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr: core.Jaxpr = new_params.pop('jaxpr')
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr,
debug_info=jaxpr.debug_info),
jaxpr, ())
axes = new_params.pop('out_specs')
new_params['out_specs_thunk'] = HashableFunction(lambda: axes, closure=axes)
return [subfun], new_params
shard_map_p = ShardMapPrimitive('shard_map')
# Staging
@util.cache(max_size=256, trace_context_in_key=True)
def _as_manual_mesh(mesh, manual_axes: frozenset):
not_manual = set(mesh.axis_names) - manual_axes
cur_mesh = get_abstract_mesh()
if cur_mesh.empty:
cur_mesh = mesh
explicit_axes, auto_axes = set(), set() # type: ignore
for a in not_manual:
if cur_mesh._name_to_type[a] == AxisType.Auto:
auto_axes.add(a)
else:
assert cur_mesh._name_to_type[a] == AxisType.Explicit, (
a, cur_mesh._name_to_type[a])
explicit_axes.add(a)
new_axis_types = []
for n in mesh.axis_names:
if n in manual_axes:
new_axis_types.append(AxisType.Manual)
elif n in auto_axes:
new_axis_types.append(AxisType.Auto)
else:
assert n in explicit_axes
new_axis_types.append(AxisType.Explicit)
return AbstractMesh(mesh.axis_sizes, mesh.axis_names,
axis_types=tuple(new_axis_types))
def _extend_axis_env(mesh, manual_axes):
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
if k in manual_axes])
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[Any], *, mesh: Mesh,
in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset,
) -> Sequence[pe.DynamicJaxprTracer]:
source_info = source_info_util.current()
to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info)
in_tracers = map(to_jaxpr_tracer, in_tracers)
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_specs,
in_avals)
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
config._check_vma(check_vma)):
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
_check_names(out_specs_thunk(), out_avals_)
if check_vma:
out_vma = [v.aval.vma for v in jaxpr.outvars]
_check_vmas(mesh, out_specs_thunk(), out_vma)
out_avals = map(_check_shapedarray, out_avals_)
out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, spec, aval))
for spec, aval in zip(out_specs_thunk(), out_avals)]
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
invars = map(trace.getvar, in_tracers)
constvars = map(trace.getvar, map(to_jaxpr_tracer, consts))
outvars = map(trace.makevar, out_tracers)
in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
config._check_vma(check_vma)):
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_specs=in_specs_staged,
out_specs=tuple(out_specs_thunk()), jaxpr=jaxpr,
check_vma=check_vma, manual_axes=manual_axes)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
effs, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
# TODO add underscore version, for direct-linearize to consume
def _spec_to_names(spec: PartitionSpec):
return {i: names if isinstance(names, tuple) else (names,)
for i, names in enumerate(spec) if names is not None}
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
def _shard_aval(mesh: Mesh, manual_axes, check_vma, spec,
aval: core.AbstractValue) -> core.AbstractValue:
if type(aval) in core.shard_aval_handlers:
return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma,
spec, aval)
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
def _unshard_aval(mesh: Mesh, check_vma, spec,
aval: core.AbstractValue) -> core.AbstractValue:
if type(aval) in core.unshard_aval_handlers:
return core.unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval)
else:
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
spec, aval: core.AbstractValue) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
names = _spec_to_names(spec)
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
new_sharding = NamedSharding(manual_mesh, aval.sharding.spec)
vma = _spec_to_vma(spec) if check_vma else frozenset()
vma = vma | aval.vma
return aval.update(shape=new_shape, sharding=new_sharding, vma=vma)
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.AbstractValue
) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
names = _spec_to_names(spec)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
names_spec = spec._normalized_spec_for_aval(aval.ndim)
if aval.ndim == 0:
out_spec = P()
else:
out_spec = [] # type: ignore
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
if name_s and not aval_s:
out_spec.append(name_s)
elif aval_s and not name_s:
out_spec.append(aval_s)
elif not name_s and not aval_s:
out_spec.append(None)
else:
assert name_s and aval_s
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
out_spec.append(name_s + aval_s)
out_spec = PartitionSpec(*out_spec)
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
get_abstract_mesh())
new_sharding = NamedSharding(new_mesh, out_spec)
manual_axes = set(new_mesh.manual_axes)
vma = (frozenset(v for v in aval.vma if v in manual_axes)
if check_vma else frozenset())
return aval.update(shape=new_shape, sharding=new_sharding, vma=vma)
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
# Type-checking
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
# TODO(mattjj,parkers): check auto
for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs):
if not core.typecompat(v.aval, _shard_aval(
mesh, manual_axes, check_vma, in_spec, x.aval)):
raise core.JaxprTypeError("shard_map argument avals not compatible with "
"jaxpr binder avals and in_specs")
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
core.check_jaxpr(jaxpr)
if check_vma:
out_vma = [v.aval.vma for v in jaxpr.outvars]
for vma, out_spec in zip(out_vma, out_specs):
if not _valid_repeats(mesh, vma, out_spec):
raise core.JaxprTypeError(
"shard_map can't prove output is sufficiently replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(_unshard_aval, mesh, check_vma), out_specs,
out_avals_sharded)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
return out_avals, effs
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _valid_repeats(mesh: Mesh, vma: Set[AxisName], spec) -> bool:
um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes)
if any(u in vma for u in um):
return False
return True
# Lowering
def _shardy_shard_map_sharding(
ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in
) -> sharding_impls.SdyArray:
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
sdy_sharding = ns._to_sdy_sharding(aval_in.ndim)
if len(manual_axes) < len(mesh.axis_names):
for dim_sharding in sdy_sharding.dim_shardings:
dim_sharding.is_open = True
return sdy_sharding
def _shardy_shard_map_token_sharding(
ctx: mlir.LoweringRuleContext, mesh
) -> ir.Attribute:
ns = _make_scoped_manual_sharding(ctx, mesh, P())
return ns._to_sdy_sharding(0)
def _get_spmdaxis_ctx_mesh(mesh):
if isinstance(mesh, AbstractMesh):
concrete_mesh = get_concrete_mesh()
return concrete_mesh if concrete_mesh is not None else mesh
return mesh
def _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma):
axis_ctx = ctx.module_context.axis_context
in_avals_ = [v.aval for v in jaxpr.invars]
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
# Nested `ManualComputationOp`s cannot refer to axes that are already
# manual. So figure out what axes are free thus far.
shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes
else:
shardy_manual_axes = manual_axes
new_axis_context = sharding_impls.SPMDAxisContext(
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()]
num_tokens = len(tokens)
manual_axes = order_wrt_mesh(mesh, shardy_manual_axes)
if np.prod([mesh.shape[a] for a in manual_axes]) == 1:
# No need for a `ManualComputationOp` if all manual axes are size 1.
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
out_nodes, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack,
mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)),
(), *in_nodes,
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens_out)
return out_nodes
in_shardings = list(
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
in_specs, ctx.avals_in))
num_dim_vars = len(ctx.dim_var_values)
in_shardings = ([_shardy_shard_map_token_sharding(ctx, mesh)]
* (num_tokens + num_dim_vars) + in_shardings)
in_shardings = sharding_impls.SdyArrayList(in_shardings).build()
out_shardings = list(
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
out_specs, ctx.avals_out))
out_shardings = [
_shardy_shard_map_token_sharding(ctx, mesh)] * num_tokens + out_shardings
out_shardings = sharding_impls.SdyArrayList(out_shardings).build()
output_types = ([hlo.TokenType.get()] * num_tokens +
list(map(mlir.aval_to_ir_type, ctx.avals_out)))
args = (*ctx.dim_var_values, *tokens, *in_nodes)
manual_computation_op = sdy.ManualComputationOp(
output_types,
mlir.flatten_ir_values(args),
in_shardings, out_shardings,
sdy.ManualAxesAttr.get(
ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes])))
block = ir.Block.create_at_start(
manual_computation_op.body,
(*(i if isinstance(i, ir.Type) else i.type for i in ctx.dim_var_values),
*([hlo.TokenType.get()] * num_tokens),
*map(mlir.aval_to_ir_type, in_avals_)))
with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes),
config._check_vma(check_vma)):
out_nodes_, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack,
mlir.TokenSet(zip(
ctx.tokens_in.effects(), block.arguments[:num_tokens])),
(), *block.arguments[num_tokens+num_dim_vars:],
dim_var_values=ctx.dim_var_values)
sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()],
*out_nodes_)])
num_tokens = len(tokens_out.effects())
tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip(
ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens])))
ctx.set_tokens_out(tokens_out)
return manual_computation_op.results[num_tokens:]
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
if config.use_shardy_partitioner.value:
return _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma)
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs,
ctx.avals_in, in_avals_, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
out_nodes_, tokens_out = mlir.call_lowering(
"shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_,
out_avals_, ctx.tokens_in, *in_nodes_,
dim_var_values=ctx.dim_var_values,
arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_),
result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_))
ctx.set_tokens_out(tokens_out)
return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs,
out_avals_, ctx.avals_out, out_nodes_)
mlir.register_lowering(shard_map_p, _shard_map_lowering)
def _make_scoped_manual_sharding(ctx, mesh, spec):
axis_ctx = ctx.module_context.axis_context
mesh = mesh.abstract_mesh
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
mesh = mesh.update_axis_types(
{a: AxisType.Manual for a in axis_ctx.manual_axes})
return NamedSharding(mesh, spec)
def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
return x
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
unspecified = (set(range(aval_in.ndim))
if len(manual_axes) < len(mesh.axis_names) else set())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
unspecified_dims=unspecified)
manual_proto = pxla.manual_proto(
aval_in, manual_axes | set(mesh.manual_axes), mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto,
unspecified)
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
return x
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = (set(range(aval_in.ndim))
if len(manual_axes) < len(mesh.axis_names) else set())
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
manual_proto = pxla.manual_proto(
aval_in, manual_axes | set(mesh.manual_axes), mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto,
unspecified_dims=unspecified)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
unspecified)
def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str:
if isinstance(aval, core.ShapedArray):
names = _spec_to_names(spec)
return str(map(names.get, range(aval.ndim)))
return ''
# Eager evaluation
def get_mesh_from_args(args_flat, mesh):
for a in args_flat:
if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding):
if a.sharding.mesh.shape_tuple != mesh.shape_tuple:
aval = core.shaped_abstractify(a)
raise ValueError(
f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not"
" match the mesh shape passed to shard_map "
f" {mesh.shape_tuple} for shape {aval.str_short()}")
mesh = a.sharding.mesh
if isinstance(mesh, AbstractMesh):
raise ValueError(
"Please pass `jax.Array`s with a `NamedSharding` as input to"
" `shard_map` when passing `AbstractMesh` to the mesh argument.")
assert isinstance(mesh, Mesh)
return mesh
def _vma_to_spec(mesh, vma):
return P(order_wrt_mesh(mesh, vma))
def _spec_to_vma(spec):
return frozenset(p for s in spec if s is not None
for p in (s if isinstance(s, tuple) else (s,)))
def order_wrt_mesh(mesh, x):
return tuple(a for a in mesh.axis_names if a in x)
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk,
check_vma, manual_axes):
if len(manual_axes) < len(mesh.axis_names):
raise NotImplementedError
del prim
if isinstance(mesh, AbstractMesh):
concrete_mesh = get_concrete_mesh()
mesh = concrete_mesh if concrete_mesh is not None else mesh
mesh = get_mesh_from_args(args, mesh)
cur_mesh = get_abstract_mesh()
args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh),
in_specs, args)
in_vma = map(_spec_to_vma, in_specs)
outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma,
cur_mesh)
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs]
_check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_vma:
_check_vmas(mesh, out_specs_thunk(), out_vma)
src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma)
else:
src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma)
dst_pspecs = out_specs_thunk()
return map(partial(_match_spec, mesh, check_vma), src_pspecs, dst_pspecs,
outs)
core.EvalTrace.process_shard_map = _shard_map_impl
def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh):
trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh)
in_tracers = map(partial(ShardMapTracer, trace), vmas, args)
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
ans = f.call_wrapped(*in_tracers)
outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans))
return outs, out_vma
def _unmatch_spec(mesh: Mesh, check_vma, in_spec, x: JaxType, context_mesh
) -> JaxType:
with (core.eval_context(), jax.disable_jit(False),
use_abstract_mesh(context_mesh)):
return jax.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec))(x)
def _unmatch(mesh, check_vma, in_spec, x):
if check_vma:
used_axes = _spec_to_vma(in_spec)
dst = P(order_wrt_mesh(mesh, used_axes))
else:
dst = P(mesh.axis_names)
check_vma = False
return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,),
out_specs=dst, check_vma=check_vma)(x)
def _check_names(specs, avals: Sequence[core.ShapedArray]) -> None:
fail = [a if sp and len(sp) > a.ndim else no_fail
for sp, a in zip(specs, avals)]
if any(f is not no_fail for f in fail):
raise _SpecError(fail)
class _SpecError(Exception):
pass
def _check_vmas(mesh, specs, vmas):
fail = [vma if not _valid_repeats(mesh, vma, sp) else no_fail
for sp, vma in zip(specs, vmas)]
if any(f is not no_fail for f in fail):
raise _RepError(fail)
class _RepError(Exception):
pass
def _match_spec(mesh: Mesh, check_vma, src_pspec: PartitionSpec,
dst_pspec: PartitionSpec, x: JaxType) -> JaxType:
fn = HashablePartial(_match, mesh, check_vma, src_pspec, dst_pspec)
with core.eval_context(), jax.disable_jit(False):
return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x)
def _match(mesh, check_vma, src_pspec, dst_pspec, x):
return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec,
out_specs=dst_pspec, check_vma=check_vma)(x)
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
def _add_singleton(x): return jnp.expand_dims(x, axis=0)
def _maybe_check_special(outs):
if not config.debug_nans.value and not config.debug_infs.value: return
bufs = [s.data for leaf in tree_leaves(outs)
for s in getattr(leaf, 'addressable_shards', [])]
try:
dispatch.check_special('shard_map', bufs)
except api_util.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "manual_axes", "check", "context_mesh")
mesh: Mesh
manual_axes: frozenset[AxisName]
check: bool
context_mesh: AbstractMesh
def __init__(self, mesh, manual_axes, check, context_mesh):
super().__init__()
self.mesh = mesh
self.manual_axes = manual_axes
self.check = check
self.context_mesh = context_mesh
def to_val_vma_pair(self, val):
if isinstance(val, ShardMapTracer):
return val.val, val.vma
elif isinstance(val, Tracer):
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, self.check, P(), val, self.context_mesh)
return val_, frozenset()
def process_primitive(self, prim, tracers, params):
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
if self.check:
out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params)
out_avals = tuple(out_avals) if type(out_avals) is list else out_avals
out_vma = tree_map(lambda a: a.vma, out_avals)
in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma))
out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma)
else:
out_vma = frozenset()
in_specs = out_specs = P(self.mesh.axis_names)
eager_rule = eager_rules.get(prim)
if eager_rule:
out_vals = eager_rule(self.mesh, *in_vals, **params)
else:
f = HashablePartial(
_prim_applier, prim, self.check, tuple(params.items()), self.mesh,
in_specs, out_specs)
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
jax.debug_infs(False), use_abstract_mesh(self.context_mesh)):
out_vals = jax.jit(f)(*in_vals)
_maybe_check_special(out_vals)
if prim.multiple_results:
out_vma = (out_vma if isinstance(out_vma, (list, tuple))
else [out_vma] * len(out_vals))
return map(partial(ShardMapTracer, self), out_vma, out_vals)
return ShardMapTracer(self, out_vma, out_vals)
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError(
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
"function, and open a feature request at "
"https://github.com/jax-ml/jax/issues !")
def process_map(self, map_primitive, fun, tracers, params):
raise NotImplementedError(
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/jax-ml/jax/issues !")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
del prim, jvp, symbolic_zeros
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals,
in_vma, self.check, self.context_mesh)
return map(partial(ShardMapTracer, self), out_vma, out_vals)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals,
in_vma, self.check, self.context_mesh)
return map(partial(ShardMapTracer, self), out_vma, out_vals)
class ShardMapTracer(core.Tracer):
vma: frozenset[AxisName]
val: JaxType
def __init__(self, trace, vma, val):
self._trace = trace
if isinstance(vma, set):
vma = frozenset(vma)
assert isinstance(vma, frozenset)
self.vma = vma
self.val = val
@property
def aval(self):
aval = core.get_aval(self.val)
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
new_sharding = NamedSharding(
_as_manual_mesh(self._trace.mesh, self._trace.manual_axes),
out.sharding.spec) # pytype: disable=attribute-error
vma = self.vma if config._check_vma.value else frozenset()
return out.update(sharding=new_sharding, vma=vma)
def to_concrete_value(self):
if self.vma == frozenset():
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
return core.to_concrete_value(self.val[0])
else:
return None
def __str__(self) -> str:
pb_names = set(self._trace.mesh.axis_names) - self.vma
self = pvary(self, tuple(pb_names))
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
blocks = list(self.val)
mesh = self._trace.mesh
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
return '\n'.join(
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
__repr__ = __str__ # for debuggers, like `p x`
def _prim_applier(prim, check_vma, params_tup, mesh, in_specs, out_specs, *args):
def apply(*args):
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
return tree_map(_add_singleton, outs)
out_specs = list(out_specs) if type(out_specs) is tuple else out_specs
return shard_map(apply, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
check_vma=check_vma)(*args)
eager_rules: dict[core.Primitive, Callable] = {}
# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually
def _debug_callback_eager_rule(
mesh,
*args,
callback: Callable[..., Any],
effect: debugging.DebugEffect,
partitioned: bool,
):
del effect
with core.eval_context():
all_blocks = zip(*map(list, args))
for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks):
callback(*blocks)
return []
eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule
def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
del mesh, srcs, copy_semantics
for device in devices:
if device is not None:
raise ValueError("device_put with explicit device not allowed within "
f"shard_map-decorated functions, but got device {device}")
return xs
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# Batching
def _modify_specs_axis_data(trace, name, mesh, in_specs, in_dims):
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, name)
for sp, d in zip(in_specs, in_dims)]
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in name)
new_axis_data = batching.AxisData(
trace.axis_data.name, new_size, trace.axis_data.spmd_name,
trace.axis_data.explicit_mesh_axis)
return new_in_specs, new_axis_data
def _shard_map_batch(
trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun,
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset
) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
raise NotImplementedError
spmd_axis_name = trace.axis_data.spmd_name
explicit_mesh_axis = trace.axis_data.explicit_mesh_axis
if spmd_axis_name is not None:
used = {n for spec in in_specs for n in _spec_to_vma(spec)}
if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
new_in_specs, new_axis_data = _modify_specs_axis_data(
trace, spmd_axis_name, mesh, in_specs, in_dims)
elif explicit_mesh_axis is not None:
used = {n for spec in in_specs for n in _spec_to_vma(spec)}
if set(explicit_mesh_axis) & used:
raise ValueError("vmapped away explicit mesh axis cannot appear in "
"shard_map in_specs")
new_in_specs, new_axis_data = _modify_specs_axis_data(
trace, explicit_mesh_axis, mesh, in_specs, in_dims)
else:
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(in_specs, in_dims)]
new_axis_data = trace.axis_data
fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims))
@as_hashable_function(closure=out_specs_thunk)
def new_out_specs_thunk():
return _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims(),
out_specs_thunk())
new_params = dict(mesh=mesh, in_specs=new_in_specs,
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
manual_axes=manual_axes)
with core.set_current_trace(trace.parent_trace):
out_vals = prim.bind(fun, *in_vals, **new_params)
make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current())
return map(make_tracer, out_vals, out_dims())
batching.BatchTrace.process_shard_map = _shard_map_batch
def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs):
if spmd_name is not None:
used = {n for spec in out_specs for n in _spec_to_vma(spec)}
if not config.disable_vmap_shmap_error.value and set(spmd_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name)
for sp, d in zip(out_specs, dims)]
elif explicit_mesh_axis is not None:
used = {n for spec in out_specs for n in _spec_to_vma(spec)}
if set(explicit_mesh_axis) & used:
raise ValueError("vmapped away explicit mesh axis cannot appear in "
"shard_map out_specs")
return [sp if d is batching.not_mapped else
pxla.batch_spec(sp, d, explicit_mesh_axis)
for sp, d in zip(out_specs, dims)]
else:
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(out_specs, dims)]
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs,
out_specs_thunk, check_vma, manual_axes):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
args, in_tree = tree_flatten((primals, tangents))
f_jvp = ad.jvp_subtrace(f, trace.tag)
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
tangent_in_specs = [sp for sp, nz in zip(in_specs, which_nz) if nz]
@as_hashable_function(closure=out_specs_thunk)
def new_out_specs_thunk():
out_ax = out_specs_thunk()
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs),
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
manual_axes=manual_axes)
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t
for p, t in zip(primal_out, tangent_out)]
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
f: lu.WrappedFun, tracers, mesh, in_specs,
out_specs_thunk, check_vma, manual_axes):
tracers = map(trace.to_jaxpr_tracer, tracers)
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs)
in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma),
unk_in_specs, in_avals)
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
f = _promote_scalar_residuals(f)
f_known, aux = pe.partial_eval_wrapper_nounits2(
f, (*in_knowns,), (*in_avals_sharded,))
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
@as_hashable_function(closure=out_specs_thunk)
def known_out_specs():
_, _, out_knowns, res_avals, _, _ = aux()
_, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk())
if check_vma:
res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals]
else:
res_specs = [P(all_names)] * len(res_avals)
return (*out_known_specs, *res_specs)
known_params = dict(mesh=mesh, in_specs=(*known_in_specs,),
out_specs_thunk=known_out_specs, check_vma=check_vma,
manual_axes=manual_axes)
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts),
known_params)
in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux()
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
assert not jaxpr.constvars
unk_out_specs, _ = pe.partition_list(out_knowns, out_specs_thunk())
known_out_specs_ = known_out_specs()
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
# TODO make res_avals be the full set, not just the non-fwd ones
res_avals_iter = iter(res_avals)
res_specs = []
for f1, f2 in zip(in_fwd, out_fwd):
if f1 is not None:
res_specs.append(known_in_specs[f1])
elif f2 is not None:
res_specs.append(known_out_specs_[f2])
else:
if check_vma:
res_vma = next(res_avals_iter).vma
res_specs.append(P(order_wrt_mesh(mesh, res_vma)))
else:
res_specs.append(P(all_names))
unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment]
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.to_jaxpr_tracer, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()]
out_avals_sharded = [v.aval for v in jaxpr.outvars]
unk_params = dict(mesh=mesh, in_specs=unk_in_specs,
out_specs=unk_out_specs, jaxpr=jaxpr,
check_vma=check_vma, manual_axes=manual_axes)
out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_specs,
out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
out_tracers, shard_map_p, unk_params,
effs, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
tracers, mesh, in_specs, out_specs_thunk, check_vma,
manual_axes):
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
@as_hashable_function(closure=linearize_outs_thunk)
def fwd_out_specs_thunk():
res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
if f1 is None and f2 is None]
out_specs = out_specs_thunk()
if check_vma:
res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals]
else:
res_specs = [P(all_names)] * len(res_avals)
return (*res_specs, *out_specs)
fwd_params = dict(
mesh=mesh, in_specs=in_specs,
out_specs_thunk=fwd_out_specs_thunk, check_vma=check_vma,
manual_axes=manual_axes)
all_fwd_results = shard_map_p.bind_with_trace(
trace.parent_trace, (f_primal, *primals), fwd_params)
res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
non_fwd_res = all_fwd_results[:num_res_out]
primals_out = all_fwd_results[num_res_out:]
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)]
with (_extend_axis_env(mesh, manual_axes),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
config._check_vma(check_vma)):
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
out_specs = out_specs_thunk()
res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
if f1 is None and f2 is None]
res_avals_iter = iter(res_avals2)
res_specs = []
for f1, f2 in zip(in_fwd, out_fwd):
if f1 is not None:
res_specs.append(in_specs[f1])
elif f2 is not None:
res_specs.append(out_specs[f2])
else:
if check_vma:
res_vma = next(res_avals_iter).vma
res_specs.append(P(order_wrt_mesh(mesh, res_vma)))
else:
res_specs.append(P(all_names))
new_in_specs = (*res_specs, *(P(),) * len(env),
*(ax for ax, nz in zip(in_specs, nzs_in) if nz))
tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out)
if nz)
@as_hashable_function(closure=tangent_out_specs)
def tangent_out_specs_thunk():
return tangent_out_specs
tangent_params = dict(
mesh=mesh, in_specs=new_in_specs, out_specs_thunk=tangent_out_specs_thunk,
check_vma=check_vma, manual_axes=manual_axes)
# TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here
def f_tangent(*args):
return core.eval_jaxpr(lin_jaxpr, (), *args)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(
trace.tangent_trace,
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
*residuals, *env, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out)
ad.LinearizeTrace.process_shard_map = _shard_map_linearize
@lu.transformation2
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
ans = f(*args, **kwargs)
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
residuals = ans[:num_res_out]
primals = ans[num_res_out:]
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in residuals]
return *residuals, *primals
@lu.transformation2
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
which = [f1 is None and f2 is None and not v.aval.shape
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in out_consts]
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]):
def fun(*res_and_args):
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
return core.eval_jaxpr(jaxpr, res, *args)
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for v, w in zip(jaxpr.constvars, which)]
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals)
return jaxpr
def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName]
) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-vma case.
name_set = _spec_to_vma(spec)
return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes)
if n not in name_set]
def _shard_map_transpose(out_cts, *args,
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [
ad.Zero(_shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes))))
for sp, x in zip(out_specs, out_cts)
]
args = tuple(x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(
_shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
for sp, x in zip(in_specs, args))
all_args, in_tree = tree_flatten((out_cts, args))
def fun_trans_callable(out_cts, args):
# TODO(mattjj): when #26811 lands, delete this and just run backward_pass
in_undef = map(ad.is_undefined_primal, args)
res, undefs = partition_list(in_undef, args)
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
pe.close_jaxpr(jaxpr), in_undef, False)
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
in_cts = ad.backward_pass(
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
)[len(res_reshaped):]
_, in_ct_specs = partition_list(in_undef, in_specs)
in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, sp, x.aval))
if type(x) is ad.Zero else x if check_vma
else jax.lax.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes)))
for sp, x in zip(in_ct_specs, in_cts)]
res_zeros = [ad_util.zero_from_primal(r) for r in res]
return merge_lists(in_undef, res_zeros, in_cts)
fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info)
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
new_in_specs = (
[n for n, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] +
[n for n, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal])
def new_out_specs_thunk():
return tuple(sp for sp, nz in zip(in_specs, nz_arg_cts()) if nz)
try:
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs),
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
manual_axes=manual_axes)
except (FloatingPointError, ZeroDivisionError) as e:
print("Invalid nan value encountered in the backward pass of a shard_map "
"function. Calling the de-optimized backward pass.")
try:
# TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)`
# in eager mode so that output of shmap are not manual.
with jax.disable_jit(True):
_ = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs),
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
manual_axes=manual_axes)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
api_util._raise_no_nan_in_deoptimized(e)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
# Remat
def _partial_eval_jaxpr_custom_rule(
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: core.JaxprEqn
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
list[core.Var]]:
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes']
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
num_out_primals = len(jaxpr_known.outvars) - num_res
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
mesh = eqn.params['mesh']
with (_extend_axis_env(mesh, manual_axes),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
config._check_vma(check_vma)):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
newvar = core.gensym()
residuals, staged_in_res_specs = [], []
for var, w in zip(jaxpr_staged.invars[:num_res], which):
if w:
rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore
if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes)))
residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval)))
staged_in_res_specs.append(rn)
if check_vma:
out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore
for var, o in zip(res_vars, out_fwd) if o is None]
else:
out_res_specs_known = [
P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which)
params_known, params_staged = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd,
out_res_specs_known, staged_in_res_specs,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info, eqn.ctx)
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
_partial_eval_jaxpr_custom_rule
def _add_reshapes(which: Sequence[bool],
jaxpr_known: core.Jaxpr,
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
# add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
def known(*args):
out = core.eval_jaxpr(jaxpr_known, (), *args)
out_known, res = split_list(out, [len(out) - sum(which)])
res = [_add_singleton(x) if not x.shape else x for x in res]
return [*out_known, *res]
avals_in = [v.aval for v in jaxpr_known.invars]
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in)
def staged(*args):
res_, ins = split_list(args, [len(which)])
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in)
return jaxpr_known, jaxpr_staged
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs,
params_known, params_staged):
# prune inputs to jaxpr_known according to unks_in
in_specs_known, _ = partition_list(unks_in, params_known['in_specs'])
_, out_specs_known = partition_list(kept_outs_known, params_known['out_specs'])
out_specs_known = out_specs_known + out_res_specs_known
assert len(out_specs_known) == len(params_known['jaxpr'].outvars)
new_params_known = dict(params_known, in_specs=tuple(in_specs_known),
out_specs=tuple(out_specs_known))
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, in_specs_staged = partition_list(inst_in, params_staged['in_specs'])
iter_staged = iter(staged_in_res_specs)
res_specs = [in_specs_known[f1] if f1 is not None else
out_specs_known[f2] if f2 is not None else
next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)]
in_specs_staged = res_specs + in_specs_staged
_, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs'])
new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged),
out_specs=tuple(out_specs_staged))
return new_params_known, new_params_staged
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
return tuple(name for name in mesh.axis_names
if name not in spmd_names and name in manual_axes)
def _all_newly_manual_mesh_names(
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
vmap_spmd_names = set(axis_env.spmd_axis_names)
if not (ctx_mesh := get_abstract_mesh()).empty:
mesh = ctx_mesh
already_manual_names = set(ctx_mesh.manual_axes)
else:
# TODO(mattjj): remove this mechanism when we revise mesh scopes
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
return tuple(name for name in mesh.axis_names
if (name not in vmap_spmd_names | already_manual_names and
name in manual_axes))
# DCE
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
mesh = eqn.params["mesh"]
manual_axes = eqn.params["manual_axes"]
check_vma = eqn.params["check_vma"]
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
return used_inputs, None
else:
_, in_specs = partition_list(used_inputs, eqn.params['in_specs'])
_, out_specs = partition_list(used_outputs, eqn.params['out_specs'])
new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs),
out_specs=tuple(out_specs))
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[x for x, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
pe.dce_rules[shard_map_p] = _shard_map_dce
# Implementing pmap in terms of shard_map
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
static_broadcasted_argnums=(), devices=None, backend=None,
axis_size=None, donate_argnums=(), global_arg_shapes=None):
devices = tuple(devices) if devices is not None else devices
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
def infer_params(*args, **kwargs):
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, devices, backend, axis_size, args, kwargs)
for arg in p.flat_args:
dispatch.check_arg(arg)
mesh = Mesh(_get_devices(p, backend), (axis_name,))
_pmapped, in_specs, out_specs = _cached_shard_map(
p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name)
flat_global_args = host_local_array_to_global_array(
p.flat_args, mesh, list(in_specs))
jitted_f = jax.jit(
_pmapped,
donate_argnums=[i for i, val in enumerate(p.donated_invars) if val])
return jitted_f, flat_global_args, p.out_tree, mesh, out_specs
def wrapped(*args, **kwargs):
(jitted_f, flat_global_args, out_tree, mesh,
out_specs) = infer_params(*args, **kwargs)
outs = jitted_f(*flat_global_args)
outs = global_array_to_host_local_array(outs, mesh, out_specs())
return tree_unflatten(out_tree(), outs)
def lower(*args, **kwargs):
jitted_f, _, _, _, _ = infer_params(*args, **kwargs)
return jitted_f.lower(*args, **kwargs)
wrapped.lower = lower
return wrapped
@lu.cache
def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat))
out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk())
fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk)
return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs,
out_specs=out_specs, check_vma=False,
axis_names=set(mesh.axis_names)),
in_specs, out_specs)
@lu.transformation2
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax),
list(args), list(in_axes))
out = f(*args)
return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
list(out), list(out_axes_thunk()))
def _axis_to_spec(axis_name, ax):
if isinstance(ax, int):
specs = [None] * ax + [axis_name]
return P(*specs)
elif ax is None:
return P()
else:
raise TypeError(ax)
def _get_devices(p, backend):
if backend is not None and p.devices is None:
devs = jax.devices(backend=backend)
else:
devs = jax.devices() if p.devices is None else p.devices
if jax.process_count() > 1:
return devs[:p.global_axis_size]
return devs[:p.local_axis_size]