# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable, Sequence import functools from typing import Any import numpy as np from jax._src import core from jax._src import dispatch from jax._src import effects from jax._src import ffi from jax._src import tree_util from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.lib import ffi as ffi_lib export = util.set_module("jax.experimental.buffer_callback") Buffer = export(ffi_lib.Buffer) ExecutionStage = export(ffi_lib.ExecutionStage) ExecutionContext = export(ffi_lib.ExecutionContext) def buffer_callback( callback: Callable[..., None], result_shape_dtypes: object, *, has_side_effect: bool = False, vmap_method: str | None = None, input_output_aliases: dict[int, int] | None = None, command_buffer_compatible: bool = False, ): """An experimental callback that operates in place on device buffers. Only supported on CPU and GPU backends. Note that the plan is for this to eventually be replaced by a consolidated callback API built using JAX mutable arrays, but for now this provides a mechanism for prototyping computational kernels using other Python libraries including Numpy, PyTorch, Cupy, and others. Let's start with a simple example: >>> def py_add_one_inplace(ctx, out, x): ... np.asarray(out)[...] = np.asarray(x) + 1 ... >>> x = jnp.array(41, dtype=jnp.int32) >>> out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) >>> add_one = buffer_callback(py_add_one_inplace, out_type) >>> add_one(x) # doctest: +SKIP Array(42, dtype=int32) In this example, we're executing a numpy computation via JAX, and this could have been implemented using :func:`jax.pure_callback`, but in this case, the output is being populated in-place. This means that JAX doesn't need to copy the output arrays upon returning from the callback. Note that even though the callback function operates on mutable buffers, JAX still sees this as an operation that consumes and produces regular immutable JAX arrays. Unlike the other JAX callback APIs, ``buffer_callback`` requires that the user-defined Python function have the following signature: .. code-block:: python def callback(ctx: ExecutionContext, out, *args) -> None: ... where ``ctx`` is an instance of :class:`~jax.experimental.buffer_callback.ExecutionContext`, which mainly provides access to XLA's computation stream when running on GPU, ``out`` is a pytree of mutable :class:`~jax.experimental.buffer_callback.Buffer` objects, and the ``args`` arguments have the same pytree structure as the inputs, but each leaf is :class:`~jax.experimental.buffer_callback.Buffer`. This callback should not return any values, and it should overwrite the ``out`` buffers in place to output values back to JAX. It's important to note that this Python function can't really be called except via ```buffer_callback`` itself, because it's not (yet!) possible to construct mutable JAX buffers directly in Python. The bespoke :class:`~jax.experimental.buffer_callback.Buffer` type is an array-like object that supports the ``__array__`` protocol on CPU, the ``__cuda_array_interface__`` protocol on GPU, and the ``__dlpack__`` protocol on both CPU and GPU. Args: callback: A Python function with the signature and behavior described above. result_shape_dtypes: A pytree whose leaves have ``shape`` and ``dtype`` attributes, with a structure that matches the expected output of the callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used to define leaf values. has_side_effect: Whether the callback has side effects. vmap_method: A string specifying how the callback transforms under :func:`~jax.vmap` as described in the docs for :func:`~jax.pure_callback`. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. command_buffer_compatible: if ``True``, the callback will be traced into the command buffer. This means that the Python code should only be executed once, and then the operations will be replayed for every subsequent call. Returns: A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof), and pytree of :class:`jax.Array` objects whose structure matches that of ``result_shape_dtypes``. See Also: - :func:`jax.pure_callback`: callback designed for pure host functions. - :func:`jax.experimental.io_callback`: callback designed for impure host functions. - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. """ flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes) flat_result_avals = tuple( core.ShapedArray(x.shape, x.dtype) for x in flat_shape_dtypes ) def wrapped_callback(*args, **kwargs): flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) in_avals = [core.get_aval(x) for x in flat_args] static_input_output_aliases: tuple[tuple[int, int], ...] = () if input_output_aliases is not None: for i_idx, o_idx in sorted(input_output_aliases.items()): i_idx, o_idx = int(i_idx), int(o_idx) if i_idx >= len(args): raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " f"with input index {i_idx} outside the range [0, " f"{len(args)}).") if o_idx >= len(flat_result_avals): raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " f"with output index {o_idx} outside the range [0, " f"{len(flat_result_avals)}).") in_aval = in_avals[i_idx] out_aval = flat_result_avals[o_idx] if not ffi._check_compatible_avals(in_aval, out_aval): raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " f"referring to an input with abstract value {in_aval} and an " f"output with a different abstract value {out_aval}.") static_input_output_aliases += ((i_idx, o_idx),) out_flat = buffer_callback_p.bind( *flat_args, callback=callback, result_avals=flat_result_avals, in_tree=in_tree, out_tree=out_tree, vmap_method=vmap_method, has_side_effect=has_side_effect, input_output_aliases=static_input_output_aliases, command_buffer_compatible=command_buffer_compatible, ) return tree_util.tree_unflatten(out_tree, out_flat) return wrapped_callback buffer_callback_p = core.Primitive("buffer_callback") buffer_callback_p.multiple_results = True dispatch.prim_requires_devices_during_lowering.add(buffer_callback_p) dispatch.simple_impl(buffer_callback_p) class BufferCallbackEffect(effects.Effect): def __str__(self): return "BufferCallback" _BufferCallbackEffect = BufferCallbackEffect() effects.lowerable_effects.add_type(BufferCallbackEffect) effects.control_flow_allowed_effects.add_type(BufferCallbackEffect) @buffer_callback_p.def_effectful_abstract_eval def _buffer_callback_abstract_eval( *args, result_avals: tuple[core.ShapedArray, ...], has_side_effect: bool, **_, ): del args effects = {_BufferCallbackEffect} if has_side_effect else core.no_effects return result_avals, effects def _buffer_callback_jvp_rule(*args, **kwargs): del args, kwargs raise ValueError( "Buffer callbacks do not support JVP. " "Please use `jax.custom_jvp` to use callbacks while taking gradients.") ad.primitive_jvps[buffer_callback_p] = _buffer_callback_jvp_rule def _buffer_callback_transpose_rule(*args, **kwargs): del args, kwargs raise ValueError( "Buffer callbacks do not support transpose. " "Please use `jax.custom_vjp` to use callbacks while taking gradients.") ad.primitive_transposes[buffer_callback_p] = _buffer_callback_transpose_rule batching.primitive_batchers[buffer_callback_p] = functools.partial( ffi.ffi_batching_rule, buffer_callback_p ) def _buffer_callback_lowering( ctx: mlir.LoweringRuleContext, *args: Any, callback, in_tree: Any, out_tree: Any, has_side_effect: bool, input_output_aliases: Sequence[tuple[int, int]], command_buffer_compatible: bool, **_, ): if len(ctx.module_context.platforms) > 1: raise NotImplementedError("multi-platform lowering for buffer_callback") platform = ctx.module_context.platforms[0] target_name = { "cpu": "xla_buffer_python_cpu_callback", "cuda": "xla_buffer_python_gpu_callback", "rocm": "xla_buffer_python_gpu_callback", }.get(platform) if target_name is None: raise ValueError(f"`buffer_callback` not supported on {platform} backend.") if command_buffer_compatible and platform in ("cuda", "rocm"): target_name += "_cmd_buffer" def wrapped_callback(exec_ctx, *args: Any): args_in, args_out = util.split_list(args, [in_tree.num_leaves]) py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in) py_args_out = tree_util.tree_unflatten(out_tree, args_out) if callback(exec_ctx, py_args_out, *py_args_in, **py_kwargs_in) is not None: raise ValueError("buffer_callback callback must not return any values.") return () ctx.module_context.add_host_callback(wrapped_callback) index = np.uint64(len(ctx.module_context.host_callbacks) - 1) rule = ffi.ffi_lowering( target_name, has_side_effect=has_side_effect, operand_output_aliases=dict(input_output_aliases), ) return rule(ctx, *args, index=index) mlir.register_lowering(buffer_callback_p, _buffer_callback_lowering)