2025-08-11 12:24:21 +08:00

466 lines
17 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python function API implementation."""
from __future__ import annotations
import dataclasses
import inspect
import random
import threading
from typing import Any, Callable, Sequence
import jax
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
from jax.extend.ifrt_programs import ifrt_programs
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
@dataclasses.dataclass(frozen=True, slots=True)
class FunctionInfo:
"""User function wrapped by colocated_python."""
fun: Callable[..., Any]
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
@dataclasses.dataclass(frozen=True, slots=True)
class Specialization:
"""Specialization for a colocated_python function."""
in_specs_treedef: tree_util.PyTreeDef | None = None
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None
out_specs_treedef: tree_util.PyTreeDef | None = None
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
devices: xc.DeviceList | None = None
def update(
self,
*,
in_specs_treedef: tree_util.PyTreeDef | None = None,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
out_specs_treedef: tree_util.PyTreeDef | None = None,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
devices: Sequence[jax.Device] | xc.DeviceList | None = None,
):
"""Creates a new specialization with overrides."""
if in_specs_treedef is None:
in_specs_treedef = self.in_specs_treedef
elif self.in_specs_treedef is not None:
raise ValueError("in_specs already specified")
if in_specs_leaves is None:
in_specs_leaves = self.in_specs_leaves
elif self.in_specs_leaves is not None:
raise ValueError("in_specs already specified")
if out_specs_fn is None:
out_specs_fn = self.out_specs_fn
elif self.out_specs_fn is not None:
raise ValueError("out_specs_fn already specified")
if out_specs_treedef is None:
out_specs_treedef = self.out_specs_treedef
elif self.out_specs_treedef is not None:
raise ValueError("out_specs already specified")
if out_specs_leaves is None:
out_specs_leaves = self.out_specs_leaves
elif self.out_specs_leaves is not None:
raise ValueError("out_specs already specified")
if devices is None:
devices = self.devices
elif self.devices is not None:
raise ValueError("devices already specified")
elif not isinstance(devices, xc.DeviceList):
devices = xc.DeviceList(tuple(devices))
return Specialization(
in_specs_treedef,
in_specs_leaves,
out_specs_fn,
out_specs_treedef,
out_specs_leaves,
devices,
)
def _get_spec(x: Any) -> api.ShapeDtypeStruct:
"""Extracts a spec for a value, which must be a JAX Array."""
# TODO(hyeontaek): Allow Python values and automatically apply `shard_arg`
# with a suitable sharding and layout.
if not isinstance(x, jax.Array):
raise ValueError(
"colocated_python only supports jax.Array as input and output, but got"
f" {type(x)}."
)
return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
"""Returns a representative device list from function call arguments."""
device_list_set: set[xc.DeviceList] = set()
for x in args:
sharding = getattr(x, "sharding", None)
if sharding is not None:
device_list_set.add(x.sharding._internal_device_list)
if not device_list_set:
return None
if len(device_list_set) != 1:
raise ValueError(
"All arguments must use the same device list, but got"
f" multiple device lists: {device_list_set}."
)
return device_list_set.pop()
def _compile_to_executable(
name: str,
fun: Callable[..., Any],
in_specs_treedef: tree_util.PyTreeDef,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
out_specs_treedef: tree_util.PyTreeDef,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
fun_and_specialization = (
fun,
in_specs_treedef,
in_specs_leaves,
out_specs_treedef,
out_specs_leaves,
devices,
)
pickled_function = _serialize(fun_and_specialization)
program = ifrt_programs.make_colocated_python_program(
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
)
ifrt_client = devices[0].client
out_sdss = tuple(
jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves
)
out_shardings = tuple(sds.sharding for sds in out_specs_leaves)
try:
compile_options = ifrt_programs.make_colocated_python_compile_options()
loaded_executable = ifrt_client.compile_ifrt_program(
program, compile_options
)
out_handlers = pxla.global_avals_to_results_handler(
out_sdss, out_shardings, committed=True # type: ignore
).handlers
def call(*args, **kwargs):
args_leaves = tree_util.tree_leaves((args, kwargs))
execute_result = loaded_executable.execute_sharded(
args_leaves, with_tokens=False
)
results = execute_result.consume_with_handlers(out_handlers)
return tree_util.tree_unflatten(out_specs_treedef, results)
return call
except jax.errors.JaxRuntimeError as e:
# TODO(hyeontaek): Implement colocated Python support in McJAX and remove
# this fallback path.
if "PjRtCompiler requires an HloProgram" in str(e):
return fun
raise
def _make_output_specs_and_push_result_fun(
info: FunctionInfo, specialization: Specialization, uid: int
) -> Callable[..., Any]:
"""Creates a function that computes output specs and pushes the result to the result store."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is None
assert specialization.out_specs_leaves is None
assert specialization.devices is not None
devices = specialization.devices
def lowered_fun(*args, **kwargs) -> jax.Array:
result = info.fun(*args, **kwargs)
result_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
return _serialize_specs(out_treedef, out_spec_leaves, devices)
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
_make_specs_for_serialized_specs(specialization.devices),
)
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_output_specs_and_push_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
devices=specialization.devices,
)
def _make_pop_result_fun(
info: FunctionInfo, specialization: Specialization, uid: int
) -> Callable[..., Any]:
"""Makes a function that pops results from the result store."""
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
out_specs_treedef = specialization.out_specs_treedef
def lowered_fun():
result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid)
return tree_util.tree_unflatten(out_specs_treedef, result_leaves)
in_specs_leaves, in_specs_treedef = tree_util.tree_flatten((
# args
(),
# kwargs
{},
))
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_pop_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_treedef=in_specs_treedef,
in_specs_leaves=tuple(in_specs_leaves),
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
def _make_async_execution_fun(
info: FunctionInfo, specialization: Specialization
) -> Callable[..., Any]:
"""Makes a function that asynchronously executes the function."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
name = getattr(info.fun, "__name__", "unknown")
return _compile_to_executable(
name=name,
fun=info.fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
@jax._src.util.cache(max_size=None)
def _get_specialized_func(
info: FunctionInfo, specialization: Specialization
) -> Callable[..., Any]:
"""Returns a specialized function for the given specialization."""
util.test_event("colocated_python_func._get_specialized_func")
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.devices is not None
uid = random.getrandbits(63)
mutex = threading.Lock()
# Asynchronous execution function that has known output_specs.
async_execution_func = None
def specialized_func(*args, **kwargs):
"""Specialized function to be executed with given args and kwargs."""
nonlocal specialization, async_execution_func
with mutex:
if async_execution_func is None:
if specialization.out_specs_treedef is None:
if specialization.out_specs_fn is None:
serialized_out_specs = _make_output_specs_and_push_result_fun(
info, specialization, uid
)(*args, **kwargs)
# Waits for the output_specs. This may block.
out_specs_treedef, out_specs_leaves = _deserialize_specs(
serialized_out_specs
)
# Subsequent calls would use async_execution_func with discovered
# output_specs.
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=out_specs_leaves,
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
return _make_pop_result_fun(info, specialization, uid)()
else:
# Compute out_specs using out_specs_fn and inputs.
args_specs, kwargs_specs = tree_util.tree_map(
_get_spec, (args, kwargs)
)
out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs)
# Type checking is ignored to silence mypy error: Incompatible types
# in assignment (expression has type "list[Any]", variable has type
# "tuple[ShapeDtypeStruct, ...]") [assignment]
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment]
out_specs
)
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
# Fall-through.
else:
async_execution_func = _make_async_execution_fun(info, specialization)
# Fall-through.
# Asynchronous execution runs outside of the mutex to allow concurrent
# execution for inline executors.
return async_execution_func(*args, **kwargs)
return specialized_func
def make_callable(
fun: Callable[..., Any],
fun_sourceinfo: str | None,
fun_signature: inspect.Signature | None,
):
"""Makes a colocated Python callable."""
return _make_callable(
FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
)
def _make_callable(info: FunctionInfo, specialization: Specialization):
"""Internal implementation of make_callable."""
def specialize(
in_specs: ShapeDtypeStructTree | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
devices: Sequence[jax.Device] | None = None,
):
"""Returns a colocated Python callable with extra specialization.
Args:
in_specs: Optionally specifies the expected input specs. Input specs are
expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
function call.
out_specs_fn: Optionally specifies a function that computes the output
specs from input specs. If unspecified, colocated_python will compute
the output specs during the very first execution, and this execution
will be synchronous.
devices: Optionally specifies the devices to execute the function on. Must
be provided if in_specs has no leaves because devices cannot be inferred
from input specs or arguments.
Returns:
A colocated Python callable with extra specialization.
"""
# TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
# `out_specs_fn(in_specs)` returns at least one leaf that we can use for
# inferring `devices`.
if in_specs is None:
in_specs_leaves, in_specs_treedef = None, None
else:
in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs)
in_specs_leaves = tuple(in_specs_leaves_list)
return _make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
out_specs_fn=out_specs_fn,
devices=devices,
),
)
@api_boundary
def __call__(*args, **kwargs):
"""Executes the function.
If the output specs are not known, the very first execution will be
synchronous.
"""
args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
if specialization.in_specs_treedef is None:
# Allow input polymorphism by applying input_specs specialization
# temporarily for this call.
return _make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
),
)(*args, **kwargs)
if specialization.devices is None:
devices = _infer_devices_from_args(args_leaves)
if devices is None:
raise ValueError(
"No devices found. colocated_python function without input"
" arguments must be first specialized with devices."
)
# Allow device polymorphism by applying devices specialization temporarily
# for this call.
return _make_callable(info, specialization.update(devices=devices))(
*args, **kwargs
)
# Assertion is added to silence mypy error: Unsupported operand types for !=
# ("PyTreeDef" and "None") [operator]
assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef)
# If input_specs is known, verify that it matches actual inputs.
if (specialization.in_specs_treedef != in_specs_treedef
or specialization.in_specs_leaves != in_specs_leaves):
raise ValueError(
"Input specs in specialization and input specs of arguments must have"
" the same pytree structure, but they have the following structural"
" differences:\n"
+ ("\n".join(
f" - {tree_util.keystr(path)} is a {thing1} in value 1 and"
f" a {thing2} in value 2, so {explanation}.\n"
for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef(
specialization.in_specs_treedef, in_specs_treedef
))))
return _get_specialized_func(info, specialization)(*args, **kwargs)
__call__ = wraps(info.fun)(__call__)
__call__.specialize = specialize
return __call__