# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Sequence import collections import dataclasses import functools from typing import Any, Union from jax._src import config from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.partition_spec import PartitionSpec from jax._src import sharding as JSharding import numpy as np Shape = tuple[int, ...] Device = xc.Device Index = tuple[slice, ...] XLADeviceAssignment = Sequence[Device] class AUTO: def __init__(self, mesh: mesh_lib.Mesh): self.mesh = mesh def _to_sdy_sharding(self, ndim: int) -> SdyArray: dim_shardings = [SdyDim(axes=[], is_open=True) for _ in range(ndim)] return SdyArray(mesh_shape=self.mesh.shape_tuple, dim_shardings=dim_shardings) class UnspecifiedValue: def __repr__(self): return "UnspecifiedValue" UNSPECIFIED = UnspecifiedValue() MeshAxisName = Any """ ArrayMapping specifies how an ndarray should map to mesh axes. Note that the ordering is crucial for the cases when this mapping is non-injective (i.e. when multiple mesh axes map to the same positional axis). Then, the order of entries of the mapping determines a major-to-minor order on mesh axes, according to which chunks of the value along the repeated dimension will be assigned. For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}. The second dimension of the value would get chunked into 6 pieces, and assigned to the mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case, that would mean that a flat list of chunks would get assigned to a flattened list of mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the mesh devices ndarray would have to be transposed before flattening and assignment. """ ArrayMapping = collections.OrderedDict[MeshAxisName, int] ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue] def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids): return NamedSharding(mesh, spec, memory_kind=memory_kind, _logical_device_ids=logical_device_ids) @use_cpp_class(xc.NamedSharding) class NamedSharding(JSharding.Sharding): r"""A :class:`NamedSharding` expresses sharding using named axes. A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and :class:`PartitionSpec` which describes how to shard an array across that mesh. A :class:`Mesh` is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``. A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, ``PartitionSpec('x', 'y')`` says that the first dimension of data is sharded across ``x`` axis of the mesh, and the second dimension is sharded across ``y`` axis of the mesh. The Distributed arrays and automatic parallelization (https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names) tutorial has more details and diagrams that explain how :class:`Mesh` and :class:`PartitionSpec` are used. Args: mesh: A :class:`jax.sharding.Mesh` object. spec: A :class:`jax.sharding.PartitionSpec` object. Examples: >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec) """ mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None _logical_device_ids: tuple[int, ...] | None @use_cpp_method() def __init__( self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, memory_kind: str | None = None, _logical_device_ids=None): self.mesh = mesh self.spec = spec self._memory_kind = memory_kind self._logical_device_ids = _logical_device_ids check_pspec(self.mesh, self.spec) def __repr__(self): mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' ldi = ('' if self._logical_device_ids is None else f', logical_device_ids={self._logical_device_ids}') mesh_repr = f"{str(self.mesh)}" return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})' def __reduce__(self): return (_unpickle_named_sharding, (self.mesh, self.spec, self.memory_kind, self._logical_device_ids)) @property def memory_kind(self) -> str | None: return self._memory_kind @use_cpp_method(jaxlib_extension_version >= 353) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( (self.mesh, self.memory_kind, self.spec, self._logical_device_ids)) return self._hash @use_cpp_method(jaxlib_extension_version >= 353) def __eq__(self, other): if not isinstance(other, NamedSharding): return False if self is other: return True if (self.spec != other.spec or self.memory_kind != other.memory_kind or self._logical_device_ids != other._logical_device_ids): return False return self.mesh is other.mesh or self.mesh == other.mesh def check_compatible_aval(self, aval_shape: Shape) -> None: if len(aval_shape) < len(self.spec): extra_msg = (' For scalars the PartitionSpec should be P()' if len(aval_shape) == 0 else '') raise ValueError( f"Sharding {self} is only valid for values of rank at least " f"{len(self.spec)}, but was applied to a value of rank " f"{len(aval_shape)}.{extra_msg}") @property def num_devices(self) -> int: return self.mesh.size @property def device_set(self) -> set[Device]: if isinstance(self.mesh, mesh_lib.AbstractMesh): raise ValueError( 'device_set is not implemented for `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_set @property def _device_assignment(self) -> XLADeviceAssignment: if isinstance(self.mesh, mesh_lib.AbstractMesh): raise ValueError('_device_assignment is not implemented for' ' `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_tuple @property def is_fully_addressable(self) -> bool: if isinstance(self.mesh, mesh_lib.AbstractMesh): raise ValueError('is_fully_addressable is not implemented for ' '`jax.sharding.AbstractMesh`.') # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. if config.enable_empty_arrays.value: return self._internal_device_list.is_fully_addressable # type: ignore return not self.mesh.is_multi_process @property def _is_concrete(self) -> bool: if isinstance(self.mesh, mesh_lib.AbstractMesh): return False return True @property def addressable_devices(self) -> set[Device]: if isinstance(self.mesh, mesh_lib.AbstractMesh): raise ValueError('addressable_devices is not implemented for ' '`jax.sharding.AbstractMesh`.') # Override addressable devices because there is a high chance that the mesh # across multiple NamedSharding objects will be the same. return self.mesh._local_devices_set @functools.cached_property def is_fully_replicated(self) -> bool: if self.mesh.size == 1: return True array_mapping = get_array_mapping(self.spec) mesh_shape = self.mesh.shape num_partitions = 1 for name in array_mapping: # type: ignore num_partitions *= mesh_shape[name] return num_partitions == 1 def with_memory_kind(self, kind: str) -> NamedSharding: return self.update(memory_kind=kind) def update(self, **kwargs) -> NamedSharding: spec = kwargs.pop("spec", self.spec) if not isinstance(spec, PartitionSpec): spec = PartitionSpec(*spec) return NamedSharding( mesh=kwargs.pop("mesh", self.mesh), spec=spec, memory_kind=kwargs.pop("memory_kind", self.memory_kind), _logical_device_ids=kwargs.pop("_logical_device_ids", self._logical_device_ids)) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray: dim_shardings = [SdyDim(axes=[], is_open=False) for _ in range(num_dimensions)] for i, dim_spec in enumerate(self.spec): if dim_spec is PartitionSpec.UNCONSTRAINED: dim_shardings[i].is_open = True elif dim_spec is None: # Already empty and closed sharding. pass else: dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,) dim_shardings[i].axes = dim_spec return SdyArray(mesh_shape=self.mesh.shape_tuple, dim_shardings=dim_shardings, logical_device_ids=self._logical_device_ids, unreduced_axes=self.spec.unreduced) NamedSharding.__module__ = 'jax.sharding' def get_array_mapping( axis_resources: PartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: if isinstance(axis_resources, (AUTO, UnspecifiedValue)): return axis_resources d = collections.OrderedDict() for i, axes in enumerate(axis_resources): if axes is None or axes is PartitionSpec.UNCONSTRAINED: continue axes = axes if isinstance(axes, tuple) else (axes,) for axis in axes: d[axis] = i return d @dataclasses.dataclass class SdyDim: axes: Sequence[str] is_open: bool priority: int | None = None def build(self) -> sdy.DimensionShardingAttr: return sdy.DimensionShardingAttr.get( [sdy.AxisRefAttr.get(axis) for axis in self.axes], is_closed=not self.is_open, priority=self.priority) def __repr__(self): return f'SdyDim({self._custom_repr()})' def _custom_repr(self): axes_repr = ', '.join(f"'{a}'" for a in self.axes) open_repr = '' if self.is_open: open_repr = ', ?' if self.axes else '?' priority_repr = '' if self.priority is None else f'p{self.priority}' return f'{{{axes_repr}{open_repr}}}{priority_repr}' def _get_axes(axes, mesh_shape): if not axes: return () assert mesh_shape is not None # Sort wrt mesh axis names so order is deterministic and doesn't hang in # McJAX. return tuple(n for n, _ in mesh_shape if n in axes) @dataclasses.dataclass(kw_only=True) class SdyArray: mesh_shape: tuple[tuple[str, int], ...] | None dim_shardings: Sequence[SdyDim] logical_device_ids: tuple[int, ...] | None = None replicated_axes: tuple[str, ...] = () unreduced_axes: frozenset[str] = frozenset() def build(self) -> sdy.TensorShardingAttr: if self.mesh_shape is None: mesh_attr = sdy.MeshAttr.get([]) else: ldi = ([] if self.logical_device_ids is None else list(self.logical_device_ids)) mesh_attr = sdy.MeshAttr.get( [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape], ldi) replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape) unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape) return sdy.TensorShardingAttr.get( mesh_attr, [dim_sharding.build() for dim_sharding in self.dim_shardings], replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes], unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes]) def __repr__(self): dim_sharding_repr = ', '.join( d._custom_repr() for d in self.dim_shardings) device_id_repr = (f', device_ids={self.logical_device_ids}' if self.logical_device_ids is not None else '') rar = (f', replicated_axes={self.replicated_axes}' if self.replicated_axes else '') return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})" # TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra # parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)` def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh): if mesh._any_axis_auto: dim_shardings, used_axes = [], [] # type: ignore for d in sdy_sharding.dim_shardings: # TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open? dim_shardings.append(SdyDim(axes=[], is_open=True) if not d.axes and not d.is_open else d) used_axes.extend(d.axes) remaining_axes = set(mesh.axis_names) - set(used_axes) replicated_axes = tuple(r for r in remaining_axes if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit) return SdyArray(mesh_shape=sdy_sharding.mesh_shape, dim_shardings=dim_shardings, logical_device_ids=sdy_sharding.logical_device_ids, replicated_axes=replicated_axes) return sdy_sharding @cache(max_size=4096, trace_context_in_key=False) def named_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: mesh_shape = self.mesh.shape array_mapping = get_array_mapping(self.spec) mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)} special_axes = {} manual_axes = frozenset(self.mesh.manual_axes) if manual_axes: axis_names = self.mesh.axis_names for manual_axis in manual_axes: special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL replicated_mesh_axes = [] for i, (axis_name, axis_val) in enumerate(mesh_shape.items()): if axis_name not in array_mapping: # type: ignore replicated_mesh_axes.append((i, axis_val)) if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes: return xc.HloSharding.replicate() mesh_permutation = [] new_mesh_shape = [1] * num_dimensions for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore new_mesh_shape[pos] *= mesh_shape[name] mesh_permutation.append(mesh_axis_pos[name]) last_tile_dims = [] if replicated_mesh_axes: axes_by_type: dict[Any, list[int]] = collections.defaultdict(list) size_by_type = collections.defaultdict(lambda: 1) # type: ignore assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys())) for i, size in replicated_mesh_axes: ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED) axes_by_type[ty].append(i) size_by_type[ty] *= size for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value): last_tile_dims.append(ty) new_mesh_shape.append(size_by_type[ty]) mesh_permutation.extend(axes) # Explanation of the parameters of `HloSharding.iota_tile`. # This is the HloShardingV2 format: # * dims: How many ways each dimension is sharded. # Replicated/Manual dims are added added at the end # * reshape_dims: This is the just the shape of the mesh. # * transpose_perm: This is the order in which mesh axes in PartitionSpec # appear relative to mesh.axis_names order. # * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL. # Let's see an example: # Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2} # and partition_spec=P(None, ('d', 'b'), 'c'). # Arguments to iota_tile will be: # dims = [1, 4, 2, 1, 2] # 'a' is replicated hence `2` is at the end. # reshape_dims = [2, 2, 2, 2] # transpose_perm = [3, 1, 2, 0] # 'a' is replicated hence 0 is at the end # subgroup_types = [xc.OpSharding.Type.REPLICATED] dims = new_mesh_shape reshape_dims = self.mesh.axis_sizes if self._logical_device_ids is None: return xc.HloSharding.iota_tile( dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation, subgroup_types=last_tile_dims) else: return xc.HloSharding.subgroup_with_device_ordering( np.asarray(self._logical_device_ids) .reshape(dims).reshape(reshape_dims).transpose(mesh_permutation) .reshape(dims), subgroup_types=last_tile_dims) def array_mapping_to_axis_resources(array_mapping: ArrayMapping): if not array_mapping: return PartitionSpec() max_index = -1 reverse_map = collections.defaultdict(list) for axis, index in array_mapping.items(): reverse_map[index].append(axis) if index > max_index: max_index = index partitions = [] for i in range(max_index + 1): axis = reverse_map[i] if axis: partitions.append(axis[0] if len(axis) == 1 else tuple(axis)) else: partitions.append(None) return PartitionSpec(*partitions) @cache(max_size=128, trace_context_in_key=False) def check_pspec(mesh, spec, _manual_axes=frozenset()): _check_unique_resources(spec, "NamedSharding spec", mesh) _check_mesh_resource_axis(mesh, spec) _check_mesh_unreduced(mesh, spec) class DuplicateSpecError(Exception): def __init__(self, message, mesh, pspec): super().__init__(message) self.message = message self.mesh = mesh self.pspec = pspec def __str__(self): return f"{self.message}" def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None ) -> None: resource_counts: dict[MeshAxisName, int] = {} duplicate = False for d in pspec: if d is PartitionSpec.UNCONSTRAINED or d is None: continue d = d if isinstance(d, tuple) else (d,) for resource in d: count = resource_counts.get(resource, 0) if count > 0: duplicate = True resource_counts[resource] = count + 1 if duplicate: multiple_uses = [r for r, c in resource_counts.items() if c > 1] raise DuplicateSpecError( message=( f'A single {arg_name} specification can map every mesh axis to at' f' most one positional dimension, but {pspec} has duplicate entries' f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: continue p = p if isinstance(p, tuple) else (p,) for r in p: if r not in mesh.axis_names: raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): raise ValueError( 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {pspec}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') if (AxisType.Auto not in mesh._axis_types_dict and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( f'{pspec} cannot contain' ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' f' axis_types: {mesh._axis_types_dict}') def _check_mesh_unreduced(mesh, pspec): for u in pspec.unreduced: if u not in mesh.axis_names: raise ValueError( f'Unreduced axes {u} is not found in {mesh.axis_names=}. ' f'Got {pspec=}') if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): raise ValueError( 'Unreduced axes can only refer to mesh axes that is of type' f' `Explicit`. Got unreduced axes: {pspec.unreduced} and' f' mesh: {mesh}') for u in pspec.reduced: if u not in mesh.axis_names: raise ValueError( f'Reduced axes {u} is not found in {mesh.axis_names=}. ' f'Got {pspec=}') if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual): raise ValueError( 'Reduced axes can only refer to mesh axes that is of type' f' `Explicit`. Got reduced axes: {pspec.reduced} and' f' mesh: {mesh}')