2025-08-11 12:24:21 +08:00

140 lines
4.9 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering of jaxprs into XLA (HLO) computations.
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Any, Union
import numpy as np
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Shape
from jax._src.lib import xla_client as xc
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Types
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
# Utilities
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Shape, None, tuple[Union[Shape, None], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xc.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(sharding_to_proto, sharding)))
if sharding is None:
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.prod(sharding))) # type: ignore
return proto
def tuple_sharding_proto(elems):
proto = xc.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xc.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
### handlers
# IR constants
class InvalidInputException(Exception):
pass
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
for typ in typ.__mro__:
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
deprecations.warn(
'jax-abstract-dunder-array',
('Triggering of __jax_array__() during abstractification is deprecated.'
' To avoid this error, either explicitly convert your object using'
' jax.numpy.array(), or register your object as a pytree.'),
stacklevel=6)
return canonicalize_dtype(x.__jax_array__())
raise InvalidInputException(
f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
canonicalize_dtype_handlers: dict[Any, Callable] = {}
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity
initial_style_primitives: set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
initial_style_primitives.add(prim)