140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Lowering of jaxprs into XLA (HLO) computations.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from functools import partial
|
|
from typing import Any, Union
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import core
|
|
from jax._src import deprecations
|
|
from jax._src import dtypes
|
|
from jax._src.abstract_arrays import numpy_scalar_types
|
|
from jax._src.util import safe_zip, safe_map
|
|
|
|
from jax._src.typing import Shape
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
# Types
|
|
|
|
def identity(x): return x
|
|
|
|
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
|
|
|
# Utilities
|
|
|
|
# HLO instructions optionally can be annotated to say how the output should be
|
|
# spatially partitioned (represented in XLA as OpSharding protos, see
|
|
# sharding_to_proto). For array outputs, the annotation is either an int per
|
|
# dimension specifying the number of ways that dimension divided (i.e. the total
|
|
# number of shards is the product), or None to indicate the array should be
|
|
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
|
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
|
# checkers don't support recursive types), so we only represent one level of
|
|
# nesting in this type definition.
|
|
SpatialSharding = Union[Shape, None, tuple[Union[Shape, None], ...]]
|
|
|
|
|
|
def sharding_to_proto(sharding: SpatialSharding):
|
|
"""Converts a SpatialSharding to an OpSharding.
|
|
|
|
See
|
|
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
|
for details on the OpSharding proto.
|
|
"""
|
|
proto = xc.OpSharding()
|
|
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
|
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
|
return tuple_sharding_proto(list(map(sharding_to_proto, sharding)))
|
|
|
|
if sharding is None:
|
|
proto.type = xc.OpSharding.Type.REPLICATED
|
|
else:
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
proto.tile_assignment_dimensions = list(sharding) # type: ignore
|
|
proto.tile_assignment_devices = list(range(np.prod(sharding))) # type: ignore
|
|
return proto
|
|
|
|
def tuple_sharding_proto(elems):
|
|
proto = xc.OpSharding()
|
|
assert all(isinstance(e, type(proto)) for e in elems)
|
|
proto.type = xc.OpSharding.Type.TUPLE
|
|
proto.tuple_shardings = elems
|
|
return proto
|
|
|
|
|
|
### handlers
|
|
|
|
|
|
# IR constants
|
|
|
|
class InvalidInputException(Exception):
|
|
pass
|
|
|
|
|
|
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
|
def canonicalize_dtype(x):
|
|
typ = type(x)
|
|
handler = canonicalize_dtype_handlers.get(typ)
|
|
if handler: return handler(x)
|
|
for typ in typ.__mro__:
|
|
handler = canonicalize_dtype_handlers.get(typ)
|
|
if handler: return handler(x)
|
|
if hasattr(x, '__jax_array__'):
|
|
deprecations.warn(
|
|
'jax-abstract-dunder-array',
|
|
('Triggering of __jax_array__() during abstractification is deprecated.'
|
|
' To avoid this error, either explicitly convert your object using'
|
|
' jax.numpy.array(), or register your object as a pytree.'),
|
|
stacklevel=6)
|
|
return canonicalize_dtype(x.__jax_array__())
|
|
raise InvalidInputException(
|
|
f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
|
|
|
|
def _canonicalize_masked_array_dtype(x):
|
|
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
|
"Use arr.filled() to convert the value to a standard numpy array.")
|
|
|
|
def _canonicalize_ndarray_dtype(x):
|
|
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
|
|
|
|
def _canonicalize_python_scalar_dtype(typ, x):
|
|
return np.asarray(
|
|
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
|
|
|
|
canonicalize_dtype_handlers: dict[Any, Callable] = {}
|
|
canonicalize_dtype_handlers.update(
|
|
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
|
|
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
|
|
canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
|
|
canonicalize_dtype_handlers.update(
|
|
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
|
|
canonicalize_dtype_handlers[core.Token] = identity
|
|
canonicalize_dtype_handlers[core.DArray] = identity
|
|
canonicalize_dtype_handlers[core.MutableArray] = identity
|
|
|
|
initial_style_primitives: set[core.Primitive] = set()
|
|
|
|
def register_initial_style_primitive(prim: core.Primitive):
|
|
initial_style_primitives.add(prim)
|