83 lines
4.3 KiB
Python
83 lines
4.3 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Callable, Hashable
|
|
from typing import Any
|
|
from jax._src import traceback_util
|
|
from jax.sharding import Mesh, AbstractMesh
|
|
from jax._src import shard_map as jshmap
|
|
|
|
Specs = Any
|
|
AxisName = Hashable
|
|
|
|
@traceback_util.api_boundary
|
|
def shard_map(
|
|
f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs,
|
|
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
|
|
"""Map a function over shards of data.
|
|
|
|
Note:
|
|
``shard_map`` is an experimental API, and still subject to change. For an
|
|
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
|
|
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
|
|
|
|
Args:
|
|
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
|
takes as input a shard of the mapped-over arguments and produces a shard
|
|
of the output.
|
|
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
|
|
to shard the data and on which to execute instances of ``f``. The names of
|
|
the ``Mesh`` can be used in collective communication operations in ``f``.
|
|
This is typically created by a utility function like
|
|
:func:`jax.experimental.mesh_utils.create_device_mesh`.
|
|
in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
|
with a tree structure that is a tree prefix of the args tuple to be mapped
|
|
over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec``
|
|
represents how the corresponding argument (or subtree of arguments) should
|
|
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
|
|
mentioning a ``mesh`` axis name at a position expresses sharding the
|
|
corresponding argument array axis along that positional axis; not
|
|
mentioning an axis name expresses replication. If an argument, or argument
|
|
subtree, has a corresponding spec of None, that argument is not sharded.
|
|
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
|
|
with a tree structure that is a tree prefix of the output of ``f``. Each
|
|
``PartitionSpec`` represents how the corresponding output shards should be
|
|
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at
|
|
a position expresses concatenation of that mesh axis's shards along the
|
|
corresponding positional axis. Not mentioning a ``mesh`` axis name
|
|
expresses a promise that the output values are equal along that mesh axis,
|
|
and that rather than concatenating only a single value should be produced.
|
|
check_rep: If True (default) enable additional validity checks and automatic
|
|
differentiation optimizations. The validity checks concern whether any mesh
|
|
axis names not mentioned in ``out_specs`` are consistent with how the outputs
|
|
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
|
|
auto: (experimental) an optional set of axis names from ``mesh`` over which we
|
|
do not shard the data or map the function, but rather we allow the
|
|
compiler to control sharding. These names cannot be used in ``in_specs``,
|
|
``out_specs``, or in communication collectives in ``f``.
|
|
|
|
Returns:
|
|
A callable that applies the input function ``f`` across data sharded according to
|
|
the ``mesh`` and ``in_specs``.
|
|
|
|
Examples:
|
|
For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_.
|
|
|
|
.. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html
|
|
"""
|
|
axis_names = frozenset(mesh.axis_names) - auto
|
|
return jshmap._shard_map(
|
|
f, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
|
|
check_vma=check_rep, axis_names=axis_names, _skip_mesh_check=True)
|