268 lines
10 KiB
Python
268 lines
10 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import functools
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import effects
|
|
from jax._src import ffi
|
|
from jax._src import tree_util
|
|
from jax._src import util
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.lib import ffi as ffi_lib
|
|
|
|
export = util.set_module("jax.experimental.buffer_callback")
|
|
Buffer = export(ffi_lib.Buffer)
|
|
ExecutionStage = export(ffi_lib.ExecutionStage)
|
|
ExecutionContext = export(ffi_lib.ExecutionContext)
|
|
|
|
|
|
def buffer_callback(
|
|
callback: Callable[..., None],
|
|
result_shape_dtypes: object,
|
|
*,
|
|
has_side_effect: bool = False,
|
|
vmap_method: str | None = None,
|
|
input_output_aliases: dict[int, int] | None = None,
|
|
command_buffer_compatible: bool = False,
|
|
):
|
|
"""An experimental callback that operates in place on device buffers.
|
|
|
|
Only supported on CPU and GPU backends.
|
|
|
|
Note that the plan is for this to eventually be replaced by a consolidated
|
|
callback API built using JAX mutable arrays, but for now this provides a
|
|
mechanism for prototyping computational kernels using other Python libraries
|
|
including Numpy, PyTorch, Cupy, and others.
|
|
|
|
Let's start with a simple example:
|
|
|
|
>>> def py_add_one_inplace(ctx, out, x):
|
|
... np.asarray(out)[...] = np.asarray(x) + 1
|
|
...
|
|
>>> x = jnp.array(41, dtype=jnp.int32)
|
|
>>> out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
>>> add_one = buffer_callback(py_add_one_inplace, out_type)
|
|
>>> add_one(x) # doctest: +SKIP
|
|
Array(42, dtype=int32)
|
|
|
|
In this example, we're executing a numpy computation via JAX, and this could
|
|
have been implemented using :func:`jax.pure_callback`, but in this case, the
|
|
output is being populated in-place. This means that JAX doesn't need to copy
|
|
the output arrays upon returning from the callback. Note that even though the
|
|
callback function operates on mutable buffers, JAX still sees this as an
|
|
operation that consumes and produces regular immutable JAX arrays.
|
|
|
|
Unlike the other JAX callback APIs, ``buffer_callback`` requires that the
|
|
user-defined Python function have the following signature:
|
|
|
|
.. code-block:: python
|
|
|
|
def callback(ctx: ExecutionContext, out, *args) -> None:
|
|
...
|
|
|
|
where ``ctx`` is an instance of
|
|
:class:`~jax.experimental.buffer_callback.ExecutionContext`, which mainly
|
|
provides access to XLA's computation stream when running on GPU, ``out`` is a
|
|
pytree of mutable :class:`~jax.experimental.buffer_callback.Buffer` objects,
|
|
and the ``args`` arguments have the same pytree structure as the inputs, but
|
|
each leaf is :class:`~jax.experimental.buffer_callback.Buffer`. This callback
|
|
should not return any values, and it should overwrite the ``out`` buffers in
|
|
place to output values back to JAX.
|
|
|
|
It's important to note that this Python function can't really be called
|
|
except via ```buffer_callback`` itself, because it's not (yet!) possible to
|
|
construct mutable JAX buffers directly in Python.
|
|
|
|
The bespoke :class:`~jax.experimental.buffer_callback.Buffer` type is an
|
|
array-like object that supports the ``__array__`` protocol on CPU, the
|
|
``__cuda_array_interface__`` protocol on GPU, and the ``__dlpack__`` protocol
|
|
on both CPU and GPU.
|
|
|
|
Args:
|
|
callback: A Python function with the signature and behavior described above.
|
|
result_shape_dtypes: A pytree whose leaves have ``shape`` and ``dtype``
|
|
attributes, with a structure that matches the expected output of the
|
|
callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used
|
|
to define leaf values.
|
|
has_side_effect: Whether the callback has side effects.
|
|
vmap_method: A string specifying how the callback transforms under
|
|
:func:`~jax.vmap` as described in the docs for :func:`~jax.pure_callback`.
|
|
input_output_aliases: a dictionary mapping the index of some inputs to
|
|
the index of the output that aliases them. These indices are in the
|
|
flattened inputs and outputs.
|
|
command_buffer_compatible: if ``True``, the callback will be traced into
|
|
the command buffer. This means that the Python code should only be
|
|
executed once, and then the operations will be replayed for every
|
|
subsequent call.
|
|
|
|
Returns:
|
|
A new callable that accepts :class:`jax.Array` inputs (and pytrees thereof),
|
|
and pytree of :class:`jax.Array` objects whose structure matches that
|
|
of ``result_shape_dtypes``.
|
|
|
|
See Also:
|
|
- :func:`jax.pure_callback`: callback designed for pure host functions.
|
|
- :func:`jax.experimental.io_callback`: callback designed for impure host
|
|
functions.
|
|
- :func:`jax.debug.callback`: callback designed for general-purpose
|
|
debugging.
|
|
- :func:`jax.debug.print`: callback designed for printing.
|
|
"""
|
|
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
|
|
flat_result_avals = tuple(
|
|
core.ShapedArray(x.shape, x.dtype) for x in flat_shape_dtypes
|
|
)
|
|
|
|
def wrapped_callback(*args, **kwargs):
|
|
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
|
|
|
in_avals = [core.get_aval(x) for x in flat_args]
|
|
static_input_output_aliases: tuple[tuple[int, int], ...] = ()
|
|
if input_output_aliases is not None:
|
|
for i_idx, o_idx in sorted(input_output_aliases.items()):
|
|
i_idx, o_idx = int(i_idx), int(o_idx)
|
|
if i_idx >= len(args):
|
|
raise ValueError(
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
f"with input index {i_idx} outside the range [0, "
|
|
f"{len(args)}).")
|
|
if o_idx >= len(flat_result_avals):
|
|
raise ValueError(
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
f"with output index {o_idx} outside the range [0, "
|
|
f"{len(flat_result_avals)}).")
|
|
in_aval = in_avals[i_idx]
|
|
out_aval = flat_result_avals[o_idx]
|
|
if not ffi._check_compatible_avals(in_aval, out_aval):
|
|
raise ValueError(
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
f"referring to an input with abstract value {in_aval} and an "
|
|
f"output with a different abstract value {out_aval}.")
|
|
static_input_output_aliases += ((i_idx, o_idx),)
|
|
|
|
out_flat = buffer_callback_p.bind(
|
|
*flat_args,
|
|
callback=callback,
|
|
result_avals=flat_result_avals,
|
|
in_tree=in_tree,
|
|
out_tree=out_tree,
|
|
vmap_method=vmap_method,
|
|
has_side_effect=has_side_effect,
|
|
input_output_aliases=static_input_output_aliases,
|
|
command_buffer_compatible=command_buffer_compatible,
|
|
)
|
|
return tree_util.tree_unflatten(out_tree, out_flat)
|
|
|
|
return wrapped_callback
|
|
|
|
|
|
buffer_callback_p = core.Primitive("buffer_callback")
|
|
buffer_callback_p.multiple_results = True
|
|
dispatch.prim_requires_devices_during_lowering.add(buffer_callback_p)
|
|
dispatch.simple_impl(buffer_callback_p)
|
|
|
|
|
|
class BufferCallbackEffect(effects.Effect):
|
|
def __str__(self):
|
|
return "BufferCallback"
|
|
|
|
_BufferCallbackEffect = BufferCallbackEffect()
|
|
effects.lowerable_effects.add_type(BufferCallbackEffect)
|
|
effects.control_flow_allowed_effects.add_type(BufferCallbackEffect)
|
|
|
|
|
|
@buffer_callback_p.def_effectful_abstract_eval
|
|
def _buffer_callback_abstract_eval(
|
|
*args,
|
|
result_avals: tuple[core.ShapedArray, ...],
|
|
has_side_effect: bool,
|
|
**_,
|
|
):
|
|
del args
|
|
effects = {_BufferCallbackEffect} if has_side_effect else core.no_effects
|
|
return result_avals, effects
|
|
|
|
|
|
def _buffer_callback_jvp_rule(*args, **kwargs):
|
|
del args, kwargs
|
|
raise ValueError(
|
|
"Buffer callbacks do not support JVP. "
|
|
"Please use `jax.custom_jvp` to use callbacks while taking gradients.")
|
|
ad.primitive_jvps[buffer_callback_p] = _buffer_callback_jvp_rule
|
|
|
|
|
|
def _buffer_callback_transpose_rule(*args, **kwargs):
|
|
del args, kwargs
|
|
raise ValueError(
|
|
"Buffer callbacks do not support transpose. "
|
|
"Please use `jax.custom_vjp` to use callbacks while taking gradients.")
|
|
ad.primitive_transposes[buffer_callback_p] = _buffer_callback_transpose_rule
|
|
|
|
batching.primitive_batchers[buffer_callback_p] = functools.partial(
|
|
ffi.ffi_batching_rule, buffer_callback_p
|
|
)
|
|
|
|
|
|
def _buffer_callback_lowering(
|
|
ctx: mlir.LoweringRuleContext,
|
|
*args: Any,
|
|
callback,
|
|
in_tree: Any,
|
|
out_tree: Any,
|
|
has_side_effect: bool,
|
|
input_output_aliases: Sequence[tuple[int, int]],
|
|
command_buffer_compatible: bool,
|
|
**_,
|
|
):
|
|
|
|
if len(ctx.module_context.platforms) > 1:
|
|
raise NotImplementedError("multi-platform lowering for buffer_callback")
|
|
platform = ctx.module_context.platforms[0]
|
|
target_name = {
|
|
"cpu": "xla_buffer_python_cpu_callback",
|
|
"cuda": "xla_buffer_python_gpu_callback",
|
|
"rocm": "xla_buffer_python_gpu_callback",
|
|
}.get(platform)
|
|
if target_name is None:
|
|
raise ValueError(f"`buffer_callback` not supported on {platform} backend.")
|
|
|
|
if command_buffer_compatible and platform in ("cuda", "rocm"):
|
|
target_name += "_cmd_buffer"
|
|
|
|
def wrapped_callback(exec_ctx, *args: Any):
|
|
args_in, args_out = util.split_list(args, [in_tree.num_leaves])
|
|
py_args_in, py_kwargs_in = tree_util.tree_unflatten(in_tree, args_in)
|
|
py_args_out = tree_util.tree_unflatten(out_tree, args_out)
|
|
if callback(exec_ctx, py_args_out, *py_args_in, **py_kwargs_in) is not None:
|
|
raise ValueError("buffer_callback callback must not return any values.")
|
|
return ()
|
|
|
|
ctx.module_context.add_host_callback(wrapped_callback)
|
|
index = np.uint64(len(ctx.module_context.host_callbacks) - 1)
|
|
rule = ffi.ffi_lowering(
|
|
target_name,
|
|
has_side_effect=has_side_effect,
|
|
operand_output_aliases=dict(input_output_aliases),
|
|
)
|
|
return rule(ctx, *args, index=index)
|
|
mlir.register_lowering(buffer_callback_p, _buffer_callback_lowering)
|