124 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			124 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2023 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Utilities for tracing stateful functions."""
 | |
| 
 | |
| from functools import partial
 | |
| from typing import Callable
 | |
| 
 | |
| import jax
 | |
| from jax._src import core
 | |
| from jax._src import dtypes
 | |
| from jax._src import linear_util as lu
 | |
| from jax._src.interpreters import partial_eval as pe
 | |
| from jax._src.state import AbstractRef
 | |
| from jax._src.state.primitives import ref_get
 | |
| from jax._src.typing import DTypeLike
 | |
| from jax._src.util import safe_map, safe_zip, split_list
 | |
| 
 | |
| map, unsafe_map = safe_map, map
 | |
| zip, unsafe_zip = safe_zip, zip
 | |
| 
 | |
| 
 | |
| def hoist_consts_to_refs(
 | |
|     jaxpr: core.Jaxpr,
 | |
|     *,
 | |
|     index: int = 0,
 | |
|     make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
 | |
| ) -> core.Jaxpr:
 | |
|   """Hoists the constants in the given jaxpr into invars.
 | |
| 
 | |
|   Args:
 | |
|     jaxpr: The jaxpr.
 | |
|     index: The index where the invars for the constants should be inserted.
 | |
|       By default, the new invars are inserted *before* any existing invars.
 | |
|     make_abstract_ref: a callable to construct an AbstractRef, or subtype
 | |
|       thereof, from a constant AbstractValue.
 | |
| 
 | |
|   Returns:
 | |
|     A new jaxpr where the constants were hoisted into invars as ``Ref``s.
 | |
|   """
 | |
|   if not jaxpr.constvars:
 | |
|     return jaxpr  # Nothing to hoist.
 | |
| 
 | |
|   is_const_ref = [
 | |
|       isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
 | |
|   ]
 | |
|   const_avals = [
 | |
|       var.aval if is_ref else make_abstract_ref(var.aval)
 | |
|       for is_ref, var in zip(is_const_ref, jaxpr.constvars)
 | |
|   ]
 | |
|   in_avals = [var.aval for var in jaxpr.invars]
 | |
|   in_avals[index:index] = const_avals
 | |
| 
 | |
|   def _hoist(*consts_args):
 | |
|     args0, all_consts, args1 = split_list(
 | |
|         consts_args, [index, len(const_avals)]
 | |
|     )
 | |
|     # We immediately read the const values out of the `Ref`s.
 | |
|     all_consts = [
 | |
|         c if is_ref else ref_get(c, ())
 | |
|         for is_ref, c in zip(is_const_ref, all_consts)
 | |
|     ]
 | |
|     return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
 | |
| 
 | |
|   hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
 | |
|       lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
 | |
|   assert not consts, "All consts should have been converted to refs"
 | |
|   return hoisted_jaxpr
 | |
| 
 | |
| 
 | |
| def val_to_ref_aval(x) -> AbstractRef:
 | |
|   aval = core.get_aval(x)
 | |
|   if type(aval) is not core.ShapedArray:
 | |
|     raise TypeError(f"can't make ref from {x}")
 | |
|   return AbstractRef(aval)
 | |
| 
 | |
| 
 | |
| def dtype_bitwidth(dtype: DTypeLike) -> int:
 | |
|   if dtypes.isdtype(dtype, "integral"):
 | |
|     return dtypes.iinfo(dtype).bits
 | |
|   return dtypes.dtype(dtype).itemsize * 8
 | |
| 
 | |
| 
 | |
| def bitcast(x, dtype: DTypeLike):
 | |
|   x_bitwidth = dtype_bitwidth(x.dtype)
 | |
|   y_bitwidth = dtype_bitwidth(dtype)
 | |
|   shape = list(x.shape)
 | |
|   if x_bitwidth != y_bitwidth:
 | |
|     if len(shape) < 2:
 | |
|       raise NotImplementedError(
 | |
|           "Bitcast 1D ref with bitwidth change is not supported."
 | |
|       )
 | |
|     # Note: this is only valid on TPU.
 | |
|     if shape[-2] * x_bitwidth % y_bitwidth != 0:
 | |
|       raise ValueError(
 | |
|           "Expected input and output shapes are the same after multiplying"
 | |
|           " the second-minor dimension by the bitwidths."
 | |
|       )
 | |
|   shape[-2] = shape[-2] * x_bitwidth // y_bitwidth
 | |
|   if x_bitwidth < y_bitwidth:
 | |
|     ratio = y_bitwidth // x_bitwidth
 | |
|     x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes(
 | |
|         -1, -2
 | |
|     )
 | |
|   y = jax.lax.bitcast_convert_type(x, dtype)
 | |
|   if x_bitwidth > y_bitwidth:
 | |
|     y = y.swapaxes(-1, -2).reshape(shape)
 | |
|   return y
 | |
| 
 | |
| 
 | |
| def eval_bitcast_shape(x, dtype: DTypeLike):
 | |
|   f = partial(bitcast, dtype=dtype)
 | |
|   return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape
 |