140 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			140 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2018 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| # Lowering of jaxprs into XLA (HLO) computations.
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| from collections.abc import Callable
 | |
| from functools import partial
 | |
| from typing import Any, Union
 | |
| 
 | |
| import numpy as np
 | |
| 
 | |
| from jax._src import core
 | |
| from jax._src import deprecations
 | |
| from jax._src import dtypes
 | |
| from jax._src.abstract_arrays import numpy_scalar_types
 | |
| from jax._src.util import safe_zip, safe_map
 | |
| 
 | |
| from jax._src.typing import Shape
 | |
| 
 | |
| from jax._src.lib import xla_client as xc
 | |
| 
 | |
| map, unsafe_map = safe_map, map
 | |
| zip, unsafe_zip = safe_zip, zip
 | |
| 
 | |
| # Types
 | |
| 
 | |
| def identity(x): return x
 | |
| 
 | |
| _scalar_types = dtypes.python_scalar_dtypes.keys()
 | |
| 
 | |
| # Utilities
 | |
| 
 | |
| # HLO instructions optionally can be annotated to say how the output should be
 | |
| # spatially partitioned (represented in XLA as OpSharding protos, see
 | |
| # sharding_to_proto). For array outputs, the annotation is either an int per
 | |
| # dimension specifying the number of ways that dimension divided (i.e. the total
 | |
| # number of shards is the product), or None to indicate the array should be
 | |
| # replicated. Tuple outputs are represented as tuples thereof. XLA supports
 | |
| # arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
 | |
| # checkers don't support recursive types), so we only represent one level of
 | |
| # nesting in this type definition.
 | |
| SpatialSharding = Union[Shape, None, tuple[Union[Shape, None], ...]]
 | |
| 
 | |
| 
 | |
| def sharding_to_proto(sharding: SpatialSharding):
 | |
|   """Converts a SpatialSharding to an OpSharding.
 | |
| 
 | |
|   See
 | |
|   https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
 | |
|   for details on the OpSharding proto.
 | |
|   """
 | |
|   proto = xc.OpSharding()
 | |
|   if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
 | |
|     assert all(s is None or isinstance(s, tuple) for s in sharding)
 | |
|     return tuple_sharding_proto(list(map(sharding_to_proto, sharding)))
 | |
| 
 | |
|   if sharding is None:
 | |
|     proto.type = xc.OpSharding.Type.REPLICATED
 | |
|   else:
 | |
|     proto.type = xc.OpSharding.Type.OTHER
 | |
|     proto.tile_assignment_dimensions = list(sharding)  # type: ignore
 | |
|     proto.tile_assignment_devices = list(range(np.prod(sharding)))  # type: ignore
 | |
|   return proto
 | |
| 
 | |
| def tuple_sharding_proto(elems):
 | |
|   proto = xc.OpSharding()
 | |
|   assert all(isinstance(e, type(proto)) for e in elems)
 | |
|   proto.type = xc.OpSharding.Type.TUPLE
 | |
|   proto.tuple_shardings = elems
 | |
|   return proto
 | |
| 
 | |
| 
 | |
| ### handlers
 | |
| 
 | |
| 
 | |
| # IR constants
 | |
| 
 | |
| class InvalidInputException(Exception):
 | |
|   pass
 | |
| 
 | |
| 
 | |
| # TODO(mattjj): try to remove this canonicalize_dtype stuff
 | |
| def canonicalize_dtype(x):
 | |
|   typ = type(x)
 | |
|   handler = canonicalize_dtype_handlers.get(typ)
 | |
|   if handler: return handler(x)
 | |
|   for typ in typ.__mro__:
 | |
|     handler = canonicalize_dtype_handlers.get(typ)
 | |
|     if handler: return handler(x)
 | |
|   if hasattr(x, '__jax_array__'):
 | |
|     deprecations.warn(
 | |
|       'jax-abstract-dunder-array',
 | |
|       ('Triggering of __jax_array__() during abstractification is deprecated.'
 | |
|        ' To avoid this error, either explicitly convert your object using'
 | |
|        ' jax.numpy.array(), or register your object as a pytree.'),
 | |
|       stacklevel=6)
 | |
|     return canonicalize_dtype(x.__jax_array__())
 | |
|   raise InvalidInputException(
 | |
|       f"Argument '{x}' of type {type(x)} is not a valid JAX type.")
 | |
| 
 | |
| def _canonicalize_masked_array_dtype(x):
 | |
|   raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
 | |
|                    "Use arr.filled() to convert the value to a standard numpy array.")
 | |
| 
 | |
| def _canonicalize_ndarray_dtype(x):
 | |
|   return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
 | |
| 
 | |
| def _canonicalize_python_scalar_dtype(typ, x):
 | |
|   return np.asarray(
 | |
|       x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
 | |
| 
 | |
| canonicalize_dtype_handlers: dict[Any, Callable] = {}
 | |
| canonicalize_dtype_handlers.update(
 | |
|     (t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
 | |
| canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
 | |
| canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
 | |
| canonicalize_dtype_handlers.update(
 | |
|     (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
 | |
| canonicalize_dtype_handlers[core.Token] = identity
 | |
| canonicalize_dtype_handlers[core.DArray] = identity
 | |
| canonicalize_dtype_handlers[core.MutableArray] = identity
 | |
| 
 | |
| initial_style_primitives: set[core.Primitive] = set()
 | |
| 
 | |
| def register_initial_style_primitive(prim: core.Primitive):
 | |
|   initial_style_primitives.add(prim)
 |