730 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			730 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2021 The JAX Authors.
 | ||
| #
 | ||
| # Licensed under the Apache License, Version 2.0 (the "License");
 | ||
| # you may not use this file except in compliance with the License.
 | ||
| # You may obtain a copy of the License at
 | ||
| #
 | ||
| #     https://www.apache.org/licenses/LICENSE-2.0
 | ||
| #
 | ||
| # Unless required by applicable law or agreed to in writing, software
 | ||
| # distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| # See the License for the specific language governing permissions and
 | ||
| # limitations under the License.
 | ||
| """Allows JAX to call TensorFlow functions with support for autodiff.
 | ||
| 
 | ||
| **Experimental: please give feedback, and expect changes.**
 | ||
| 
 | ||
| This module introduces the function :func:`call_tf` that allows JAX to call
 | ||
| TensorFlow functions.
 | ||
| 
 | ||
| For examples and details, see
 | ||
| https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
 | ||
| 
 | ||
| """
 | ||
| 
 | ||
| from __future__ import annotations
 | ||
| 
 | ||
| from collections.abc import Callable, Sequence
 | ||
| import dataclasses
 | ||
| import functools
 | ||
| from typing import Any
 | ||
| 
 | ||
| from absl import logging
 | ||
| import jax
 | ||
| from jax import dlpack
 | ||
| from jax import dtypes
 | ||
| from jax import numpy as jnp
 | ||
| from jax import tree_util
 | ||
| from jax._src import ad_util
 | ||
| from jax._src import core
 | ||
| from jax._src import effects
 | ||
| from jax._src import util
 | ||
| from jax._src.lib import _jax
 | ||
| from jax._src.lib.mlir import ir
 | ||
| from jax._src.lib.mlir.dialects import func as func_dialect
 | ||
| from jax._src.lib.mlir.dialects import hlo
 | ||
| from jax.experimental.jax2tf import jax2tf as jax2tf_internal
 | ||
| from jax._src.interpreters import mlir
 | ||
| import ml_dtypes
 | ||
| import numpy as np
 | ||
| import tensorflow as tf
 | ||
| 
 | ||
| 
 | ||
| map = util.safe_map
 | ||
| zip = util.safe_zip
 | ||
| 
 | ||
| TfConcreteFunction = Any
 | ||
| TfVal = jax2tf_internal.TfVal
 | ||
| 
 | ||
| # The platforms for which to use DLPack to avoid copying (only works on GPU
 | ||
| # and CPU at the moment, and only for Array). For CPU we don't need
 | ||
| # DLPack, if we are careful.
 | ||
| _DLPACK_PLATFORMS = ("gpu",)
 | ||
| 
 | ||
| class UnspecifiedOutputShapeDtype:
 | ||
|   pass
 | ||
| 
 | ||
| def call_tf(
 | ||
|     callable_tf: Callable,
 | ||
|     has_side_effects=True,
 | ||
|     ordered=False,
 | ||
|     output_shape_dtype=UnspecifiedOutputShapeDtype(),
 | ||
|     call_tf_graph=False,
 | ||
| ) -> Callable:
 | ||
|   """Calls a TensorFlow function from JAX, with support for reverse autodiff.
 | ||
| 
 | ||
|   The ``callable_tf`` will be called with TensorFlow-compatible arguments (
 | ||
|   numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
 | ||
|   function must return the same type of results.
 | ||
| 
 | ||
|   If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
 | ||
|   or :func:`jax.pmap`, or a control-flow primitive) then
 | ||
|   ``callable_tf`` will be compiled with ``tf.function(callable_tf,
 | ||
|   jit_compile=True)``
 | ||
|   and the resulting XLA computation will be embedded in JAX's XLA computation.
 | ||
| 
 | ||
|   If ``call_tf`` appears outside a JAX staging context, it will be called inline
 | ||
|   using TensorFlow eager mode.
 | ||
| 
 | ||
|   The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
 | ||
|   ``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
 | ||
|   that the gradient will be TensorFlow-accurate, e.g., will respect the
 | ||
|   custom gradients that may be defined for the code in ``callable_tf``.
 | ||
| 
 | ||
|   For an example and more details see the
 | ||
|   `README
 | ||
|   <https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
 | ||
| 
 | ||
|   Args:
 | ||
|     callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
 | ||
|       arguments.
 | ||
|     has_side_effects: if True then it ensures that instances of this primitive
 | ||
|       are not removed or replicated by JAX optimizations such as dead-code
 | ||
|       elimination.
 | ||
|     ordered: If true, calls are modeled as having ordered effects.
 | ||
|     output_shape_dtype: An optional declaration of the expected shape and dtype
 | ||
|       of the result of the called TensorFlow function. If given it will be used
 | ||
|       during JAX tracing to form the abstract values of the results of the
 | ||
|       `call_tf`. If not given then we form a `tf.Graph` for the called
 | ||
|       TensorFlow function and we use the TensorFlow-inferred shapes and types.
 | ||
|       Must be a pytree matching the structure of the nested structure returned
 | ||
|       from the TensorFlow function, containing objects with `.shape` and
 | ||
|       `.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
 | ||
|     call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
 | ||
|       future.
 | ||
| 
 | ||
|   Returns: a JAX callable that can be invoked with JAX pytree arguments, in
 | ||
|     op-by-op mode or in a staged context. This callable can be used with JAX's
 | ||
|     reverse-mode autodiff (:func:`jax.grad`).
 | ||
|   """
 | ||
|   @jax.custom_vjp
 | ||
|   def make_call(*args_jax):
 | ||
|     """We wrap it all in `make_call` so that we can attach custom VJP."""
 | ||
| 
 | ||
|     args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
 | ||
|     # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
 | ||
|     def canonical_arg(v):
 | ||
|       v = v if getattr(v, "dtype", None) else np.asarray(v)
 | ||
|       dtype = dtypes.canonicalize_dtype(v.dtype)
 | ||
|       if dtype != v.dtype:
 | ||
|         v = v.astype(dtype)
 | ||
|       return v
 | ||
| 
 | ||
|     args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
 | ||
|     def make_tensorspec(a_jax):
 | ||
|       a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
 | ||
|       a_tf_shape = [d if core.is_constant_dim(d) else None for d in a_jax.shape]
 | ||
|       return tf.TensorSpec(a_tf_shape, a_tf_dtype)
 | ||
|     args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
 | ||
| 
 | ||
|     if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
 | ||
|       output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
 | ||
|       output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
 | ||
|     else:
 | ||
|       output_avals, output_shape_dtype_tree = None, None
 | ||
| 
 | ||
|     res_treedef = None  # We'll store here the result treedef
 | ||
|     res_tf_flat = None  # For error reporting
 | ||
|     # The function below will be called at least once, either in eager
 | ||
|     # mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
 | ||
|     def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
 | ||
|       args_tf = args_treedef.unflatten(args_tf_flat)
 | ||
|       res_tf = callable_tf(*args_tf)
 | ||
| 
 | ||
|       # b/279454591: When `callable_tf` is a tf function with zero outputs, it
 | ||
|       # returns a `StatefulPartitionedCall` (if the function is stateful) or
 | ||
|       # `PartitionedCall` (if the function is stateless) op instead of
 | ||
|       # tf.Tensors. We work around this issue by replacing the output `res_tf`
 | ||
|       # with an empty list.
 | ||
| 
 | ||
|       if isinstance(res_tf, tf.Operation):
 | ||
|         assert (
 | ||
|             res_tf.type == "StatefulPartitionedCall"
 | ||
|             or res_tf.type == "PartitionedCall"
 | ||
|         )
 | ||
|         t_out = res_tf.get_attr("Tout")
 | ||
|         # t_out should be an empty list.
 | ||
|         assert not t_out, (
 | ||
|             "The TF function returned an unexpected result, please check its"
 | ||
|             f" function body. res_tf = {res_tf}"
 | ||
|         )
 | ||
|         res_tf = t_out
 | ||
| 
 | ||
|       nonlocal res_treedef, res_tf_flat
 | ||
|       res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
 | ||
|       assert res_treedef is None or res_treedef == res_treedef_now, (
 | ||
|           f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
 | ||
|       res_treedef = res_treedef_now
 | ||
|       if output_avals is not None:
 | ||
|         if res_treedef != output_shape_dtype_tree:
 | ||
|           raise ValueError(
 | ||
|               "The pytree of the TensorFlow function results does not match the "
 | ||
|               "pytree of the declared output_shape_dtype:\n"
 | ||
|               f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
 | ||
|         assert len(output_avals) == len(res_tf_flat)
 | ||
| 
 | ||
|       checked_res_tf_flat = [
 | ||
|           check_tf_result(i, r_tf, r_aval)
 | ||
|           for i, (r_tf, r_aval) in enumerate(
 | ||
|               zip(res_tf_flat,
 | ||
|                   (output_avals
 | ||
|                    if output_avals is not None
 | ||
|                    else (None,) * len(res_tf_flat))))]
 | ||
|       return checked_res_tf_flat
 | ||
| 
 | ||
|     # Prepare a tf.function ahead of time, to cache the concrete functions. This
 | ||
|     # won't be used in op-by-op execution mode.
 | ||
|     function_flat_tf = tf.function(
 | ||
|         callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
 | ||
| 
 | ||
|     res_jax_flat = call_tf_p.bind(
 | ||
|         *args_flat_jax,
 | ||
|         # Carry the actual function such that op-by-op call can call in TF eager mode.
 | ||
|         callable_flat_tf=callable_flat_tf,
 | ||
|         function_flat_tf=function_flat_tf,
 | ||
|         args_flat_sig_tf=args_flat_sig_tf,
 | ||
|         output_avals=output_avals,
 | ||
|         has_side_effects=has_side_effects,
 | ||
|         ordered=ordered,
 | ||
|         call_tf_graph=call_tf_graph,
 | ||
|     )
 | ||
| 
 | ||
|     # We must have called callable_flat_tf by nοw
 | ||
|     assert res_treedef is not None
 | ||
|     return res_treedef.unflatten(res_jax_flat)
 | ||
| 
 | ||
|   # Define the fwd and bwd custom_vjp functions
 | ||
|   def make_call_vjp_fwd(*args_jax):
 | ||
|     # Return the primal arguments as the residual
 | ||
|     return make_call(*args_jax), args_jax
 | ||
| 
 | ||
|   def make_call_vjp_bwd(residual_jax, ct_res_jax):
 | ||
|     args_jax = residual_jax  # residual is the primal argument
 | ||
| 
 | ||
|     def tf_vjp_fun(args_tf, ct_res_tf):
 | ||
|       """Invoke TF gradient."""
 | ||
| 
 | ||
|       # TF does not like us to watch non-float vars or Nones.
 | ||
|       def replace_non_float_or_none(arg_tf):
 | ||
|         if arg_tf is not None and (
 | ||
|             arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
 | ||
|         ):
 | ||
|           return arg_tf
 | ||
|         else:
 | ||
|           # When watched, this will be ignored. When used in results it will
 | ||
|           # result in a floating 0. gradient, which JAX will ignore (and
 | ||
|           # replace it with a float0)
 | ||
|           return tf.zeros((), dtype=tf.float32)
 | ||
| 
 | ||
|       watched_args_tf = tf.nest.map_structure(
 | ||
|           replace_non_float_or_none, args_tf
 | ||
|       )
 | ||
|       with tf.GradientTape(persistent=True) as tape:
 | ||
|         tape.watch(watched_args_tf)
 | ||
|         res = callable_tf(*args_tf)
 | ||
| 
 | ||
|       tf.nest.assert_same_structure(res, ct_res_tf)
 | ||
|       dres_darg = tape.gradient(
 | ||
|           tf.nest.map_structure(replace_non_float_or_none, res),
 | ||
|           sources=watched_args_tf,
 | ||
|           output_gradients=ct_res_tf,
 | ||
|           unconnected_gradients=tf.UnconnectedGradients.ZERO,
 | ||
|       )
 | ||
| 
 | ||
|       dres_darg = tree_util.tree_map(
 | ||
|           lambda x: x if x is None else tf.convert_to_tensor(x),
 | ||
|           dres_darg,
 | ||
|       )
 | ||
| 
 | ||
|       # callable_tf may mutate (the structure of) args_tf, thus we check against
 | ||
|       # watched_args_tf which should be structurally the same as the original
 | ||
|       # args_tf.
 | ||
|       tf.nest.assert_same_structure(dres_darg, watched_args_tf)
 | ||
|       return dres_darg
 | ||
| 
 | ||
|     # Use call_tf to call the VJP function
 | ||
|     ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
 | ||
|     # We must make the float0s that JAX expects
 | ||
|     def fix_float0(arg_jax, ct_arg_jax):
 | ||
|       if arg_jax is None:
 | ||
|         return None
 | ||
|       arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
 | ||
|       ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
 | ||
|       if ct_arg_dtype != ct_arg_jax.dtype:
 | ||
|         return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
 | ||
|                                                         ct_arg_dtype))
 | ||
|       return ct_arg_jax
 | ||
| 
 | ||
|     ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax,
 | ||
|                                            is_leaf=lambda x: x is None)
 | ||
|     return ct_args_jax_fixed
 | ||
| 
 | ||
|   make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
 | ||
|   return util.wraps(callable_tf)(make_call)
 | ||
| 
 | ||
| 
 | ||
| def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
 | ||
|   # Check that the TF function returns values of expected types. This
 | ||
|   # improves error reporting, preventing hard-to-diagnose errors downstream
 | ||
|   try:
 | ||
|     jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
 | ||
|   except Exception as e:
 | ||
|     msg = ("The called TF function returns a result that is not "
 | ||
|            f"convertible to JAX: {r_tf}.")
 | ||
|     raise ValueError(msg) from e
 | ||
| 
 | ||
|   if r_aval is None:
 | ||
|     return r_tf
 | ||
|   # We convert to TF type, and canonicalize to 32-bit if necessary
 | ||
|   r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
 | ||
|   # Checking shapes is trickier in presence of dynamic shapes. I wish we could
 | ||
|   # check at runtime that the returned shape matches the declared shape. I wish
 | ||
|   # that tf.ensure_shape did this, but it can only take shapes that contain None
 | ||
|   # not computed shapes. However, in eager mode we should be able to resolve
 | ||
|   # the declared shapes to constants and we get better checking.
 | ||
|   if tf.executing_eagerly():
 | ||
|     r_aval_shape_tf = jax2tf_internal._eval_shape(r_aval.shape)
 | ||
|   else:
 | ||
|     r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
 | ||
|   # We do as much checking as we can here, instead of relying on tf.ensure_shape
 | ||
|   # because the latter gives different errors in eager vs. compiled mode.
 | ||
|   # TODO(b/279454591): This strange error is from TF. Eager function suppose
 | ||
|   # return tf Val with concrete shape but not.  Here we change exception to warn
 | ||
|   # and bypass it. This case need revisit on TF side.
 | ||
|   try:
 | ||
|     _ = len(r_tf.shape)
 | ||
|   except ValueError as e:
 | ||
|     msg = (
 | ||
|         "The shape check test cannot be performed because the shape of the"
 | ||
|         "`r_tf` tensor cannot be obtained."
 | ||
|         f"r_tf = {r_tf}, r_aval = {r_aval}"
 | ||
|     )
 | ||
|     msg += str(e)
 | ||
|     logging.warning(msg)
 | ||
|     return r_tf
 | ||
|   if (r_tf.dtype != r_aval_dtype_tf or
 | ||
|       len(r_tf.shape) != len(r_aval_shape_tf) or
 | ||
|       any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
 | ||
|           for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
 | ||
|     msg = ("The shapes or dtypes returned by the TensorFlow function "
 | ||
|            "do not match the declared output_shape_dtype:\n"
 | ||
|            f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
 | ||
|     raise ValueError(msg)
 | ||
|   # At this point tf.ensure_shape does not do much, it should never throw an
 | ||
|   # error, albeit it may refine the shape a bit.
 | ||
|   return tf.ensure_shape(r_tf, r_aval_shape_tf)
 | ||
| 
 | ||
| 
 | ||
| call_tf_p = core.Primitive("call_tf")
 | ||
| call_tf_p.multiple_results = True
 | ||
| 
 | ||
| # The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
 | ||
| def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
 | ||
|   # On GPU we use dlpack to avoid copies of data to the host.
 | ||
|   def _arg_jax_to_tf(arg_jax):
 | ||
|     if (isinstance(arg_jax, jax.Array) and
 | ||
|         list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
 | ||
|         arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
 | ||
|       return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__())
 | ||
|     # The following avoids copies to the host on CPU, always for Array
 | ||
|     # and even for ndarray if they are sufficiently aligned.
 | ||
|     # TODO(necula): on TPU this copies to the host!
 | ||
|     if getattr(arg_jax, 'dtype', None) == dtypes.float0:
 | ||
|       return tf.zeros(shape=arg_jax.shape,
 | ||
|                       dtype=jax2tf_internal._tf_np_dtype_for_float0)
 | ||
|     return tf.constant(np.asarray(arg_jax))
 | ||
| 
 | ||
|   args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
 | ||
|   with jax2tf_internal.inside_call_tf():
 | ||
|     # Call in TF eager mode
 | ||
|     res_tf_flat = callable_flat_tf(*args_tf_flat)
 | ||
| 
 | ||
|   def _res_tf_to_jax(res_tf: TfVal):
 | ||
|     res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
 | ||
|     if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES:
 | ||
|       res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
 | ||
|       res_jax_platform = res_tf_platform.lower()
 | ||
|       if res_jax_platform in _DLPACK_PLATFORMS:
 | ||
|         res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
 | ||
|         return jax.dlpack.from_dlpack(res_dlpack)
 | ||
| 
 | ||
|     # When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
 | ||
|     # To handle this special case, we create a numpy copy.
 | ||
|     if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
 | ||
|       return jax.device_put(jnp.array(res_tf.numpy()))
 | ||
|     else:
 | ||
|       return jax.device_put(np.asarray(res_tf))
 | ||
| 
 | ||
|   return list(map(_res_tf_to_jax, res_tf_flat))
 | ||
| 
 | ||
| 
 | ||
| call_tf_p.def_impl(_call_tf_impl)
 | ||
| 
 | ||
| @functools.lru_cache(maxsize=128)
 | ||
| def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf):  # -> tf.ConcreteFunction
 | ||
|   with jax2tf_internal.inside_call_tf():
 | ||
|     return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
 | ||
| 
 | ||
| 
 | ||
| # Mark the effectful instances of call_tf
 | ||
| @dataclasses.dataclass(frozen=True)
 | ||
| class CallTfEffect(effects.Effect):
 | ||
|   __str__ = lambda _: "CallTfEffect"
 | ||
| 
 | ||
| call_tf_effect = CallTfEffect()
 | ||
| 
 | ||
| effects.lowerable_effects.add_type(CallTfEffect)
 | ||
| effects.control_flow_allowed_effects.add_type(CallTfEffect)
 | ||
| effects.remat_allowed_effects.add_type(CallTfEffect)
 | ||
| effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
 | ||
| 
 | ||
| 
 | ||
| class CallTfOrderedEffect(effects.Effect):
 | ||
|   __str__ = lambda _: "CallTfOrderedEffect"
 | ||
| 
 | ||
| 
 | ||
| call_tf_ordered_effect = CallTfOrderedEffect()
 | ||
| 
 | ||
| effects.lowerable_effects.add_type(CallTfOrderedEffect)
 | ||
| effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
 | ||
| effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
 | ||
| effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
 | ||
| effects.ordered_effects.add_type(CallTfOrderedEffect)
 | ||
| effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
 | ||
| 
 | ||
| 
 | ||
| def _call_tf_abstract_eval(
 | ||
|     *args_flat_avals,
 | ||
|     function_flat_tf,
 | ||
|     args_flat_sig_tf,
 | ||
|     has_side_effects,
 | ||
|     ordered,
 | ||
|     output_avals,
 | ||
|     call_tf_graph,
 | ||
|     **__,
 | ||
| ):
 | ||
|   # Called only when we form a Jaxpr, i.e., under jit, scan, etc.
 | ||
|   effects = set()
 | ||
|   if ordered:
 | ||
|     effects.add(call_tf_ordered_effect)
 | ||
|   elif has_side_effects:
 | ||
|     effects.add(call_tf_effect)
 | ||
| 
 | ||
|   # If no output_avals is given, then we ask TF to infer the output shapes.
 | ||
|   # We call this even if output_avals is given because it will ensure that
 | ||
|   # callable_flat_tf is called. Since _get_concrete_function_tf is cached
 | ||
|   # there is a small cost of calling it more often than needed.
 | ||
|   concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
 | ||
|                                                         args_flat_sig_tf)
 | ||
| 
 | ||
|   # In the case that the tf.function has no return value
 | ||
|   if len(concrete_function_flat_tf.outputs) == 0:
 | ||
|     return (), effects
 | ||
| 
 | ||
|   if output_avals is not None:
 | ||
|     return output_avals, effects
 | ||
| 
 | ||
|   def is_fully_known_shape(s):
 | ||
|     return s.rank is not None and all(d is not None for d in s)
 | ||
| 
 | ||
|   if all(is_fully_known_shape(s)
 | ||
|         for s in concrete_function_flat_tf.output_shapes):
 | ||
|     avals_from_tf = tuple(
 | ||
|         # We convert to JAX type, and canonicalize to 32-bit if necessary
 | ||
|         core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
 | ||
|         for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
 | ||
|                                 concrete_function_flat_tf.output_shapes))
 | ||
|     return avals_from_tf, effects
 | ||
| 
 | ||
|   msg = ("call_tf cannot call functions whose output has dynamic shape. "
 | ||
|     f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
 | ||
|     "Consider using the `output_shape_dtype` argument to call_tf. "
 | ||
|     "\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
 | ||
|       " for a discussion.")
 | ||
|   raise ValueError(msg)
 | ||
| 
 | ||
| 
 | ||
| call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
 | ||
| 
 | ||
| 
 | ||
| def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype:
 | ||
|   """Converts an MLIR scalar type to a NumPy dtype."""
 | ||
| 
 | ||
|   if ir.IntegerType.isinstance(type):
 | ||
|     type = ir.IntegerType(type)
 | ||
|     width = type.width
 | ||
|     if width == 1:
 | ||
|       return np.dtype(np.bool_)
 | ||
|     elif width == 8:
 | ||
|       return np.dtype(np.uint8 if type.is_unsigned else np.int8)
 | ||
|     elif width == 16:
 | ||
|       return np.dtype(np.uint16 if type.is_unsigned else np.int16)
 | ||
|     elif width == 32:
 | ||
|       return np.dtype(np.uint32 if type.is_unsigned else np.int32)
 | ||
|     elif width == 64:
 | ||
|       return np.dtype(np.uint64 if type.is_unsigned else np.int64)
 | ||
|     else:
 | ||
|       raise ValueError(f"Unsupported integer width: {width}")
 | ||
| 
 | ||
|   elif ir.F16Type.isinstance(type):
 | ||
|     return np.dtype(np.float16)
 | ||
|   elif ir.F32Type.isinstance(type):
 | ||
|     return np.dtype(np.float32)
 | ||
|   elif ir.F64Type.isinstance(type):
 | ||
|     return np.dtype(np.float64)
 | ||
|   elif ir.BF16Type.isinstance(type):
 | ||
|     return np.dtype(ml_dtypes.bfloat16)
 | ||
| 
 | ||
|   elif ir.ComplexType.isinstance(type):
 | ||
|     element_type = ir.ComplexType(type).element_type
 | ||
|     if ir.F32Type.isinstance(element_type):
 | ||
|       return np.dtype(np.complex64)
 | ||
|     elif ir.F64Type.isinstance(element_type):
 | ||
|       return np.dtype(np.complex128)
 | ||
|     else:
 | ||
|       raise ValueError(f"Unsupported complex element type: {element_type}")
 | ||
| 
 | ||
|   else:
 | ||
|     raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}")
 | ||
| 
 | ||
| 
 | ||
| def _call_tf_lowering(
 | ||
|     ctx: mlir.LoweringRuleContext,
 | ||
|     *args_op,
 | ||
|     platform,
 | ||
|     function_flat_tf,
 | ||
|     args_flat_sig_tf,
 | ||
|     has_side_effects,
 | ||
|     ordered,
 | ||
|     call_tf_graph,
 | ||
|     output_avals,
 | ||
|     **_,
 | ||
| ):
 | ||
|   # We use the same TF lowering device as for the embedding JAX computation.
 | ||
|   # One example when this is needed is when the code refers to variables on one
 | ||
|   # device. Or, for sharding annotations (only supported on TPU).
 | ||
| 
 | ||
|   if platform in ["cpu", "tpu"]:
 | ||
|     tf_platform = platform.upper()
 | ||
|   elif platform == "cuda":
 | ||
|     tf_platform = "GPU"
 | ||
|   else:
 | ||
|     raise ValueError("platform {platform} not supported")
 | ||
| 
 | ||
|   concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
 | ||
| 
 | ||
|   captured_inputs = []
 | ||
|   if concrete_function_flat_tf.captured_inputs:
 | ||
|     # The function uses either captured variables or tensors.
 | ||
|     msg = (
 | ||
|         "call_tf works best with a TensorFlow function that does not capture "
 | ||
|         "variables or tensors from the context. "
 | ||
|         "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
 | ||
|         f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
 | ||
|     logging.warning(msg)
 | ||
|     for inp in concrete_function_flat_tf.captured_inputs:
 | ||
|       if inp.dtype == tf.resource:  # A variable; lookup by handle
 | ||
|         inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
 | ||
|         assert len(inp_vars) == 1, f"Found {inp_vars}"
 | ||
|         captured_inputs.append(inp_vars[0])
 | ||
|       else:
 | ||
|         captured_inputs.append(inp)
 | ||
| 
 | ||
|   # The following use case happens when we call_tf a restored saved model that
 | ||
|   # includes parameters (hence functions closing over tf.Variable), and then
 | ||
|   # we jax2tf.convert it with native serialization, under tf.function (or
 | ||
|   # for saving to saved model). The `np.asarray(inp)` fails because it thinks
 | ||
|   # it is in TF graph mode. The `tf.init_scope()` lifts out of function-building
 | ||
|   # graph scopes, and allows us to read the values of the variables
 | ||
|   with tf.init_scope():
 | ||
|     captured_ops = tuple(
 | ||
|         mlir.ir_constant(np.asarray(inp))
 | ||
|         for inp in captured_inputs
 | ||
|     )
 | ||
| 
 | ||
|   if call_tf_graph:
 | ||
|     with jax2tf_internal.inside_call_tf():
 | ||
|       return emit_tf_embedded_graph_custom_call(
 | ||
|           ctx,
 | ||
|           concrete_function_flat_tf,
 | ||
|           tuple(args_op) + captured_ops,
 | ||
|           has_side_effects,
 | ||
|           ordered,
 | ||
|           output_avals,
 | ||
|       )
 | ||
| 
 | ||
|   def convert_to_spec(x):
 | ||
|     if isinstance(x, tf.TensorSpec):
 | ||
|       return x
 | ||
|     else:
 | ||
|       return tf.TensorSpec.from_tensor(x)
 | ||
| 
 | ||
|   args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
 | ||
| 
 | ||
|   with jax2tf_internal.inside_call_tf():
 | ||
|     try:
 | ||
|       func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
 | ||
|           *args_tf_flat
 | ||
|       )(stage="hlo_serialized", platform_name=tf_platform)
 | ||
|     except Exception as e:
 | ||
|       msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
 | ||
|              "\ncall_tf can used " +
 | ||
|               "in a staged context (under jax.jit, lax.scan, etc.) only with " +
 | ||
|               "compilable functions with static output shapes.\n" +
 | ||
|               "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
 | ||
|              "\n\nCaught TensorFlow exception: " + str(e))
 | ||
|       raise ValueError(msg) from e
 | ||
| 
 | ||
|   stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo)
 | ||
|   submodule = ir.Module.parse(stablehlo)
 | ||
|   symtab = ir.SymbolTable(submodule.operation)
 | ||
|   callee_result_types = symtab["main"].type.results
 | ||
|   fn = mlir.merge_mlir_modules(ctx.module_context.module,
 | ||
|                                f"call_tf_{function_flat_tf.name}",
 | ||
|                                submodule,
 | ||
|                                dst_symtab=ctx.module_context.symbol_table)
 | ||
|   call = func_dialect.CallOp(callee_result_types,
 | ||
|                              ir.FlatSymbolRefAttr.get(fn),
 | ||
|                              tuple(args_op) + captured_ops)
 | ||
|   flat_results = call.results
 | ||
| 
 | ||
|   if ordered:
 | ||
|     raise NotImplementedError(
 | ||
|         "ordered=True is not supported in the jitted context without"
 | ||
|         " `call_tf_graph=True`"
 | ||
|     )
 | ||
| 
 | ||
|   outputs = []
 | ||
|   for op, res_type in zip(flat_results, callee_result_types):
 | ||
|     if not res_type.has_static_shape:
 | ||
|       msg = (
 | ||
|           "Compiled TensorFlow function has dynamic output shape "
 | ||
|           + f"{res_type}. call_tf can used in a staged context (under jax.jit,"
 | ||
|           " lax.scan, etc.) only with compilable functions with static"
 | ||
|           " output shapes. See"
 | ||
|           " https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
 | ||
|           " for a discussion."
 | ||
|       )
 | ||
|       raise ValueError(msg)
 | ||
| 
 | ||
|     res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type)
 | ||
|     # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
 | ||
|     jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
 | ||
|     if res_dtype != jax_res_dtype:
 | ||
|       op = hlo.ConvertOp(
 | ||
|           mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)),
 | ||
|           op,
 | ||
|       ).result
 | ||
|     outputs.append(op)
 | ||
|   return outputs
 | ||
| 
 | ||
| 
 | ||
| def _register_call_lowering(platform):
 | ||
|   mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
 | ||
|                                                       platform=platform),
 | ||
|                          platform=platform)
 | ||
| for platform in ("cpu", "cuda", "tpu"):
 | ||
|   _register_call_lowering(platform)
 | ||
| 
 | ||
| # Support the call_tf under jax2tf.convert in eager mode
 | ||
| def _jax2tf_call_tf(*args: TfVal,
 | ||
|                     callable_flat_tf: Callable,
 | ||
|                     **_) -> TfVal:
 | ||
|   with jax2tf_internal.inside_call_tf():
 | ||
|     res_tf_flat = callable_flat_tf(*args)
 | ||
|   return res_tf_flat
 | ||
| 
 | ||
| jax2tf_internal.tf_impl[call_tf_p] = _jax2tf_call_tf
 | ||
| 
 | ||
| 
 | ||
| def emit_tf_embedded_graph_custom_call(
 | ||
|     ctx: mlir.LoweringRuleContext,
 | ||
|     concrete_function_flat_tf,
 | ||
|     operands: Sequence[ir.Value],
 | ||
|     has_side_effects,
 | ||
|     ordered,
 | ||
|     output_avals,
 | ||
| ):
 | ||
|   """Emits a custom call referencing a tf.Graph embedding of the TF function.
 | ||
| 
 | ||
|   All call_tf called function information is stored in tf.metadata.
 | ||
|   This includes:
 | ||
|   (1) The called function name: This name will be used by the runtime to execute
 | ||
|   the callback.
 | ||
|   (2) The called function index in the XLACallModule `function_list` attribute.
 | ||
|   """
 | ||
|   call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
 | ||
|   if call_tf_concrete_function_list is None:
 | ||
|     raise ValueError(
 | ||
|         "call_tf_graph=True only support exporting by jax2tf.convert currently."
 | ||
|     )
 | ||
|   # TODO(necula): It is dangerous to modify global state when lowering because
 | ||
|   # there are a number of lowering caches that only cache the StableHLO.
 | ||
|   # See call_tf_test.py:test_multi_platform_call_tf_graph.
 | ||
|   called_index = add_to_call_tf_concrete_function_list(
 | ||
|       concrete_function_flat_tf, call_tf_concrete_function_list)
 | ||
|   tf_backend_config = {
 | ||
|       "has_token_input_output": ir.BoolAttr.get(ordered),
 | ||
|       "called_index": mlir.i64_attr(called_index),
 | ||
|   }
 | ||
|   result_avals = ctx.avals_out if ctx.avals_out is not None else ()
 | ||
| 
 | ||
|   operands = list(operands)
 | ||
|   result_types = list(
 | ||
|       mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
 | ||
|   )
 | ||
|   if ordered:
 | ||
|     operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
 | ||
|     result_types.insert(0, mlir.token_type())
 | ||
| 
 | ||
|   custom_call = hlo.CustomCallOp(
 | ||
|       result_types,
 | ||
|       operands,
 | ||
|       call_target_name=ir.StringAttr.get("tf.call_tf_function"),
 | ||
|       has_side_effect=ir.BoolAttr.get(has_side_effects),
 | ||
|       api_version=mlir.i32_attr(2),
 | ||
|       called_computations=ir.ArrayAttr.get([]),
 | ||
|       backend_config=ir.StringAttr.get(""),
 | ||
|   )
 | ||
|   # Store TF metadata in unregistered attribute
 | ||
|   custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
 | ||
|       tf_backend_config
 | ||
|   )
 | ||
| 
 | ||
|   results = list(custom_call.results)
 | ||
|   if ordered:
 | ||
|     token = results.pop(0)
 | ||
|     ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
 | ||
| 
 | ||
|   return results
 | ||
| 
 | ||
| 
 | ||
| def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
 | ||
|   try:
 | ||
|     called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
 | ||
|   except ValueError:
 | ||
|     called_index = len(call_tf_concrete_function_list)
 | ||
|     call_tf_concrete_function_list.append(concrete_tf_fn)
 | ||
|   return called_index
 |