2025-08-11 12:24:21 +08:00

83 lines
4.3 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Hashable
from typing import Any
from jax._src import traceback_util
from jax.sharding import Mesh, AbstractMesh
from jax._src import shard_map as jshmap
Specs = Any
AxisName = Hashable
@traceback_util.api_boundary
def shard_map(
f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs,
check_rep: bool = True, auto: frozenset[AxisName] = frozenset()):
"""Map a function over shards of data.
Note:
``shard_map`` is an experimental API, and still subject to change. For an
introduction to sharded data, refer to :ref:`sharded-computation`. For a more
in-depth look at using ``shard_map``, refer to `SPMD multi-device parallelism with shard_map`_.
Args:
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
to shard the data and on which to execute instances of ``f``. The names of
the ``Mesh`` can be used in collective communication operations in ``f``.
This is typically created by a utility function like
:func:`jax.experimental.mesh_utils.create_device_mesh`.
in_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
with a tree structure that is a tree prefix of the args tuple to be mapped
over. Similar to :class:`~jax.sharding.NamedSharding`, each ``PartitionSpec``
represents how the corresponding argument (or subtree of arguments) should
be sharded along the named axes of ``mesh``. In each ``PartitionSpec``,
mentioning a ``mesh`` axis name at a position expresses sharding the
corresponding argument array axis along that positional axis; not
mentioning an axis name expresses replication. If an argument, or argument
subtree, has a corresponding spec of None, that argument is not sharded.
out_specs: a pytree with :class:`~jax.sharding.PartitionSpec` instances as leaves,
with a tree structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at
a position expresses concatenation of that mesh axis's shards along the
corresponding positional axis. Not mentioning a ``mesh`` axis name
expresses a promise that the output values are equal along that mesh axis,
and that rather than concatenating only a single value should be produced.
check_rep: If True (default) enable additional validity checks and automatic
differentiation optimizations. The validity checks concern whether any mesh
axis names not mentioned in ``out_specs`` are consistent with how the outputs
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
auto: (experimental) an optional set of axis names from ``mesh`` over which we
do not shard the data or map the function, but rather we allow the
compiler to control sharding. These names cannot be used in ``in_specs``,
``out_specs``, or in communication collectives in ``f``.
Returns:
A callable that applies the input function ``f`` across data sharded according to
the ``mesh`` and ``in_specs``.
Examples:
For examples, refer to :ref:`sharded-computation` or `SPMD multi-device parallelism with shard_map`_.
.. _SPMD multi-device parallelism with shard_map: https://docs.jax.dev/en/latest/notebooks/shard_map.html
"""
axis_names = frozenset(mesh.axis_names) - auto
return jshmap._shard_map(
f, mesh=mesh, in_specs=in_specs, out_specs=out_specs,
check_vma=check_rep, axis_names=axis_names, _skip_mesh_check=True)