# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Colocated Python function API implementation.""" from __future__ import annotations import dataclasses import inspect import random import threading from typing import Any, Callable, Sequence import jax from jax._src import api from jax._src import tree_util from jax._src import util from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @dataclasses.dataclass(frozen=True, slots=True) class FunctionInfo: """User function wrapped by colocated_python.""" fun: Callable[..., Any] fun_sourceinfo: str | None fun_signature: inspect.Signature | None @dataclasses.dataclass(frozen=True, slots=True) class Specialization: """Specialization for a colocated_python function.""" in_specs_treedef: tree_util.PyTreeDef | None = None in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None out_specs_treedef: tree_util.PyTreeDef | None = None out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None devices: xc.DeviceList | None = None def update( self, *, in_specs_treedef: tree_util.PyTreeDef | None = None, in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, out_specs_treedef: tree_util.PyTreeDef | None = None, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None, devices: Sequence[jax.Device] | xc.DeviceList | None = None, ): """Creates a new specialization with overrides.""" if in_specs_treedef is None: in_specs_treedef = self.in_specs_treedef elif self.in_specs_treedef is not None: raise ValueError("in_specs already specified") if in_specs_leaves is None: in_specs_leaves = self.in_specs_leaves elif self.in_specs_leaves is not None: raise ValueError("in_specs already specified") if out_specs_fn is None: out_specs_fn = self.out_specs_fn elif self.out_specs_fn is not None: raise ValueError("out_specs_fn already specified") if out_specs_treedef is None: out_specs_treedef = self.out_specs_treedef elif self.out_specs_treedef is not None: raise ValueError("out_specs already specified") if out_specs_leaves is None: out_specs_leaves = self.out_specs_leaves elif self.out_specs_leaves is not None: raise ValueError("out_specs already specified") if devices is None: devices = self.devices elif self.devices is not None: raise ValueError("devices already specified") elif not isinstance(devices, xc.DeviceList): devices = xc.DeviceList(tuple(devices)) return Specialization( in_specs_treedef, in_specs_leaves, out_specs_fn, out_specs_treedef, out_specs_leaves, devices, ) def _get_spec(x: Any) -> api.ShapeDtypeStruct: """Extracts a spec for a value, which must be a JAX Array.""" # TODO(hyeontaek): Allow Python values and automatically apply `shard_arg` # with a suitable sharding and layout. if not isinstance(x, jax.Array): raise ValueError( "colocated_python only supports jax.Array as input and output, but got" f" {type(x)}." ) return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None: """Returns a representative device list from function call arguments.""" device_list_set: set[xc.DeviceList] = set() for x in args: sharding = getattr(x, "sharding", None) if sharding is not None: device_list_set.add(x.sharding._internal_device_list) if not device_list_set: return None if len(device_list_set) != 1: raise ValueError( "All arguments must use the same device list, but got" f" multiple device lists: {device_list_set}." ) return device_list_set.pop() def _compile_to_executable( name: str, fun: Callable[..., Any], in_specs_treedef: tree_util.PyTreeDef, in_specs_leaves: tuple[api.ShapeDtypeStruct, ...], out_specs_treedef: tree_util.PyTreeDef, out_specs_leaves: tuple[api.ShapeDtypeStruct, ...], devices: xc.DeviceList, ) -> Callable[..., Any]: """Compiles a Python function into a runtime executable.""" fun_and_specialization = ( fun, in_specs_treedef, in_specs_leaves, out_specs_treedef, out_specs_leaves, devices, ) pickled_function = _serialize(fun_and_specialization) program = ifrt_programs.make_colocated_python_program( name, pickled_function, devices, in_specs_leaves, out_specs_leaves ) ifrt_client = devices[0].client out_sdss = tuple( jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves ) out_shardings = tuple(sds.sharding for sds in out_specs_leaves) try: compile_options = ifrt_programs.make_colocated_python_compile_options() loaded_executable = ifrt_client.compile_ifrt_program( program, compile_options ) out_handlers = pxla.global_avals_to_results_handler( out_sdss, out_shardings, committed=True # type: ignore ).handlers def call(*args, **kwargs): args_leaves = tree_util.tree_leaves((args, kwargs)) execute_result = loaded_executable.execute_sharded( args_leaves, with_tokens=False ) results = execute_result.consume_with_handlers(out_handlers) return tree_util.tree_unflatten(out_specs_treedef, results) return call except jax.errors.JaxRuntimeError as e: # TODO(hyeontaek): Implement colocated Python support in McJAX and remove # this fallback path. if "PjRtCompiler requires an HloProgram" in str(e): return fun raise def _make_output_specs_and_push_result_fun( info: FunctionInfo, specialization: Specialization, uid: int ) -> Callable[..., Any]: """Creates a function that computes output specs and pushes the result to the result store.""" assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.out_specs_treedef is None assert specialization.out_specs_leaves is None assert specialization.devices is not None devices = specialization.devices def lowered_fun(*args, **kwargs) -> jax.Array: result = info.fun(*args, **kwargs) result_leaves, out_treedef = tree_util.tree_flatten(result) out_spec_leaves = tuple(_get_spec(x) for x in result_leaves) func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves) return _serialize_specs(out_treedef, out_spec_leaves, devices) out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( _make_specs_for_serialized_specs(specialization.devices), ) name = getattr(info.fun, "__name__", "unknown") name = f"{name}_output_specs_and_push_result" return _compile_to_executable( name=name, fun=lowered_fun, in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, out_specs_treedef=out_specs_treedef, out_specs_leaves=tuple(out_specs_leaves), devices=specialization.devices, ) def _make_pop_result_fun( info: FunctionInfo, specialization: Specialization, uid: int ) -> Callable[..., Any]: """Makes a function that pops results from the result store.""" assert specialization.out_specs_treedef is not None assert specialization.out_specs_leaves is not None assert specialization.devices is not None out_specs_treedef = specialization.out_specs_treedef def lowered_fun(): result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid) return tree_util.tree_unflatten(out_specs_treedef, result_leaves) in_specs_leaves, in_specs_treedef = tree_util.tree_flatten(( # args (), # kwargs {}, )) name = getattr(info.fun, "__name__", "unknown") name = f"{name}_pop_result" return _compile_to_executable( name=name, fun=lowered_fun, in_specs_treedef=in_specs_treedef, in_specs_leaves=tuple(in_specs_leaves), out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) def _make_async_execution_fun( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: """Makes a function that asynchronously executes the function.""" assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.out_specs_treedef is not None assert specialization.out_specs_leaves is not None assert specialization.devices is not None name = getattr(info.fun, "__name__", "unknown") return _compile_to_executable( name=name, fun=info.fun, in_specs_treedef=specialization.in_specs_treedef, in_specs_leaves=specialization.in_specs_leaves, out_specs_treedef=specialization.out_specs_treedef, out_specs_leaves=specialization.out_specs_leaves, devices=specialization.devices, ) @jax._src.util.cache(max_size=None) def _get_specialized_func( info: FunctionInfo, specialization: Specialization ) -> Callable[..., Any]: """Returns a specialized function for the given specialization.""" util.test_event("colocated_python_func._get_specialized_func") assert specialization.in_specs_treedef is not None assert specialization.in_specs_leaves is not None assert specialization.devices is not None uid = random.getrandbits(63) mutex = threading.Lock() # Asynchronous execution function that has known output_specs. async_execution_func = None def specialized_func(*args, **kwargs): """Specialized function to be executed with given args and kwargs.""" nonlocal specialization, async_execution_func with mutex: if async_execution_func is None: if specialization.out_specs_treedef is None: if specialization.out_specs_fn is None: serialized_out_specs = _make_output_specs_and_push_result_fun( info, specialization, uid )(*args, **kwargs) # Waits for the output_specs. This may block. out_specs_treedef, out_specs_leaves = _deserialize_specs( serialized_out_specs ) # Subsequent calls would use async_execution_func with discovered # output_specs. specialization = specialization.update( out_specs_treedef=out_specs_treedef, out_specs_leaves=out_specs_leaves, ) async_execution_func = _make_async_execution_fun( info, specialization ) return _make_pop_result_fun(info, specialization, uid)() else: # Compute out_specs using out_specs_fn and inputs. args_specs, kwargs_specs = tree_util.tree_map( _get_spec, (args, kwargs) ) out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs) # Type checking is ignored to silence mypy error: Incompatible types # in assignment (expression has type "list[Any]", variable has type # "tuple[ShapeDtypeStruct, ...]") [assignment] out_specs_leaves, out_specs_treedef = tree_util.tree_flatten( # type: ignore[assignment] out_specs ) specialization = specialization.update( out_specs_treedef=out_specs_treedef, out_specs_leaves=tuple(out_specs_leaves), ) async_execution_func = _make_async_execution_fun( info, specialization ) # Fall-through. else: async_execution_func = _make_async_execution_fun(info, specialization) # Fall-through. # Asynchronous execution runs outside of the mutex to allow concurrent # execution for inline executors. return async_execution_func(*args, **kwargs) return specialized_func def make_callable( fun: Callable[..., Any], fun_sourceinfo: str | None, fun_signature: inspect.Signature | None, ): """Makes a colocated Python callable.""" return _make_callable( FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() ) def _make_callable(info: FunctionInfo, specialization: Specialization): """Internal implementation of make_callable.""" def specialize( in_specs: ShapeDtypeStructTree | None = None, out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, devices: Sequence[jax.Device] | None = None, ): """Returns a colocated Python callable with extra specialization. Args: in_specs: Optionally specifies the expected input specs. Input specs are expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a function call. out_specs_fn: Optionally specifies a function that computes the output specs from input specs. If unspecified, colocated_python will compute the output specs during the very first execution, and this execution will be synchronous. devices: Optionally specifies the devices to execute the function on. Must be provided if in_specs has no leaves because devices cannot be inferred from input specs or arguments. Returns: A colocated Python callable with extra specialization. """ # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if # `out_specs_fn(in_specs)` returns at least one leaf that we can use for # inferring `devices`. if in_specs is None: in_specs_leaves, in_specs_treedef = None, None else: in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) in_specs_leaves = tuple(in_specs_leaves_list) return _make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, out_specs_fn=out_specs_fn, devices=devices, ), ) @api_boundary def __call__(*args, **kwargs): """Executes the function. If the output specs are not known, the very first execution will be synchronous. """ args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) if specialization.in_specs_treedef is None: # Allow input polymorphism by applying input_specs specialization # temporarily for this call. return _make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, ), )(*args, **kwargs) if specialization.devices is None: devices = _infer_devices_from_args(args_leaves) if devices is None: raise ValueError( "No devices found. colocated_python function without input" " arguments must be first specialized with devices." ) # Allow device polymorphism by applying devices specialization temporarily # for this call. return _make_callable(info, specialization.update(devices=devices))( *args, **kwargs ) # Assertion is added to silence mypy error: Unsupported operand types for != # ("PyTreeDef" and "None") [operator] assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) # If input_specs is known, verify that it matches actual inputs. if (specialization.in_specs_treedef != in_specs_treedef or specialization.in_specs_leaves != in_specs_leaves): raise ValueError( "Input specs in specialization and input specs of arguments must have" " the same pytree structure, but they have the following structural" " differences:\n" + ("\n".join( f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" f" a {thing2} in value 2, so {explanation}.\n" for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( specialization.in_specs_treedef, in_specs_treedef )))) return _get_specialized_func(info, specialization)(*args, **kwargs) __call__ = wraps(info.fun)(__call__) __call__.specialize = specialize return __call__