1844 lines
81 KiB
Python
1844 lines
81 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Hashable, Sequence, Set
|
|
import enum
|
|
from functools import partial
|
|
import inspect
|
|
from math import prod
|
|
import operator as op
|
|
from typing import Any, TypeVar, Union
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
from jax._src import ad_util
|
|
from jax._src import api_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import debugging
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import linear_util as lu
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import util
|
|
from jax._src.core import pvary
|
|
from jax._src.core import Tracer, typeof
|
|
from jax._src.mesh import (AbstractMesh, Mesh, AxisType, use_abstract_mesh,
|
|
get_abstract_mesh, get_concrete_mesh)
|
|
from jax._src.api import _shared_code_pmap, _prepare_pmap
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import hlo, sdy
|
|
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
|
as_hashable_function, memoize, partition_list,
|
|
merge_lists, split_list, subs_list2,
|
|
fun_name as util_fun_name)
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.interpreters import ad
|
|
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
|
tree_structure, tree_leaves, keystr)
|
|
from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef,
|
|
generate_key_paths, KeyPath)
|
|
from jax.experimental.multihost_utils import (host_local_array_to_global_array,
|
|
global_array_to_host_local_array)
|
|
|
|
P = PartitionSpec
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
# API
|
|
|
|
Specs = Any # PyTree[PartitionSpec]
|
|
AxisName = Hashable
|
|
|
|
|
|
def shard_map(f=None, /, *, out_specs: Specs, axis_names: Set[AxisName] = set(),
|
|
in_specs: Specs | None = None,
|
|
mesh: Mesh | AbstractMesh | None = None, check_vma: bool = True):
|
|
"""Map a function over shards of data using a mesh of devices.
|
|
|
|
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
|
|
|
|
Args:
|
|
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
|
takes as input a shard of the mapped-over arguments and produces a shard
|
|
of the output.
|
|
mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the
|
|
array of devices over which to shard the data and on which to execute
|
|
instances of ``f``. The names of the ``Mesh`` can be used in collective
|
|
communication operations in ``f``. If mesh is None, it will be inferred
|
|
from the context which can be set via `jax.sharding.use_mesh` context
|
|
manager.
|
|
in_specs: (optional, default None) a pytree with
|
|
``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure
|
|
that is a tree prefix of the args tuple to be mapped over. Similar to
|
|
``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the
|
|
corresponding argument (or subtree of arguments) should be sharded along
|
|
the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a
|
|
``mesh`` axis name at a position expresses sharding the corresponding
|
|
argument array axis along that positional axis; not mentioning an axis
|
|
name expresses replication. If ``None``, all mesh axes must be of type
|
|
`Explicit`, in which case the in_specs are inferred from the argument types.
|
|
out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree
|
|
structure that is a tree prefix of the output of ``f``. Each
|
|
``PartitionSpec`` represents how the corresponding output shards should be
|
|
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name
|
|
at a position expresses concatenation of that mesh axis's shards along the
|
|
corresponding positional axis; not mentioning a ``mesh`` axis name
|
|
expresses a promise that the output values are equal along that mesh axis,
|
|
and that rather than concatenating only a single value should be produced.
|
|
axis_names: (optional, default set()) set of axis names from ``mesh`` over
|
|
which the function ``f`` is manual. If empty, ``f``, is manual
|
|
over all mesh axes.
|
|
check_vma: (optional) boolean (default True) representing whether to enable
|
|
additional validity checks and automatic differentiation optimizations.
|
|
The validity checks concern whether any mesh axis names not mentioned in
|
|
``out_specs`` are consistent with how the outputs of ``f`` are replicated.
|
|
|
|
Returns:
|
|
A callable representing a mapped version of ``f``, which accepts positional
|
|
arguments corresponding to those of ``f`` and produces output corresponding
|
|
to that of ``f``.
|
|
"""
|
|
kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs,
|
|
axis_names=axis_names, check_vma=check_vma)
|
|
if f is None:
|
|
return lambda g: _shard_map(g, **kwargs)
|
|
return _shard_map(f, **kwargs)
|
|
|
|
def _axes_to_pspec(axis_name, axis):
|
|
if axis is None:
|
|
return P()
|
|
return P(*[None] * axis + [axis_name])
|
|
|
|
class InferFromArgs:
|
|
|
|
def __repr__(self):
|
|
return "jax.sharding.Infer"
|
|
|
|
def __reduce__(self):
|
|
return (_get_default_infer, ())
|
|
|
|
Infer = InferFromArgs()
|
|
|
|
def _get_default_infer():
|
|
return Infer
|
|
|
|
# TODO(yashkatariya): We need a singleton which users can provide to `in_axes`
|
|
# to tell smap to infer in_specs from args when mesh is fully explicit.
|
|
def smap(f, /, *, in_axes=Infer, out_axes, axis_name: AxisName):
|
|
if isinstance(axis_name, (list, tuple)):
|
|
raise TypeError(
|
|
f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}")
|
|
if (in_axes is not None and in_axes is not Infer and
|
|
not isinstance(in_axes, (int, tuple))):
|
|
raise TypeError(
|
|
"smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of"
|
|
" entries corresponding to the positional arguments passed to the"
|
|
f" function, but got {in_axes}.")
|
|
if (in_axes is not Infer and
|
|
not all(isinstance(l, int) for l in tree_leaves(in_axes))):
|
|
raise TypeError(
|
|
"smap in_axes must be an int, None, jax.sharding.Infer, or (nested)"
|
|
f" container with those types as leaves, but got {in_axes}.")
|
|
if not all(isinstance(l, int) for l in tree_leaves(out_axes)):
|
|
raise TypeError("smap out_axes must be an int, None, or (nested) container "
|
|
f"with those types as leaves, but got {out_axes}.")
|
|
|
|
in_specs = (None if in_axes is Infer else
|
|
tree_map(partial(_axes_to_pspec, axis_name), in_axes,
|
|
is_leaf=lambda x: x is None))
|
|
out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes,
|
|
is_leaf=lambda x: x is None)
|
|
return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs,
|
|
axis_names={axis_name}, check_vma=True, _smap=True)
|
|
|
|
|
|
def _shard_map(f: Callable, *, mesh: Mesh | AbstractMesh | None,
|
|
in_specs: Specs, out_specs: Specs | Callable[[], Specs],
|
|
axis_names: Set[AxisName], check_vma: bool,
|
|
_skip_mesh_check: bool = False, _smap: bool = False) -> Callable:
|
|
if not callable(f):
|
|
raise TypeError("shard_map requires a callable for its first argument, "
|
|
f"but got {f} of type {type(f)}.")
|
|
|
|
@util.wraps(f)
|
|
@traceback_util.api_boundary
|
|
def wrapped(*args):
|
|
nonlocal mesh, axis_names
|
|
mesh, axis_names = _shmap_checks(mesh, axis_names, in_specs, out_specs,
|
|
_skip_mesh_check, _smap)
|
|
fun = lu.wrap_init(
|
|
f, debug_info=api_util.debug_info("shard_map", f, args, {}))
|
|
args_flat, in_tree = tree_flatten(args)
|
|
fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
|
|
|
|
try:
|
|
in_specs_flat = broadcast_prefix(
|
|
in_specs, args, is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
e, *_ = prefix_errors(in_specs, args)
|
|
raise e('shard_map in_specs') from None
|
|
|
|
if (in_specs is None and
|
|
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
|
|
arg_s = [typeof(a).sharding for a in args_flat]
|
|
assert all(i is None for i in in_specs_flat), in_specs_flat
|
|
in_specs_flat = [_manual_spec(axis_names, s.spec) for s in arg_s]
|
|
|
|
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
|
|
if s is not None)
|
|
fun, args_flat = api_util.argnums_partial(fun, dyn_argnums, args_flat, False)
|
|
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat,
|
|
args_flat)
|
|
|
|
@memoize
|
|
def out_specs_thunk():
|
|
if callable(out_specs):
|
|
out_specs_ = out_specs()
|
|
_check_specs(SpecErrorType.out, out_specs_, axis_names)
|
|
else:
|
|
out_specs_ = out_specs
|
|
dummy = tree_unflatten(out_tree(), [object()] * out_tree().num_leaves)
|
|
try:
|
|
out_specs_flat = broadcast_prefix(out_specs_, dummy)
|
|
except ValueError:
|
|
e, *_ = prefix_errors(out_specs_, dummy)
|
|
raise e('shard_map out_specs') from None
|
|
return tuple(out_specs_flat)
|
|
|
|
if check_vma:
|
|
fun = _implicit_pvary_on_output(fun, out_specs_thunk)
|
|
|
|
try:
|
|
out_flat = shard_map_p.bind(
|
|
fun, *args_flat, mesh=mesh, in_specs=in_specs_flat,
|
|
out_specs_thunk=out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=axis_names)
|
|
except _SpecError as e:
|
|
fails, = e.args
|
|
if not callable(out_specs):
|
|
msg = _spec_rank_error(SpecErrorType.out, f, out_tree(), out_specs, fails)
|
|
if any(fail is not no_fail and not fail.shape for fail in fails):
|
|
msg += (" In particular, for rank 0 outputs which are not constant "
|
|
"over the mesh, add at least one (singleton) axis to them so "
|
|
"that they can be concatenated using out_specs.")
|
|
raise ValueError(msg) from None
|
|
except _RepError as e:
|
|
fails, = e.args
|
|
if not callable(out_specs):
|
|
msg = _inout_vma_error(f, mesh, out_tree(), out_specs, fails)
|
|
raise ValueError(msg) from None
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
return wrapped
|
|
|
|
|
|
def _shmap_checks(mesh, axis_names, in_specs, out_specs, _skip_mesh_check,
|
|
_smap):
|
|
if mesh is None:
|
|
mesh = get_abstract_mesh()
|
|
if mesh.empty:
|
|
raise ValueError(
|
|
"The context mesh cannot be empty. Use"
|
|
" `jax.sharding.use_mesh(mesh)` to enter into a mesh context")
|
|
else:
|
|
ctx_mesh = get_abstract_mesh()
|
|
if (not _skip_mesh_check and not ctx_mesh.empty and
|
|
mesh.abstract_mesh != ctx_mesh):
|
|
raise ValueError(
|
|
f"The context mesh {ctx_mesh} should match the mesh passed to"
|
|
f" shard_map {mesh}")
|
|
|
|
if not isinstance(mesh, (Mesh, AbstractMesh)):
|
|
raise TypeError("shard_map requires a `jax.sharding.Mesh` or a "
|
|
"`jax.sharding.AbstractMesh` instance for its "
|
|
f"second argument, but got {mesh} of type {type(mesh)}.")
|
|
|
|
if not isinstance(axis_names, (frozenset, set)):
|
|
raise TypeError(
|
|
"`axis_names` argument of shard_map should be of type `frozenset` or"
|
|
f" `set`. Got type: {type(axis_names)}")
|
|
if isinstance(axis_names, set):
|
|
axis_names = frozenset(axis_names)
|
|
if not axis_names:
|
|
axis_names = frozenset(mesh.axis_names)
|
|
if not axis_names.issubset(mesh.axis_names):
|
|
raise ValueError(
|
|
f"jax.shard_map requires axis_names={axis_names} to be a subset of "
|
|
f"mesh.axis_names={mesh.axis_names}")
|
|
|
|
if (in_specs is None and
|
|
not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
|
|
axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names)
|
|
if _smap:
|
|
msg = (f"in_axes was not specified when axis_name={axis_names} was of"
|
|
f" type {axis_types}")
|
|
else:
|
|
msg = ("shard_map in_specs argument must be a pytree of"
|
|
" `jax.sharding.PartitionSpec` instances, but it was `None` when"
|
|
f" {axis_names=} are of type {axis_types}")
|
|
raise TypeError(msg)
|
|
|
|
if in_specs is not None:
|
|
_check_specs(SpecErrorType.input, in_specs, axis_names)
|
|
if not callable(out_specs):
|
|
_check_specs(SpecErrorType.out, out_specs, axis_names)
|
|
return mesh, axis_names
|
|
|
|
def _manual_spec(manual_axes, spec: P) -> P:
|
|
out = [] # type: ignore
|
|
for s in spec:
|
|
if s is None:
|
|
out.append(s)
|
|
elif isinstance(s, tuple):
|
|
temp = [p if p in manual_axes else None for p in s]
|
|
while temp and temp[-1] is None:
|
|
temp.pop()
|
|
if None in temp:
|
|
raise ValueError(f"Invalid spec: {spec}")
|
|
out.append(None if len(temp) == 0 else tuple(temp))
|
|
else:
|
|
out.append(s if s in manual_axes else None)
|
|
return P(*out)
|
|
|
|
|
|
# Error checking and messages
|
|
|
|
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
|
|
|
|
def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None:
|
|
if error_type == SpecErrorType.input and specs is None:
|
|
raise TypeError(
|
|
"shard_map in_specs argument must be a pytree of "
|
|
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
|
|
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
|
|
"where `P = jax.sharding.PartitionSpec`?")
|
|
|
|
def check_spec(p):
|
|
if not isinstance(p, PartitionSpec):
|
|
return False
|
|
for names in p:
|
|
names = (names,) if not isinstance(names, tuple) else names
|
|
for name in names:
|
|
if name is not None and name not in manual_axes:
|
|
return False
|
|
return True
|
|
|
|
if all(check_spec(p) for p in tree_leaves(specs)):
|
|
return
|
|
prefix = 'in' if error_type == SpecErrorType.input else 'out'
|
|
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
|
|
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
|
|
if not msgs:
|
|
for key, p in generate_key_paths(specs):
|
|
for names in p:
|
|
names = (names,) if not isinstance(names, tuple) else names
|
|
for name in names:
|
|
if name is not None and name not in manual_axes:
|
|
msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}")
|
|
raise ValueError(
|
|
f"shard_map {prefix}_specs argument must refer to an axis "
|
|
f"marked as manual ({manual_axes}), but:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n'
|
|
f"Check the {prefix}_specs values passed to shard_map.")
|
|
raise TypeError(
|
|
f"shard_map {prefix}_specs argument must be a pytree of "
|
|
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n'
|
|
f"Check the {prefix}_specs values passed to shard_map.")
|
|
|
|
class NoFail: pass
|
|
no_fail = NoFail()
|
|
|
|
def _check_specs_vs_args(
|
|
f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs,
|
|
dyn_argnums: Sequence[int], in_specs_flat: Sequence[P],
|
|
xs: Sequence) -> None:
|
|
in_avals = map(core.shaped_abstractify, xs)
|
|
fail = [a if not len(p) <= a.ndim else no_fail
|
|
for p, a in zip(in_specs_flat, in_avals)]
|
|
if any(f is not no_fail for f in fail):
|
|
fail = _expand_fail(in_tree, dyn_argnums, fail)
|
|
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
|
|
raise ValueError(msg)
|
|
in_names_flat = tuple(map(_spec_to_names, in_specs_flat))
|
|
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
|
|
for d, ns in names.items()) else no_fail
|
|
for a, names in zip(in_avals, in_names_flat)]
|
|
if any(f is not no_fail for f in fail):
|
|
fail = _expand_fail(in_tree, dyn_argnums, fail)
|
|
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
|
|
raise ValueError(msg)
|
|
|
|
def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int],
|
|
fail: Sequence[core.ShapedArray | NoFail]
|
|
) -> list[core.ShapedArray | NoFail]:
|
|
fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves
|
|
for i, f in zip(dyn_argnums, fail):
|
|
fail_[i] = f
|
|
return fail_
|
|
|
|
def _spec_rank_error(
|
|
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
|
|
fails: list[core.ShapedArray | NoFail]) -> str:
|
|
fun_name = util_fun_name(f)
|
|
if error_type == SpecErrorType.input:
|
|
prefix, base = 'in', 'args'
|
|
ba = _try_infer_args(f, tree)
|
|
else:
|
|
prefix, base = 'out', f'{fun_name}(*args)'
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
|
extra = ""
|
|
if error_type == SpecErrorType.input and ba is not None:
|
|
arg_key, *_ = fail_key
|
|
if arg_key.idx < len(ba.arguments):
|
|
param_name = list(ba.arguments.keys())[arg_key.idx]
|
|
extra = (f", where {base}{arg_key} is bound to {fun_name}'s "
|
|
f"parameter '{param_name}',")
|
|
else:
|
|
param = list(ba.signature.parameters.values())[-1]
|
|
assert param.kind == inspect.Parameter.VAR_POSITIONAL
|
|
extra = (f", where {base}{arg_key} is the index "
|
|
f"{arg_key.idx - len(ba.signature.parameters) + 1} component "
|
|
f"of {fun_name}'s varargs parameter '{param.name}',")
|
|
msgs.append(
|
|
f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length "
|
|
f"{len(spec)}, but "
|
|
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
|
|
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given an "
|
|
f"{prefix}_specs entry which is too long to be compatible with the "
|
|
f"corresponding {prefix}put value from the function:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
f"Entries in {prefix}_specs must be of length no greater than the "
|
|
f"number of axes in the corresponding {prefix}put value.\n\n"
|
|
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
|
|
f"that its {prefix}puts have sufficient rank.")
|
|
if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)):
|
|
msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs "
|
|
"entry of `P()`, where `P = jax.sharding.PartitionSpec`.")
|
|
return msg
|
|
|
|
def _spec_divisibility_error(
|
|
f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs,
|
|
fails: list[core.ShapedArray | NoFail]) -> str:
|
|
ba = _try_infer_args(f, tree)
|
|
fun_name = getattr(f, '__name__', str(f))
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
|
extra = ""
|
|
if ba is not None:
|
|
arg_key, *_ = fail_key
|
|
if arg_key.idx < len(ba.arguments):
|
|
param_name = list(ba.arguments.keys())[arg_key.idx]
|
|
extra = (f", where args{arg_key} is bound to {fun_name}'s "
|
|
f"parameter '{param_name}',")
|
|
else:
|
|
param = list(ba.signature.parameters.values())[-1]
|
|
assert param.kind == inspect.Parameter.VAR_POSITIONAL
|
|
extra = (f", where args{arg_key} is the index "
|
|
f"{arg_key.idx - len(ba.signature.parameters) + 1} component "
|
|
f"of {fun_name}'s varargs parameter '{param.name}',")
|
|
names = _spec_to_names(spec)
|
|
for d, ns in names.items():
|
|
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
|
|
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
|
|
total = 'total ' if len(ns) > 1 else ''
|
|
sz = prod(mesh.shape[n] for n in ns)
|
|
msgs.append(
|
|
f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
|
|
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
|
|
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
|
|
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
|
|
f"{aval.shape[d]}")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
|
|
f"arrays with axis sizes that are not evenly divisible by the "
|
|
f"corresponding mesh axis sizes:\n\n"
|
|
f"The mesh given has shape {tuple(mesh.shape.values())} with "
|
|
f"corresponding axis names {mesh.axis_names}.\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
f"Array arguments' axis sizes must be evenly divisible by the mesh "
|
|
f"axis or axes indicated by the corresponding elements of the "
|
|
f"argument's in_specs entry. Consider checking that in_specs are "
|
|
f"correct, and if so consider changing the mesh axis sizes or else "
|
|
f"padding the input and adapting '{fun_name}' appropriately.")
|
|
return msg
|
|
|
|
def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef,
|
|
specs: Specs, fails: list[set | NoFail]) -> str:
|
|
fun_name = getattr(f, '__name__', str(f))
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, vma) in _iter_paths(tree, specs, fails):
|
|
unmentioned = _unmentioned(mesh, spec)
|
|
if len(unmentioned) > 1:
|
|
need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec))))
|
|
got_vma = ','.join(map(str, order_wrt_mesh(mesh, vma)))
|
|
diff = ','.join(map(str, order_wrt_mesh(
|
|
mesh, [n for n in unmentioned if n in vma])))
|
|
msgs.append(
|
|
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
|
|
f"corresponding output value is only varying across mesh axes "
|
|
f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be "
|
|
f"possibly varying over {{{got_vma}}}")
|
|
else:
|
|
need_rep_, = unmentioned
|
|
msgs.append(
|
|
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
|
|
f"corresponding output value is replicated across mesh axis "
|
|
f"'{need_rep_}', but could not infer replication over any axes")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given "
|
|
f"out_specs which require replication which can't be statically "
|
|
f"inferred given the mesh:\n\n"
|
|
f"The mesh given has shape {tuple(mesh.shape.values())} with "
|
|
f"corresponding axis names {mesh.axis_names}.\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
"Check if these output values are meant to be replicated over those "
|
|
"mesh axes. If not, consider revising the corresponding out_specs "
|
|
"entries. If so, consider disabling the check by passing the "
|
|
"check_vma=False argument to `jax.shard_map`.")
|
|
return msg
|
|
|
|
def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]:
|
|
vma_set = _spec_to_vma(spec)
|
|
return [n for n in mesh.axis_names if n not in vma_set]
|
|
|
|
|
|
def _try_infer_args(f, tree):
|
|
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
|
try:
|
|
return inspect.signature(f).bind(*dummy_args)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
T = TypeVar('T')
|
|
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail]
|
|
) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]:
|
|
failures = tree_unflatten(tree, fails)
|
|
failures_aug = generate_key_paths(failures)
|
|
specs_ = tree_unflatten(tree_structure(specs), generate_key_paths(specs))
|
|
leaf = lambda x: x is None or type(x) is tuple and len(x) == 2 and type(x[1]) is P
|
|
specs_aug = broadcast_prefix(specs_, failures, is_leaf=leaf)
|
|
return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data)
|
|
in zip(specs_aug, failures_aug)
|
|
if s is not None and fail_data is not no_fail]
|
|
|
|
# Primitive
|
|
|
|
@lu.transformation2
|
|
def _implicit_pvary_on_output(f, out_specs_thunk, *args, **kwargs):
|
|
out_flat = f(*args, **kwargs)
|
|
return [pvary(o, tuple(_spec_to_vma(sp) - typeof(o).vma))
|
|
for o, sp in zip(out_flat, out_specs_thunk())]
|
|
|
|
JaxType = Any
|
|
MaybeTracer = Union[JaxType, Tracer]
|
|
|
|
class ShardMapPrimitive(core.Primitive):
|
|
multiple_results = True
|
|
|
|
def bind(self, *args, **params):
|
|
return self._true_bind(*args, **params)
|
|
|
|
def bind_with_trace(self, trace, fun_and_args, params):
|
|
fun: lu.WrappedFun
|
|
fun, *args = fun_and_args
|
|
return trace.process_shard_map(shard_map_p, fun, args, **params)
|
|
|
|
def get_bind_params(self, params):
|
|
new_params = dict(params)
|
|
jaxpr: core.Jaxpr = new_params.pop('jaxpr')
|
|
subfun = lu.hashable_partial(lu.wrap_init(core.eval_jaxpr,
|
|
debug_info=jaxpr.debug_info),
|
|
jaxpr, ())
|
|
axes = new_params.pop('out_specs')
|
|
new_params['out_specs_thunk'] = HashableFunction(lambda: axes, closure=axes)
|
|
return [subfun], new_params
|
|
|
|
shard_map_p = ShardMapPrimitive('shard_map')
|
|
|
|
# Staging
|
|
|
|
@util.cache(max_size=256, trace_context_in_key=True)
|
|
def _as_manual_mesh(mesh, manual_axes: frozenset):
|
|
not_manual = set(mesh.axis_names) - manual_axes
|
|
cur_mesh = get_abstract_mesh()
|
|
if cur_mesh.empty:
|
|
cur_mesh = mesh
|
|
explicit_axes, auto_axes = set(), set() # type: ignore
|
|
for a in not_manual:
|
|
if cur_mesh._name_to_type[a] == AxisType.Auto:
|
|
auto_axes.add(a)
|
|
else:
|
|
assert cur_mesh._name_to_type[a] == AxisType.Explicit, (
|
|
a, cur_mesh._name_to_type[a])
|
|
explicit_axes.add(a)
|
|
|
|
new_axis_types = []
|
|
for n in mesh.axis_names:
|
|
if n in manual_axes:
|
|
new_axis_types.append(AxisType.Manual)
|
|
elif n in auto_axes:
|
|
new_axis_types.append(AxisType.Auto)
|
|
else:
|
|
assert n in explicit_axes
|
|
new_axis_types.append(AxisType.Explicit)
|
|
return AbstractMesh(mesh.axis_sizes, mesh.axis_names,
|
|
axis_types=tuple(new_axis_types))
|
|
|
|
|
|
def _extend_axis_env(mesh, manual_axes):
|
|
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
|
|
if k in manual_axes])
|
|
|
|
def _shard_map_staging(
|
|
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
|
|
in_tracers: Sequence[Any], *, mesh: Mesh,
|
|
in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset,
|
|
) -> Sequence[pe.DynamicJaxprTracer]:
|
|
source_info = source_info_util.current()
|
|
to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info)
|
|
in_tracers = map(to_jaxpr_tracer, in_tracers)
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
|
|
in_avals = [t.aval for t in in_tracers]
|
|
in_avals_ = map(partial(_shard_aval, mesh, manual_axes, check_vma), in_specs,
|
|
in_avals)
|
|
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
|
|
config._check_vma(check_vma)):
|
|
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
|
|
_check_names(out_specs_thunk(), out_avals_)
|
|
if check_vma:
|
|
out_vma = [v.aval.vma for v in jaxpr.outvars]
|
|
_check_vmas(mesh, out_specs_thunk(), out_vma)
|
|
out_avals = map(_check_shapedarray, out_avals_)
|
|
out_avals = [_check_shapedarray(_unshard_aval(mesh, check_vma, spec, aval))
|
|
for spec, aval in zip(out_specs_thunk(), out_avals)]
|
|
out_tracers = [pe.DynamicJaxprTracer(trace, a, source_info) for a in out_avals]
|
|
invars = map(trace.getvar, in_tracers)
|
|
constvars = map(trace.getvar, map(to_jaxpr_tracer, consts))
|
|
outvars = map(trace.makevar, out_tracers)
|
|
in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore
|
|
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
|
|
config._check_vma(check_vma)):
|
|
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
|
params = dict(mesh=mesh, in_specs=in_specs_staged,
|
|
out_specs=tuple(out_specs_thunk()), jaxpr=jaxpr,
|
|
check_vma=check_vma, manual_axes=manual_axes)
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
|
|
effs, source_info)
|
|
trace.frame.add_eqn(eqn)
|
|
return out_tracers
|
|
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
|
|
|
# TODO add underscore version, for direct-linearize to consume
|
|
|
|
def _spec_to_names(spec: PartitionSpec):
|
|
return {i: names if isinstance(names, tuple) else (names,)
|
|
for i, names in enumerate(spec) if names is not None}
|
|
|
|
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
return aval
|
|
|
|
def _shard_aval(mesh: Mesh, manual_axes, check_vma, spec,
|
|
aval: core.AbstractValue) -> core.AbstractValue:
|
|
if type(aval) in core.shard_aval_handlers:
|
|
return core.shard_aval_handlers[type(aval)](mesh, manual_axes, check_vma,
|
|
spec, aval)
|
|
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
|
|
|
|
def _unshard_aval(mesh: Mesh, check_vma, spec,
|
|
aval: core.AbstractValue) -> core.AbstractValue:
|
|
if type(aval) in core.unshard_aval_handlers:
|
|
return core.unshard_aval_handlers[type(aval)](mesh, check_vma, spec, aval)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")
|
|
|
|
def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
|
|
spec, aval: core.AbstractValue) -> core.AbstractValue:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
names = _spec_to_names(spec)
|
|
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
|
for i, sz in enumerate(aval.shape))
|
|
manual_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
|
|
new_sharding = NamedSharding(manual_mesh, aval.sharding.spec)
|
|
vma = _spec_to_vma(spec) if check_vma else frozenset()
|
|
vma = vma | aval.vma
|
|
return aval.update(shape=new_shape, sharding=new_sharding, vma=vma)
|
|
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
|
|
|
|
def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.AbstractValue
|
|
) -> core.AbstractValue:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
names = _spec_to_names(spec)
|
|
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
|
for i, sz in enumerate(aval.shape))
|
|
names_spec = spec._normalized_spec_for_aval(aval.ndim)
|
|
if aval.ndim == 0:
|
|
out_spec = P()
|
|
else:
|
|
out_spec = [] # type: ignore
|
|
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
|
|
if name_s and not aval_s:
|
|
out_spec.append(name_s)
|
|
elif aval_s and not name_s:
|
|
out_spec.append(aval_s)
|
|
elif not name_s and not aval_s:
|
|
out_spec.append(None)
|
|
else:
|
|
assert name_s and aval_s
|
|
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
|
|
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
|
|
out_spec.append(name_s + aval_s)
|
|
out_spec = PartitionSpec(*out_spec)
|
|
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
|
|
get_abstract_mesh())
|
|
new_sharding = NamedSharding(new_mesh, out_spec)
|
|
manual_axes = set(new_mesh.manual_axes)
|
|
vma = (frozenset(v for v in aval.vma if v in manual_axes)
|
|
if check_vma else frozenset())
|
|
return aval.update(shape=new_shape, sharding=new_sharding, vma=vma)
|
|
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
|
|
|
|
# Type-checking
|
|
|
|
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
# TODO(mattjj,parkers): check auto
|
|
for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs):
|
|
if not core.typecompat(v.aval, _shard_aval(
|
|
mesh, manual_axes, check_vma, in_spec, x.aval)):
|
|
raise core.JaxprTypeError("shard_map argument avals not compatible with "
|
|
"jaxpr binder avals and in_specs")
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
core.check_jaxpr(jaxpr)
|
|
if check_vma:
|
|
out_vma = [v.aval.vma for v in jaxpr.outvars]
|
|
for vma, out_spec in zip(out_vma, out_specs):
|
|
if not _valid_repeats(mesh, vma, out_spec):
|
|
raise core.JaxprTypeError(
|
|
"shard_map can't prove output is sufficiently replicated")
|
|
out_avals_sharded = [x.aval for x in jaxpr.outvars]
|
|
out_avals = map(partial(_unshard_aval, mesh, check_vma), out_specs,
|
|
out_avals_sharded)
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
return out_avals, effs
|
|
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
|
|
|
|
|
|
def _valid_repeats(mesh: Mesh, vma: Set[AxisName], spec) -> bool:
|
|
um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes)
|
|
if any(u in vma for u in um):
|
|
return False
|
|
return True
|
|
|
|
# Lowering
|
|
|
|
def _shardy_shard_map_sharding(
|
|
ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in
|
|
) -> sharding_impls.SdyArray:
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_in, ns)
|
|
aval_in = core.physical_aval(aval_in)
|
|
sdy_sharding = ns._to_sdy_sharding(aval_in.ndim)
|
|
if len(manual_axes) < len(mesh.axis_names):
|
|
for dim_sharding in sdy_sharding.dim_shardings:
|
|
dim_sharding.is_open = True
|
|
return sdy_sharding
|
|
|
|
|
|
def _shardy_shard_map_token_sharding(
|
|
ctx: mlir.LoweringRuleContext, mesh
|
|
) -> ir.Attribute:
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, P())
|
|
return ns._to_sdy_sharding(0)
|
|
|
|
|
|
def _get_spmdaxis_ctx_mesh(mesh):
|
|
if isinstance(mesh, AbstractMesh):
|
|
concrete_mesh = get_concrete_mesh()
|
|
return concrete_mesh if concrete_mesh is not None else mesh
|
|
return mesh
|
|
|
|
|
|
def _shard_map_lowering_shardy(
|
|
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma):
|
|
axis_ctx = ctx.module_context.axis_context
|
|
in_avals_ = [v.aval for v in jaxpr.invars]
|
|
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
|
# Nested `ManualComputationOp`s cannot refer to axes that are already
|
|
# manual. So figure out what axes are free thus far.
|
|
shardy_manual_axes = frozenset(mesh.axis_names) - axis_ctx.manual_axes
|
|
else:
|
|
shardy_manual_axes = manual_axes
|
|
new_axis_context = sharding_impls.SPMDAxisContext(
|
|
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
|
|
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
|
|
|
tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()]
|
|
num_tokens = len(tokens)
|
|
manual_axes = order_wrt_mesh(mesh, shardy_manual_axes)
|
|
if np.prod([mesh.shape[a] for a in manual_axes]) == 1:
|
|
# No need for a `ManualComputationOp` if all manual axes are size 1.
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
out_nodes, tokens_out = mlir.jaxpr_subcomp(
|
|
sub_ctx, jaxpr, ctx.name_stack,
|
|
mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)),
|
|
(), *in_nodes,
|
|
dim_var_values=ctx.dim_var_values)
|
|
ctx.set_tokens_out(tokens_out)
|
|
return out_nodes
|
|
|
|
in_shardings = list(
|
|
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
|
|
in_specs, ctx.avals_in))
|
|
num_dim_vars = len(ctx.dim_var_values)
|
|
in_shardings = ([_shardy_shard_map_token_sharding(ctx, mesh)]
|
|
* (num_tokens + num_dim_vars) + in_shardings)
|
|
in_shardings = sharding_impls.SdyArrayList(in_shardings).build()
|
|
|
|
out_shardings = list(
|
|
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
|
|
out_specs, ctx.avals_out))
|
|
out_shardings = [
|
|
_shardy_shard_map_token_sharding(ctx, mesh)] * num_tokens + out_shardings
|
|
out_shardings = sharding_impls.SdyArrayList(out_shardings).build()
|
|
|
|
output_types = ([hlo.TokenType.get()] * num_tokens +
|
|
list(map(mlir.aval_to_ir_type, ctx.avals_out)))
|
|
|
|
args = (*ctx.dim_var_values, *tokens, *in_nodes)
|
|
manual_computation_op = sdy.ManualComputationOp(
|
|
output_types,
|
|
mlir.flatten_ir_values(args),
|
|
in_shardings, out_shardings,
|
|
sdy.ManualAxesAttr.get(
|
|
ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes])))
|
|
block = ir.Block.create_at_start(
|
|
manual_computation_op.body,
|
|
(*(i if isinstance(i, ir.Type) else i.type for i in ctx.dim_var_values),
|
|
*([hlo.TokenType.get()] * num_tokens),
|
|
*map(mlir.aval_to_ir_type, in_avals_)))
|
|
with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes),
|
|
config._check_vma(check_vma)):
|
|
out_nodes_, tokens_out = mlir.jaxpr_subcomp(
|
|
sub_ctx, jaxpr, ctx.name_stack,
|
|
mlir.TokenSet(zip(
|
|
ctx.tokens_in.effects(), block.arguments[:num_tokens])),
|
|
(), *block.arguments[num_tokens+num_dim_vars:],
|
|
dim_var_values=ctx.dim_var_values)
|
|
sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()],
|
|
*out_nodes_)])
|
|
num_tokens = len(tokens_out.effects())
|
|
tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip(
|
|
ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens])))
|
|
ctx.set_tokens_out(tokens_out)
|
|
|
|
return manual_computation_op.results[num_tokens:]
|
|
|
|
|
|
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
if config.use_shardy_partitioner.value:
|
|
return _shard_map_lowering_shardy(
|
|
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma)
|
|
|
|
in_avals_ = [v.aval for v in jaxpr.invars]
|
|
out_avals_ = [x.aval for x in jaxpr.outvars]
|
|
in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs,
|
|
ctx.avals_in, in_avals_, in_nodes)
|
|
new_axis_context = sharding_impls.SPMDAxisContext(
|
|
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
|
|
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
out_nodes_, tokens_out = mlir.call_lowering(
|
|
"shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_,
|
|
out_avals_, ctx.tokens_in, *in_nodes_,
|
|
dim_var_values=ctx.dim_var_values,
|
|
arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_),
|
|
result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_))
|
|
ctx.set_tokens_out(tokens_out)
|
|
return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs,
|
|
out_avals_, ctx.avals_out, out_nodes_)
|
|
mlir.register_lowering(shard_map_p, _shard_map_lowering)
|
|
|
|
def _make_scoped_manual_sharding(ctx, mesh, spec):
|
|
axis_ctx = ctx.module_context.axis_context
|
|
mesh = mesh.abstract_mesh
|
|
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
|
mesh = mesh.update_axis_types(
|
|
{a: AxisType.Manual for a in axis_ctx.manual_axes})
|
|
return NamedSharding(mesh, spec)
|
|
|
|
def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
|
|
aval_in, aval_out, x):
|
|
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
|
|
return x
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_in, ns)
|
|
aval_in = core.physical_aval(aval_in)
|
|
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
|
|
unspecified = (set(range(aval_in.ndim))
|
|
if len(manual_axes) < len(mesh.axis_names) else set())
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
|
|
unspecified_dims=unspecified)
|
|
manual_proto = pxla.manual_proto(
|
|
aval_in, manual_axes | set(mesh.manual_axes), mesh)
|
|
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto,
|
|
unspecified)
|
|
|
|
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
|
|
aval_in, aval_out, x):
|
|
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
|
|
return x
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_out, ns)
|
|
aval_out = core.physical_aval(aval_out)
|
|
unspecified = (set(range(aval_in.ndim))
|
|
if len(manual_axes) < len(mesh.axis_names) else set())
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
aval_in = core.physical_aval(aval_in)
|
|
manual_proto = pxla.manual_proto(
|
|
aval_in, manual_axes | set(mesh.manual_axes), mesh)
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto,
|
|
unspecified_dims=unspecified)
|
|
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
|
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
|
|
unspecified)
|
|
|
|
def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str:
|
|
if isinstance(aval, core.ShapedArray):
|
|
names = _spec_to_names(spec)
|
|
return str(map(names.get, range(aval.ndim)))
|
|
return ''
|
|
|
|
# Eager evaluation
|
|
|
|
def get_mesh_from_args(args_flat, mesh):
|
|
for a in args_flat:
|
|
if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding):
|
|
if a.sharding.mesh.shape_tuple != mesh.shape_tuple:
|
|
aval = core.shaped_abstractify(a)
|
|
raise ValueError(
|
|
f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not"
|
|
" match the mesh shape passed to shard_map "
|
|
f" {mesh.shape_tuple} for shape {aval.str_short()}")
|
|
mesh = a.sharding.mesh
|
|
if isinstance(mesh, AbstractMesh):
|
|
raise ValueError(
|
|
"Please pass `jax.Array`s with a `NamedSharding` as input to"
|
|
" `shard_map` when passing `AbstractMesh` to the mesh argument.")
|
|
assert isinstance(mesh, Mesh)
|
|
return mesh
|
|
|
|
def _vma_to_spec(mesh, vma):
|
|
return P(order_wrt_mesh(mesh, vma))
|
|
|
|
def _spec_to_vma(spec):
|
|
return frozenset(p for s in spec if s is not None
|
|
for p in (s if isinstance(s, tuple) else (s,)))
|
|
|
|
def order_wrt_mesh(mesh, x):
|
|
return tuple(a for a in mesh.axis_names if a in x)
|
|
|
|
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk,
|
|
check_vma, manual_axes):
|
|
if len(manual_axes) < len(mesh.axis_names):
|
|
raise NotImplementedError
|
|
del prim
|
|
if isinstance(mesh, AbstractMesh):
|
|
concrete_mesh = get_concrete_mesh()
|
|
mesh = concrete_mesh if concrete_mesh is not None else mesh
|
|
mesh = get_mesh_from_args(args, mesh)
|
|
cur_mesh = get_abstract_mesh()
|
|
args = map(partial(_unmatch_spec, mesh, check_vma, context_mesh=cur_mesh),
|
|
in_specs, args)
|
|
in_vma = map(_spec_to_vma, in_specs)
|
|
outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma,
|
|
cur_mesh)
|
|
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs]
|
|
_check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types
|
|
if check_vma:
|
|
_check_vmas(mesh, out_specs_thunk(), out_vma)
|
|
src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma)
|
|
else:
|
|
src_pspecs = tuple(P(mesh.axis_names) for _ in out_vma)
|
|
dst_pspecs = out_specs_thunk()
|
|
return map(partial(_match_spec, mesh, check_vma), src_pspecs, dst_pspecs,
|
|
outs)
|
|
core.EvalTrace.process_shard_map = _shard_map_impl
|
|
|
|
def _run_shmap(f, mesh, manual_axes, args, vmas, check_vma, context_mesh):
|
|
trace = ShardMapTrace(mesh, manual_axes, check_vma, context_mesh)
|
|
in_tracers = map(partial(ShardMapTracer, trace), vmas, args)
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))
|
|
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
|
|
ans = f.call_wrapped(*in_tracers)
|
|
outs, out_vma = unzip2(map(trace.to_val_vma_pair, ans))
|
|
return outs, out_vma
|
|
|
|
|
|
def _unmatch_spec(mesh: Mesh, check_vma, in_spec, x: JaxType, context_mesh
|
|
) -> JaxType:
|
|
with (core.eval_context(), jax.disable_jit(False),
|
|
use_abstract_mesh(context_mesh)):
|
|
return jax.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec))(x)
|
|
|
|
def _unmatch(mesh, check_vma, in_spec, x):
|
|
if check_vma:
|
|
used_axes = _spec_to_vma(in_spec)
|
|
dst = P(order_wrt_mesh(mesh, used_axes))
|
|
else:
|
|
dst = P(mesh.axis_names)
|
|
check_vma = False
|
|
return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,),
|
|
out_specs=dst, check_vma=check_vma)(x)
|
|
|
|
def _check_names(specs, avals: Sequence[core.ShapedArray]) -> None:
|
|
fail = [a if sp and len(sp) > a.ndim else no_fail
|
|
for sp, a in zip(specs, avals)]
|
|
if any(f is not no_fail for f in fail):
|
|
raise _SpecError(fail)
|
|
|
|
class _SpecError(Exception):
|
|
pass
|
|
|
|
def _check_vmas(mesh, specs, vmas):
|
|
fail = [vma if not _valid_repeats(mesh, vma, sp) else no_fail
|
|
for sp, vma in zip(specs, vmas)]
|
|
if any(f is not no_fail for f in fail):
|
|
raise _RepError(fail)
|
|
|
|
class _RepError(Exception):
|
|
pass
|
|
|
|
def _match_spec(mesh: Mesh, check_vma, src_pspec: PartitionSpec,
|
|
dst_pspec: PartitionSpec, x: JaxType) -> JaxType:
|
|
fn = HashablePartial(_match, mesh, check_vma, src_pspec, dst_pspec)
|
|
with core.eval_context(), jax.disable_jit(False):
|
|
return jax.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x)
|
|
|
|
def _match(mesh, check_vma, src_pspec, dst_pspec, x):
|
|
return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec,
|
|
out_specs=dst_pspec, check_vma=check_vma)(x)
|
|
|
|
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
|
|
def _add_singleton(x): return jnp.expand_dims(x, axis=0)
|
|
|
|
def _maybe_check_special(outs):
|
|
if not config.debug_nans.value and not config.debug_infs.value: return
|
|
bufs = [s.data for leaf in tree_leaves(outs)
|
|
for s in getattr(leaf, 'addressable_shards', [])]
|
|
try:
|
|
dispatch.check_special('shard_map', bufs)
|
|
except api_util.InternalFloatingPointError as e:
|
|
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
|
|
|
|
class ShardMapTrace(core.Trace):
|
|
__slots__ = ("mesh", "manual_axes", "check", "context_mesh")
|
|
|
|
mesh: Mesh
|
|
manual_axes: frozenset[AxisName]
|
|
check: bool
|
|
context_mesh: AbstractMesh
|
|
|
|
def __init__(self, mesh, manual_axes, check, context_mesh):
|
|
super().__init__()
|
|
self.mesh = mesh
|
|
self.manual_axes = manual_axes
|
|
self.check = check
|
|
self.context_mesh = context_mesh
|
|
|
|
def to_val_vma_pair(self, val):
|
|
if isinstance(val, ShardMapTracer):
|
|
return val.val, val.vma
|
|
elif isinstance(val, Tracer):
|
|
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
|
|
else:
|
|
val_ = _unmatch_spec(self.mesh, self.check, P(), val, self.context_mesh)
|
|
return val_, frozenset()
|
|
|
|
def process_primitive(self, prim, tracers, params):
|
|
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
|
|
if self.check:
|
|
out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params)
|
|
out_avals = tuple(out_avals) if type(out_avals) is list else out_avals
|
|
out_vma = tree_map(lambda a: a.vma, out_avals)
|
|
in_specs = tuple(map(partial(_vma_to_spec, self.mesh), in_vma))
|
|
out_specs = tree_map(partial(_vma_to_spec, self.mesh), out_vma)
|
|
else:
|
|
out_vma = frozenset()
|
|
in_specs = out_specs = P(self.mesh.axis_names)
|
|
|
|
eager_rule = eager_rules.get(prim)
|
|
if eager_rule:
|
|
out_vals = eager_rule(self.mesh, *in_vals, **params)
|
|
else:
|
|
f = HashablePartial(
|
|
_prim_applier, prim, self.check, tuple(params.items()), self.mesh,
|
|
in_specs, out_specs)
|
|
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
|
|
jax.debug_infs(False), use_abstract_mesh(self.context_mesh)):
|
|
out_vals = jax.jit(f)(*in_vals)
|
|
_maybe_check_special(out_vals)
|
|
if prim.multiple_results:
|
|
out_vma = (out_vma if isinstance(out_vma, (list, tuple))
|
|
else [out_vma] * len(out_vals))
|
|
return map(partial(ShardMapTracer, self), out_vma, out_vals)
|
|
return ShardMapTracer(self, out_vma, out_vals)
|
|
|
|
def process_call(self, call_primitive, fun, tracers, params):
|
|
raise NotImplementedError(
|
|
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
|
|
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
|
|
"function, and open a feature request at "
|
|
"https://github.com/jax-ml/jax/issues !")
|
|
|
|
def process_map(self, map_primitive, fun, tracers, params):
|
|
raise NotImplementedError(
|
|
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
|
|
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
|
"a feature request at https://github.com/jax-ml/jax/issues !")
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
|
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
|
|
del prim, jvp, symbolic_zeros
|
|
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
|
|
out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals,
|
|
in_vma, self.check, self.context_mesh)
|
|
return map(partial(ShardMapTracer, self), out_vma, out_vals)
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
|
symbolic_zeros):
|
|
if symbolic_zeros:
|
|
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
|
|
"implemented; please open an issue at "
|
|
"https://github.com/jax-ml/jax/issues")
|
|
raise NotImplementedError(msg)
|
|
del prim, fwd, bwd, out_trees, symbolic_zeros
|
|
in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers))
|
|
out_vals, out_vma = _run_shmap(fun, self.mesh, self.manual_axes, in_vals,
|
|
in_vma, self.check, self.context_mesh)
|
|
return map(partial(ShardMapTracer, self), out_vma, out_vals)
|
|
|
|
|
|
class ShardMapTracer(core.Tracer):
|
|
vma: frozenset[AxisName]
|
|
val: JaxType
|
|
|
|
def __init__(self, trace, vma, val):
|
|
self._trace = trace
|
|
if isinstance(vma, set):
|
|
vma = frozenset(vma)
|
|
assert isinstance(vma, frozenset)
|
|
self.vma = vma
|
|
self.val = val
|
|
|
|
@property
|
|
def aval(self):
|
|
aval = core.get_aval(self.val)
|
|
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
|
|
new_sharding = NamedSharding(
|
|
_as_manual_mesh(self._trace.mesh, self._trace.manual_axes),
|
|
out.sharding.spec) # pytype: disable=attribute-error
|
|
vma = self.vma if config._check_vma.value else frozenset()
|
|
return out.update(sharding=new_sharding, vma=vma)
|
|
|
|
def to_concrete_value(self):
|
|
if self.vma == frozenset():
|
|
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
|
|
return core.to_concrete_value(self.val[0])
|
|
else:
|
|
return None
|
|
|
|
def __str__(self) -> str:
|
|
pb_names = set(self._trace.mesh.axis_names) - self.vma
|
|
self = pvary(self, tuple(pb_names))
|
|
with core.eval_context(), use_abstract_mesh(self._trace.context_mesh):
|
|
blocks = list(self.val)
|
|
mesh = self._trace.mesh
|
|
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
|
|
return '\n'.join(
|
|
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
|
|
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
|
|
|
|
__repr__ = __str__ # for debuggers, like `p x`
|
|
|
|
def _prim_applier(prim, check_vma, params_tup, mesh, in_specs, out_specs, *args):
|
|
def apply(*args):
|
|
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
|
|
return tree_map(_add_singleton, outs)
|
|
out_specs = list(out_specs) if type(out_specs) is tuple else out_specs
|
|
return shard_map(apply, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
|
|
check_vma=check_vma)(*args)
|
|
|
|
eager_rules: dict[core.Primitive, Callable] = {}
|
|
|
|
|
|
# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually
|
|
def _debug_callback_eager_rule(
|
|
mesh,
|
|
*args,
|
|
callback: Callable[..., Any],
|
|
effect: debugging.DebugEffect,
|
|
partitioned: bool,
|
|
):
|
|
del effect
|
|
with core.eval_context():
|
|
all_blocks = zip(*map(list, args))
|
|
for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks):
|
|
callback(*blocks)
|
|
return []
|
|
|
|
|
|
eager_rules[debugging.debug_callback_p] = _debug_callback_eager_rule
|
|
|
|
def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
|
|
del mesh, srcs, copy_semantics
|
|
for device in devices:
|
|
if device is not None:
|
|
raise ValueError("device_put with explicit device not allowed within "
|
|
f"shard_map-decorated functions, but got device {device}")
|
|
return xs
|
|
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
|
|
|
|
|
|
# Batching
|
|
|
|
def _modify_specs_axis_data(trace, name, mesh, in_specs, in_dims):
|
|
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, name)
|
|
for sp, d in zip(in_specs, in_dims)]
|
|
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in name)
|
|
new_axis_data = batching.AxisData(
|
|
trace.axis_data.name, new_size, trace.axis_data.spmd_name,
|
|
trace.axis_data.explicit_mesh_axis)
|
|
return new_in_specs, new_axis_data
|
|
|
|
def _shard_map_batch(
|
|
trace: batching.BatchTrace, prim: core.Primitive, fun: lu.WrappedFun,
|
|
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
|
|
in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset
|
|
) -> Sequence[batching.BatchTracer]:
|
|
in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
|
|
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
|
|
raise NotImplementedError
|
|
spmd_axis_name = trace.axis_data.spmd_name
|
|
explicit_mesh_axis = trace.axis_data.explicit_mesh_axis
|
|
if spmd_axis_name is not None:
|
|
used = {n for spec in in_specs for n in _spec_to_vma(spec)}
|
|
if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used:
|
|
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
|
|
new_in_specs, new_axis_data = _modify_specs_axis_data(
|
|
trace, spmd_axis_name, mesh, in_specs, in_dims)
|
|
elif explicit_mesh_axis is not None:
|
|
used = {n for spec in in_specs for n in _spec_to_vma(spec)}
|
|
if set(explicit_mesh_axis) & used:
|
|
raise ValueError("vmapped away explicit mesh axis cannot appear in "
|
|
"shard_map in_specs")
|
|
new_in_specs, new_axis_data = _modify_specs_axis_data(
|
|
trace, explicit_mesh_axis, mesh, in_specs, in_dims)
|
|
else:
|
|
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(in_specs, in_dims)]
|
|
new_axis_data = trace.axis_data
|
|
fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims))
|
|
|
|
@as_hashable_function(closure=out_specs_thunk)
|
|
def new_out_specs_thunk():
|
|
return _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims(),
|
|
out_specs_thunk())
|
|
|
|
new_params = dict(mesh=mesh, in_specs=new_in_specs,
|
|
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
with core.set_current_trace(trace.parent_trace):
|
|
out_vals = prim.bind(fun, *in_vals, **new_params)
|
|
make_tracer = partial(batching.BatchTracer, trace,
|
|
source_info=source_info_util.current())
|
|
return map(make_tracer, out_vals, out_dims())
|
|
batching.BatchTrace.process_shard_map = _shard_map_batch
|
|
|
|
def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs):
|
|
if spmd_name is not None:
|
|
used = {n for spec in out_specs for n in _spec_to_vma(spec)}
|
|
if not config.disable_vmap_shmap_error.value and set(spmd_name) & used:
|
|
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
|
|
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name)
|
|
for sp, d in zip(out_specs, dims)]
|
|
elif explicit_mesh_axis is not None:
|
|
used = {n for spec in out_specs for n in _spec_to_vma(spec)}
|
|
if set(explicit_mesh_axis) & used:
|
|
raise ValueError("vmapped away explicit mesh axis cannot appear in "
|
|
"shard_map out_specs")
|
|
return [sp if d is batching.not_mapped else
|
|
pxla.batch_spec(sp, d, explicit_mesh_axis)
|
|
for sp, d in zip(out_specs, dims)]
|
|
else:
|
|
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(out_specs, dims)]
|
|
|
|
|
|
# Autodiff
|
|
|
|
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs,
|
|
out_specs_thunk, check_vma, manual_axes):
|
|
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
|
which_nz = [ type(t) is not ad.Zero for t in tangents]
|
|
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
|
|
args, in_tree = tree_flatten((primals, tangents))
|
|
f_jvp = ad.jvp_subtrace(f, trace.tag)
|
|
f_jvp, which_nz_out = ad.nonzero_tangent_outputs(f_jvp)
|
|
tangent_in_specs = [sp for sp, nz in zip(in_specs, which_nz) if nz]
|
|
|
|
@as_hashable_function(closure=out_specs_thunk)
|
|
def new_out_specs_thunk():
|
|
out_ax = out_specs_thunk()
|
|
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
|
|
params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs),
|
|
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
|
|
result = shard_map_p.bind_with_trace(trace.parent_trace, (f_jvp,) + tuple(args), params)
|
|
primal_out, tangent_out = tree_unflatten(out_tree(), result)
|
|
tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t
|
|
for p, t in zip(primal_out, tangent_out)]
|
|
return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)]
|
|
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
|
|
|
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
|
f: lu.WrappedFun, tracers, mesh, in_specs,
|
|
out_specs_thunk, check_vma, manual_axes):
|
|
tracers = map(trace.to_jaxpr_tracer, tracers)
|
|
in_pvals = [t.pval for t in tracers]
|
|
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
|
unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs)
|
|
in_avals_sharded = map(partial(_shard_aval, mesh, manual_axes, check_vma),
|
|
unk_in_specs, in_avals)
|
|
f = pe.trace_to_subjaxpr_nounits_fwd2(f, trace.tag, f.debug_info, False)
|
|
f = _promote_scalar_residuals(f)
|
|
f_known, aux = pe.partial_eval_wrapper_nounits2(
|
|
f, (*in_knowns,), (*in_avals_sharded,))
|
|
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
|
|
|
|
@as_hashable_function(closure=out_specs_thunk)
|
|
def known_out_specs():
|
|
_, _, out_knowns, res_avals, _, _ = aux()
|
|
_, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk())
|
|
if check_vma:
|
|
res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals]
|
|
else:
|
|
res_specs = [P(all_names)] * len(res_avals)
|
|
return (*out_known_specs, *res_specs)
|
|
|
|
known_params = dict(mesh=mesh, in_specs=(*known_in_specs,),
|
|
out_specs_thunk=known_out_specs, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
out = shard_map_p.bind_with_trace(trace.parent_trace, (f_known, *in_consts),
|
|
known_params)
|
|
in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env = aux()
|
|
num_res = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
|
out_consts, non_fwd_res = split_list(out, [len(out) - num_res])
|
|
assert not jaxpr.constvars
|
|
unk_out_specs, _ = pe.partition_list(out_knowns, out_specs_thunk())
|
|
known_out_specs_ = known_out_specs()
|
|
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
|
|
# TODO make res_avals be the full set, not just the non-fwd ones
|
|
res_avals_iter = iter(res_avals)
|
|
res_specs = []
|
|
for f1, f2 in zip(in_fwd, out_fwd):
|
|
if f1 is not None:
|
|
res_specs.append(known_in_specs[f1])
|
|
elif f2 is not None:
|
|
res_specs.append(known_out_specs_[f2])
|
|
else:
|
|
if check_vma:
|
|
res_vma = next(res_avals_iter).vma
|
|
res_specs.append(P(order_wrt_mesh(mesh, res_vma)))
|
|
else:
|
|
res_specs.append(P(all_names))
|
|
unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment]
|
|
const_tracers = map(trace.new_instantiated_const, res)
|
|
env_tracers = map(trace.to_jaxpr_tracer, env)
|
|
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
|
out_avals_sharded = [v.aval for v in jaxpr.outvars]
|
|
unk_params = dict(mesh=mesh, in_specs=unk_in_specs,
|
|
out_specs=unk_out_specs, jaxpr=jaxpr,
|
|
check_vma=check_vma, manual_axes=manual_axes)
|
|
out_avals = map(partial(_unshard_aval, mesh, check_vma), unk_out_specs,
|
|
out_avals_sharded)
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
|
for a in out_avals]
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
|
|
out_tracers, shard_map_p, unk_params,
|
|
effs, source_info_util.current())
|
|
for t in out_tracers: t.recipe = eqn
|
|
return merge_lists(out_knowns, out_tracers, out_consts)
|
|
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
|
|
|
def _shard_map_linearize(trace, shard_map_p, f: lu.WrappedFun,
|
|
tracers, mesh, in_specs, out_specs_thunk, check_vma,
|
|
manual_axes):
|
|
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
|
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
|
f_primal, linearize_outs_thunk = ad.linearize_subtrace(f, trace.tag, nzs_in, f.debug_info)
|
|
f_primal = _promote_scalar_residuals_lin(f_primal, linearize_outs_thunk)
|
|
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
|
|
|
|
@as_hashable_function(closure=linearize_outs_thunk)
|
|
def fwd_out_specs_thunk():
|
|
res_avals, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
|
res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
|
|
if f1 is None and f2 is None]
|
|
out_specs = out_specs_thunk()
|
|
if check_vma:
|
|
res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals]
|
|
else:
|
|
res_specs = [P(all_names)] * len(res_avals)
|
|
return (*res_specs, *out_specs)
|
|
fwd_params = dict(
|
|
mesh=mesh, in_specs=in_specs,
|
|
out_specs_thunk=fwd_out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
all_fwd_results = shard_map_p.bind_with_trace(
|
|
trace.parent_trace, (f_primal, *primals), fwd_params)
|
|
res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
|
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
|
non_fwd_res = all_fwd_results[:num_res_out]
|
|
primals_out = all_fwd_results[num_res_out:]
|
|
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
|
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
|
|
for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)]
|
|
with (_extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
|
|
config._check_vma(check_vma)):
|
|
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
|
out_specs = out_specs_thunk()
|
|
res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
|
|
if f1 is None and f2 is None]
|
|
res_avals_iter = iter(res_avals2)
|
|
res_specs = []
|
|
for f1, f2 in zip(in_fwd, out_fwd):
|
|
if f1 is not None:
|
|
res_specs.append(in_specs[f1])
|
|
elif f2 is not None:
|
|
res_specs.append(out_specs[f2])
|
|
else:
|
|
if check_vma:
|
|
res_vma = next(res_avals_iter).vma
|
|
res_specs.append(P(order_wrt_mesh(mesh, res_vma)))
|
|
else:
|
|
res_specs.append(P(all_names))
|
|
new_in_specs = (*res_specs, *(P(),) * len(env),
|
|
*(ax for ax, nz in zip(in_specs, nzs_in) if nz))
|
|
tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out)
|
|
if nz)
|
|
@as_hashable_function(closure=tangent_out_specs)
|
|
def tangent_out_specs_thunk():
|
|
return tangent_out_specs
|
|
tangent_params = dict(
|
|
mesh=mesh, in_specs=new_in_specs, out_specs_thunk=tangent_out_specs_thunk,
|
|
check_vma=check_vma, manual_axes=manual_axes)
|
|
|
|
# TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here
|
|
def f_tangent(*args):
|
|
return core.eval_jaxpr(lin_jaxpr, (), *args)
|
|
|
|
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
|
nz_tangents_out = shard_map_p.bind_with_trace(
|
|
trace.tangent_trace,
|
|
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
|
|
*residuals, *env, *nz_tangents_in), tangent_params)
|
|
nz_tangents_out_iter = iter(nz_tangents_out)
|
|
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
|
|
for nz, primal in zip(nzs_out, primals_out)]
|
|
return map(partial(ad.maybe_linearize_tracer, trace), primals_out, nzs_out, tangents_out)
|
|
ad.LinearizeTrace.process_shard_map = _shard_map_linearize
|
|
|
|
@lu.transformation2
|
|
def _promote_scalar_residuals_lin(f, linearize_outs_thunk, *args, **kwargs):
|
|
ans = f(*args, **kwargs)
|
|
_, _, _, _, in_fwd, out_fwd = linearize_outs_thunk()
|
|
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
|
residuals = ans[:num_res_out]
|
|
primals = ans[num_res_out:]
|
|
residuals = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
|
for x in residuals]
|
|
return *residuals, *primals
|
|
|
|
@lu.transformation2
|
|
def _promote_scalar_residuals(f: Callable, *args, **kwargs):
|
|
jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs)
|
|
which = [f1 is None and f2 is None and not v.aval.shape
|
|
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
|
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
|
out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
|
for x in out_consts]
|
|
return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env)
|
|
|
|
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]):
|
|
def fun(*res_and_args):
|
|
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
|
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
|
return core.eval_jaxpr(jaxpr, res, *args)
|
|
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
|
for v, w in zip(jaxpr.constvars, which)]
|
|
in_avals = [*res_avals, *[v.aval for v in jaxpr.invars]]
|
|
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(fun, debug_info=jaxpr.debug_info), in_avals)
|
|
return jaxpr
|
|
|
|
|
|
def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName]
|
|
) -> list[AxisName]:
|
|
# We use a filtered-down version of unmentioned to avoid defensive-psum over
|
|
# more chips than required in the transpose-no-check-vma case.
|
|
name_set = _spec_to_vma(spec)
|
|
return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes)
|
|
if n not in name_set]
|
|
|
|
|
|
def _shard_map_transpose(out_cts, *args,
|
|
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
mb_div = lambda x, y: x / y if y != 1 else x
|
|
out_cts = [
|
|
ad.Zero(_shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
|
|
if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0
|
|
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes))))
|
|
for sp, x in zip(out_specs, out_cts)
|
|
]
|
|
args = tuple(x if type(x) is not ad.UndefinedPrimal else
|
|
ad.UndefinedPrimal(
|
|
_shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
|
|
for sp, x in zip(in_specs, args))
|
|
all_args, in_tree = tree_flatten((out_cts, args))
|
|
|
|
def fun_trans_callable(out_cts, args):
|
|
# TODO(mattjj): when #26811 lands, delete this and just run backward_pass
|
|
in_undef = map(ad.is_undefined_primal, args)
|
|
res, undefs = partition_list(in_undef, args)
|
|
jaxpr_known, jaxpr_unknown, _, _ = pe.partial_eval_jaxpr_nounits(
|
|
pe.close_jaxpr(jaxpr), in_undef, False)
|
|
res_reshaped = core.jaxpr_as_fun(jaxpr_known)(*res)
|
|
in_cts = ad.backward_pass(
|
|
jaxpr_unknown.jaxpr, False, (), (*res_reshaped, *undefs), out_cts
|
|
)[len(res_reshaped):]
|
|
_, in_ct_specs = partition_list(in_undef, in_specs)
|
|
in_cts = [ad.Zero(_unshard_aval(mesh, check_vma, sp, x.aval))
|
|
if type(x) is ad.Zero else x if check_vma
|
|
else jax.lax.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes)))
|
|
for sp, x in zip(in_ct_specs, in_cts)]
|
|
res_zeros = [ad_util.zero_from_primal(r) for r in res]
|
|
return merge_lists(in_undef, res_zeros, in_cts)
|
|
|
|
fun_trans = lu.wrap_init(fun_trans_callable, debug_info=jaxpr.debug_info)
|
|
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
|
fun_trans_flat, out_tree = api_util.flatten_fun_nokwargs(fun_trans, in_tree)
|
|
|
|
new_in_specs = (
|
|
[n for n, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] +
|
|
[n for n, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal])
|
|
|
|
def new_out_specs_thunk():
|
|
return tuple(sp for sp, nz in zip(in_specs, nz_arg_cts()) if nz)
|
|
|
|
try:
|
|
out_flat = shard_map_p.bind(
|
|
fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs),
|
|
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
except (FloatingPointError, ZeroDivisionError) as e:
|
|
print("Invalid nan value encountered in the backward pass of a shard_map "
|
|
"function. Calling the de-optimized backward pass.")
|
|
try:
|
|
# TODO(mattjj): Remove this and do `fun_trans.call_wrapped(out_cts, args)`
|
|
# in eager mode so that output of shmap are not manual.
|
|
with jax.disable_jit(True):
|
|
_ = shard_map_p.bind(
|
|
fun_trans_flat, *all_args, mesh=mesh, in_specs=tuple(new_in_specs),
|
|
out_specs_thunk=new_out_specs_thunk, check_vma=check_vma,
|
|
manual_axes=manual_axes)
|
|
except (FloatingPointError, ZeroDivisionError) as e2:
|
|
raise e2 from None
|
|
else:
|
|
api_util._raise_no_nan_in_deoptimized(e)
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
|
|
|
# Remat
|
|
|
|
def _partial_eval_jaxpr_custom_rule(
|
|
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
|
|
inst_in: Sequence[bool], eqn: core.JaxprEqn
|
|
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
|
|
list[core.Var]]:
|
|
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
|
|
check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes']
|
|
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))):
|
|
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
|
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
|
num_out_primals = len(jaxpr_known.outvars) - num_res
|
|
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
|
|
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
|
|
idx_map = {id(v): i for i, v in enumerate(out_vars)}
|
|
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
|
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
|
mesh = eqn.params['mesh']
|
|
with (_extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes))),
|
|
config._check_vma(check_vma)):
|
|
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
|
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
|
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
|
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
|
|
ins_known, _ = partition_list(unks_in, eqn.invars)
|
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
|
_, ins_staged = partition_list(inst_in, eqn.invars)
|
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
|
newvar = core.gensym()
|
|
residuals, staged_in_res_specs = [], []
|
|
for var, w in zip(jaxpr_staged.invars[:num_res], which):
|
|
if w:
|
|
rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore
|
|
if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes)))
|
|
residuals.append(newvar(_unshard_aval(mesh, check_vma, rn, var.aval)))
|
|
staged_in_res_specs.append(rn)
|
|
if check_vma:
|
|
out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore
|
|
for var, o in zip(res_vars, out_fwd) if o is None]
|
|
else:
|
|
out_res_specs_known = [
|
|
P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which)
|
|
params_known, params_staged = _pe_custom_params(
|
|
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd,
|
|
out_res_specs_known, staged_in_res_specs,
|
|
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
|
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
|
eqn.primitive, params_known, jaxpr_known.effects,
|
|
eqn.source_info, eqn.ctx)
|
|
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
|
|
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
|
|
eqn.primitive, params_staged,
|
|
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
|
|
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
|
if type(x) is core.Var and not inst]
|
|
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
|
|
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
|
|
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
|
_partial_eval_jaxpr_custom_rule
|
|
|
|
def _add_reshapes(which: Sequence[bool],
|
|
jaxpr_known: core.Jaxpr,
|
|
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
|
|
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
|
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
|
|
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
|
if not any(which_): return jaxpr_known, jaxpr_staged
|
|
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
|
|
|
def known(*args):
|
|
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
|
out_known, res = split_list(out, [len(out) - sum(which)])
|
|
res = [_add_singleton(x) if not x.shape else x for x in res]
|
|
return [*out_known, *res]
|
|
avals_in = [v.aval for v in jaxpr_known.invars]
|
|
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in)
|
|
|
|
def staged(*args):
|
|
res_, ins = split_list(args, [len(which)])
|
|
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
|
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
|
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
|
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
|
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
|
|
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(
|
|
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in)
|
|
|
|
return jaxpr_known, jaxpr_staged
|
|
|
|
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
|
in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs,
|
|
params_known, params_staged):
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
in_specs_known, _ = partition_list(unks_in, params_known['in_specs'])
|
|
_, out_specs_known = partition_list(kept_outs_known, params_known['out_specs'])
|
|
out_specs_known = out_specs_known + out_res_specs_known
|
|
assert len(out_specs_known) == len(params_known['jaxpr'].outvars)
|
|
new_params_known = dict(params_known, in_specs=tuple(in_specs_known),
|
|
out_specs=tuple(out_specs_known))
|
|
|
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
|
_, in_specs_staged = partition_list(inst_in, params_staged['in_specs'])
|
|
iter_staged = iter(staged_in_res_specs)
|
|
res_specs = [in_specs_known[f1] if f1 is not None else
|
|
out_specs_known[f2] if f2 is not None else
|
|
next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)]
|
|
|
|
in_specs_staged = res_specs + in_specs_staged
|
|
_, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs'])
|
|
new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged),
|
|
out_specs=tuple(out_specs_staged))
|
|
return new_params_known, new_params_staged
|
|
|
|
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
|
def _all_mesh_names_except_spmd(
|
|
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
|
|
axis_env = core.get_axis_env()
|
|
spmd_names = axis_env.spmd_axis_names
|
|
return tuple(name for name in mesh.axis_names
|
|
if name not in spmd_names and name in manual_axes)
|
|
|
|
def _all_newly_manual_mesh_names(
|
|
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
|
|
axis_env = core.get_axis_env()
|
|
vmap_spmd_names = set(axis_env.spmd_axis_names)
|
|
if not (ctx_mesh := get_abstract_mesh()).empty:
|
|
mesh = ctx_mesh
|
|
already_manual_names = set(ctx_mesh.manual_axes)
|
|
else:
|
|
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
|
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
|
|
return tuple(name for name in mesh.axis_names
|
|
if (name not in vmap_spmd_names | already_manual_names and
|
|
name in manual_axes))
|
|
|
|
|
|
# DCE
|
|
|
|
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
|
|
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
|
) -> tuple[list[bool], core.JaxprEqn | None]:
|
|
if not any(used_outputs) and not pe.has_effects(eqn):
|
|
return [False] * len(eqn.invars), None
|
|
mesh = eqn.params["mesh"]
|
|
manual_axes = eqn.params["manual_axes"]
|
|
check_vma = eqn.params["check_vma"]
|
|
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes | set(mesh.manual_axes)))):
|
|
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
|
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
|
|
return used_inputs, None
|
|
else:
|
|
_, in_specs = partition_list(used_inputs, eqn.params['in_specs'])
|
|
_, out_specs = partition_list(used_outputs, eqn.params['out_specs'])
|
|
new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs),
|
|
out_specs=tuple(out_specs))
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
[x for x, used in zip(eqn.outvars, used_outputs) if used],
|
|
eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx)
|
|
return used_inputs, new_eqn
|
|
pe.dce_rules[shard_map_p] = _shard_map_dce
|
|
|
|
# Implementing pmap in terms of shard_map
|
|
|
|
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
|
|
static_broadcasted_argnums=(), devices=None, backend=None,
|
|
axis_size=None, donate_argnums=(), global_arg_shapes=None):
|
|
devices = tuple(devices) if devices is not None else devices
|
|
axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap(
|
|
f, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, out_axes)
|
|
|
|
def infer_params(*args, **kwargs):
|
|
p = _prepare_pmap(f, in_axes, out_axes, static_broadcasted_tuple,
|
|
donate_tuple, devices, backend, axis_size, args, kwargs)
|
|
for arg in p.flat_args:
|
|
dispatch.check_arg(arg)
|
|
mesh = Mesh(_get_devices(p, backend), (axis_name,))
|
|
_pmapped, in_specs, out_specs = _cached_shard_map(
|
|
p.flat_fun, mesh, p.in_axes_flat, p.out_axes_thunk, axis_name)
|
|
flat_global_args = host_local_array_to_global_array(
|
|
p.flat_args, mesh, list(in_specs))
|
|
jitted_f = jax.jit(
|
|
_pmapped,
|
|
donate_argnums=[i for i, val in enumerate(p.donated_invars) if val])
|
|
return jitted_f, flat_global_args, p.out_tree, mesh, out_specs
|
|
|
|
def wrapped(*args, **kwargs):
|
|
(jitted_f, flat_global_args, out_tree, mesh,
|
|
out_specs) = infer_params(*args, **kwargs)
|
|
outs = jitted_f(*flat_global_args)
|
|
outs = global_array_to_host_local_array(outs, mesh, out_specs())
|
|
return tree_unflatten(out_tree(), outs)
|
|
|
|
def lower(*args, **kwargs):
|
|
jitted_f, _, _, _, _ = infer_params(*args, **kwargs)
|
|
return jitted_f.lower(*args, **kwargs)
|
|
wrapped.lower = lower
|
|
|
|
return wrapped
|
|
|
|
|
|
@lu.cache
|
|
def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name):
|
|
in_specs = tuple(map(partial(_axis_to_spec, axis_name), in_axes_flat))
|
|
out_specs = lambda: map(partial(_axis_to_spec, axis_name), out_axes_thunk())
|
|
fun = _handle_reshapes(flat_fun, in_axes_flat, out_axes_thunk)
|
|
return (_shard_map(fun.call_wrapped, mesh=mesh, in_specs=in_specs,
|
|
out_specs=out_specs, check_vma=False,
|
|
axis_names=set(mesh.axis_names)),
|
|
in_specs, out_specs)
|
|
|
|
@lu.transformation2
|
|
def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs):
|
|
args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax),
|
|
list(args), list(in_axes))
|
|
out = f(*args)
|
|
return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax),
|
|
list(out), list(out_axes_thunk()))
|
|
|
|
def _axis_to_spec(axis_name, ax):
|
|
if isinstance(ax, int):
|
|
specs = [None] * ax + [axis_name]
|
|
return P(*specs)
|
|
elif ax is None:
|
|
return P()
|
|
else:
|
|
raise TypeError(ax)
|
|
|
|
def _get_devices(p, backend):
|
|
if backend is not None and p.devices is None:
|
|
devs = jax.devices(backend=backend)
|
|
else:
|
|
devs = jax.devices() if p.devices is None else p.devices
|
|
if jax.process_count() > 1:
|
|
return devs[:p.global_axis_size]
|
|
return devs[:p.local_axis_size]
|