2025-08-11 12:24:21 +08:00

550 lines
17 KiB
Python

# Copyright 2017 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An XLA client in Python."""
from __future__ import annotations
import atexit
from collections.abc import Mapping
import contextlib
import enum
import logging
import os
import threading
from typing import Any, Protocol, Union
from jaxlib import _jax as _xla
# Note this module does *not* depend on any Python protocol buffers. The XLA
# Python bindings are currently packaged both as part of jaxlib and as part
# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
# and TensorFlow may fail with duplicate protocol buffer message definitions.
# Most functions are snake_case for consistency with other modules, some
# method names are CamelCase for consistency with XLA.
# pylint: disable=invalid-name
# Pylint has false positives for type annotations.
# pylint: disable=invalid-sequence-index
ifrt_programs = _xla.ifrt_programs
# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
_version = 355
# An internal increasing version number for protecting jaxlib code against
# ifrt changes.
# lives in xla/python/version.h.
# In JAX, reference this via jax._src.lib.ifrt_version.
_ifrt_version = _xla.ifrt_version_number
xla_platform_names = {
'cpu': 'Host',
'gpu': 'CUDA',
}
logger = logging.getLogger(__name__)
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
def make_cpu_client(
asynchronous=True,
distributed_client=None,
node_id=0,
num_nodes=1,
collectives=None,
num_devices=None,
get_local_topology_timeout_minutes=None,
get_global_topology_timeout_minutes=None,
) -> Client:
register_custom_call_handler('cpu', _xla.register_custom_call_target)
register_custom_type_id_handler('cpu', _xla.register_custom_type_id)
return _xla.get_tfrt_cpu_client(
asynchronous=asynchronous,
distributed_client=distributed_client,
node_id=node_id,
num_nodes=num_nodes,
collectives=collectives,
num_devices=num_devices,
get_local_topology_timeout_minutes=get_local_topology_timeout_minutes,
get_global_topology_timeout_minutes=get_global_topology_timeout_minutes,
)
DeviceTopology = _xla.DeviceTopology
get_topology_for_devices = _xla.get_topology_for_devices
def make_tfrt_tpu_c_api_device_topology(
topology_name: str = '', **kwargs
) -> DeviceTopology:
"""Creates a PJRT C API TopologyDescription."""
return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs))
def make_c_api_device_topology(
c_api: Any, topology_name: str = '', **kwargs
) -> DeviceTopology:
"""Creates a PJRT C API TopologyDescription."""
return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs))
def pjrt_plugin_loaded(plugin_name: str) -> bool:
return _xla.pjrt_plugin_loaded(plugin_name)
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None)
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
return _xla.load_pjrt_plugin(plugin_name, None, c_api)
def pjrt_plugin_initialized(plugin_name: str) -> bool:
return _xla.pjrt_plugin_initialized(plugin_name)
def initialize_pjrt_plugin(plugin_name: str) -> None:
"""Initializes a PJRT plugin.
The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
static linking) before this method is called.
Args:
plugin_name: the name of the PJRT plugin.
"""
_xla.initialize_pjrt_plugin(plugin_name)
def make_c_api_client(
plugin_name: str,
options: _NameValueMapping | None = None,
distributed_client: _xla.DistributedRuntimeClient | None = None,
):
"""Creates a PJRT C API client for a PJRT plugin.
It is required that load_pjrt_plugin_dynamically is called once with the same
plugin_name before this method is called.
Args:
plugin_name: the name of the PJRT plugin.
options: extra platform-specific options.
distributed_client: distributed client.
Returns:
A PJRT C API client for plugin_name.
"""
if options is None:
options = {}
return _xla.get_c_api_client(plugin_name, options, distributed_client)
def generate_pjrt_gpu_plugin_options() -> _NameValueMapping:
"""Generates the PjRt GPU plugin options.
Returns:
A dictionary of plugin options.
"""
options = {}
options['platform_name'] = 'cuda'
allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '')
deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '')
if deprecated_memory_fraction:
if memory_fraction:
raise ValueError(
'XLA_CLIENT_MEM_FRACTION is specified together '
'with XLA_PYTHON_CLIENT_MEM_FRACTION. '
'Remove the latter one, it is deprecated.'
)
else:
memory_fraction = deprecated_memory_fraction
preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '')
collective_memory_size = os.getenv(
'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', ''
)
if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
raise ValueError(
'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
'"bfc", or "cuda_async", got "%s"' % allocator
)
options['allocator'] = allocator
if memory_fraction:
options['memory_fraction'] = float(memory_fraction)
if preallocate:
options['preallocate'] = preallocate not in ('false', 'False', '0')
if collective_memory_size:
options['collective_memory_size'] = int(collective_memory_size) * (1 << 20)
return options
PrimitiveType = _xla.PrimitiveType
Shape = _xla.Shape
Shape.__doc__ = """
A Shape is an object defined in C++ that duck types like the following class:
class Shape:
'''Represents an XLA shape.
A shape is either an array shape, having rank-many integer
dimensions and an element type (represented by a Numpy dtype), or it
is a tuple shape, having a shape for every tuple component:
type shape =
TupleShape of shape list
| ArrayShape of { dimensions: int list; element_type: dtype }
'''
@staticmethod
def tuple_shape(tuple_shapes) -> Shape:
"Construct a tuple shape."
@staticmethod
def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
@staticmethod
def from_pyval(pyval) -> Shape:
"Returns a Shape that describes a tuple-tree of Numpy arrays."
def __init__(self, str) -> Shape:
"Parses a shape string."
def __eq__(self, other: Shape) -> bool:
def __ne__(self, other: Shape) -> bool:
def __hash__(self):
def __repr__(self):
def is_tuple(self) -> bool:
def is_array(self) -> bool:
def tuple_shapes(self) -> [Shape]:
def numpy_dtype(self) -> np.dtype:
"Like element_type(), but returns dtype('O') for a tuple shape."
def xla_element_type(self) -> PrimitiveType:
def element_type(self) -> np.dtype:
def dimensions(self) -> (int, int, ...):
def rank(self) -> int:
def with_major_to_minor_layout_if_absent(self) -> Shape:
"Returns a copy with missing layouts set to major-to-minor."
def to_serialized_proto(self) -> bytes:
"Returns 'shape' as a serialized proto."
"""
ProgramShape = _xla.ProgramShape
ProgramShape.__doc__ = """
A ProgramShape is a C++ object that duck types like the following class.
class ProgramShape:
def __init__(self, parameter_shapes, result_shape):
def parameter_shapes(self) -> [Shape]:
def result_shape(self) -> Shape:
def __repr__(self):
"""
DeviceAssignment = _xla.DeviceAssignment
DeviceAssignment.__doc__ = """
A DeviceAssignment is a C++ object with the following signature.
def create(assignment):
'''Builds a device assignment.
Args:
assignment: a 2D numpy array of device ordinal integers, indexed by
[replica][computation_in_replica].
Returns:
A device assignment.
'''
def replica_count():
'''Returns the number of replicas.'''
def computation_count():
'''Returns the number of computations per replica.'''
"""
Device = _xla.Device
CompileOptions = _xla.CompileOptions
HostBufferSemantics = _xla.HostBufferSemantics
# An Executable is a C++ class that duck types with the following API:
# class Executable:
# def local_devices(self) -> [Device]:
# def execute(self, arguments : [Buffer]) -> Buffer:
# """Execute on one replica with Buffer arguments and return value."""
#
# def size_of_generated_code_in_bytes(self) -> int:
# """Return generated binary size, or -1 if not known."""
#
# def execute_sharded_on_local_devices(self, arguments: [[Buffer]])
# -> [Buffer]:
# """Execute on many replicas with Buffer arguments and return value.
#
# Args:
# arguments: A sequence of sequences of Buffers. The i'th element of each
# sequence comprises the arguments for execution on the i'th local
# device.
#
# Returns:
# A list of the computation's outputs as a list of Buffers for each
# device.
# """
#
# There are different implementations of Executable for different backends.
XlaComputation = _xla.XlaComputation
Client = _xla.Client
Memory = _xla.Memory
Array = _xla.Array
ArrayImpl = _xla.ArrayImpl
LoadedExecutable = _xla.LoadedExecutable
Executable = _xla.Executable
DeviceList = _xla.DeviceList
OpSharding = _xla.OpSharding
HloSharding = _xla.HloSharding
Sharding = _xla.Sharding
NamedSharding = _xla.NamedSharding
SingleDeviceSharding = _xla.SingleDeviceSharding
PmapSharding = _xla.PmapSharding
GSPMDSharding = _xla.GSPMDSharding
PjRtLayout = _xla.PjRtLayout
AutotuneCacheMode = _xla.AutotuneCacheMode
def LoadedExecutable_execute(self, arguments, device=None):
del device
results = self.execute_sharded(arguments)
return [x[0] for x in results.disassemble_into_single_device_arrays()]
def LoadedExecutable_execute_with_token(self, arguments, device=None):
del device
results = self.execute_sharded(arguments, with_tokens=True)
return (
[x[0] for x in results.disassemble_into_single_device_arrays()],
results.consume_token().get_token(0),
)
LoadedExecutable.execute = LoadedExecutable_execute
LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token
class CustomCallTargetTraits(enum.IntFlag):
DEFAULT = 0
# Calls to custom call are safe to trace into the command buffer. It means
# that calls to custom call always launch exactly the same device operations
# (can depend on attribute values) that can be captured and then replayed.
#
# Supported only for custom calls implemented with XLA FFI.
COMMAND_BUFFER_COMPATIBLE = 1
class CustomCallHandler(Protocol):
def __call__(
self,
name: str,
fn: Any,
platform: str,
/,
api_version: int = ...,
traits: CustomCallTargetTraits = ...,
) -> None:
...
_custom_callback_handler: dict[str, CustomCallHandler] = {}
# Key is xla_platform_name, value is (function_name, function, api_version)
_custom_callback: dict[
str, list[tuple[str, Any, int, CustomCallTargetTraits]]
] = {}
_custom_callback_lock = threading.Lock()
def register_custom_call_target(
name: str,
fn: Any,
platform: str = 'cpu',
api_version: int = 0,
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None:
"""Registers a custom call target.
Args:
name: bytes containing the name of the function.
fn: a PyCapsule object containing the function pointer.
platform: the target platform.
api_version: the XLA FFI version to use. Supported versions are: 0 for the
untyped FFI and 1 for the typed FFI.
traits: custom call traits corresponding to XLA FFI handler traits.
"""
# To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
# Since that is hardcoded to CUDA, we are using the following as workaround.
xla_platform_name = xla_platform_names.get(platform, platform)
with _custom_callback_lock:
if xla_platform_name in _custom_callback_handler:
_custom_callback_handler[xla_platform_name](
name, fn, xla_platform_name, api_version, traits
)
else:
_custom_callback.setdefault(xla_platform_name, []).append(
(name, fn, api_version, traits)
)
def register_custom_call_handler(
platform: str, handler: CustomCallHandler
) -> None:
"""Registers a custom handler and use it to register existing custom calls.
If a custom call handler for the platform already exist, calling this method
is a no-op and it will not register a new handler.
Args:
platform: the target platform.
handler: the function to register a custom call.
"""
xla_platform_name = xla_platform_names.get(platform, platform)
with _custom_callback_lock:
if xla_platform_name in _custom_callback_handler:
logger.debug(
'Custom call handler for %s is already register. Will not register a'
' new one',
xla_platform_name,
)
return
_custom_callback_handler[xla_platform_name] = handler
if xla_platform_name in _custom_callback:
for name, fn, api_version, traits in _custom_callback[xla_platform_name]:
handler(name, fn, xla_platform_name, api_version, traits)
del _custom_callback[xla_platform_name]
class CustomTypeIdHandler(Protocol):
def __call__(self, name: str, capsule: Any) -> None:
...
_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {}
_custom_type_id: dict[str, Any] = {}
_custom_type_id_lock = threading.Lock()
def register_custom_type_id(
type_name: str,
type_id: Any,
platform: str = 'cpu',
) -> None:
"""Register a custom type id for use with the FFI.
Args:
type_name: a unique name for the type.
type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``.
platform: the target platform.
"""
xla_platform_name = xla_platform_names.get(platform, platform)
with _custom_type_id_lock:
if xla_platform_name in _custom_type_id_handler:
_custom_type_id_handler[xla_platform_name](type_name, type_id)
else:
_custom_type_id.setdefault(xla_platform_name, []).append(
(type_name, type_id)
)
def register_custom_type_id_handler(
platform: str, handler: CustomTypeIdHandler
) -> None:
"""Register a custom type id handler and use it to register existing type ids.
If a custom type id handler for the platform already exist, calling this
method is a no-op and it will not register a new handler.
Args:
platform: the target platform.
handler: the function to register a custom type id.
"""
xla_platform_name = xla_platform_names.get(platform, platform)
with _custom_callback_lock:
if xla_platform_name in _custom_type_id_handler:
logger.debug(
'Custom type id handler for %s is already register. Will not '
'register a new one',
xla_platform_name,
)
return
_custom_type_id_handler[xla_platform_name] = handler
if xla_platform_name in _custom_type_id:
for name, capsule in _custom_type_id[xla_platform_name]:
handler(name, capsule)
del _custom_type_id[xla_platform_name]
register_custom_call_partitioner = _xla.register_custom_call_partitioner
encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback
hlo_sharding_util = _xla.hlo_sharding_util
register_custom_call_as_batch_partitionable = (
_xla.register_custom_call_as_batch_partitionable
)
Traceback = _xla.Traceback
Frame = _xla.Frame
@contextlib.contextmanager
def tracebacks(enabled=True):
"""Context manager that enables or disables traceback collection."""
saved = _xla.tracebacks_enabled()
_xla.set_tracebacks_enabled(enabled)
try:
yield
finally:
_xla.set_tracebacks_enabled(saved)
@contextlib.contextmanager
def execution_stream_id(new_id: int):
"""Context manager that overwrites and restores the current thread's execution_stream_id."""
saved = _xla.get_execution_stream_id()
_xla.set_execution_stream_id(new_id)
try:
yield
finally:
_xla.set_execution_stream_id(saved)
XlaRuntimeError = _xla.XlaRuntimeError
# Perform one last garbage collection of deferred Python references. This is
# mostly to keep ASAN happy.
atexit.register(_xla.collect_garbage)
array_result_handler = _xla.array_result_handler
batched_copy_array_to_devices_with_sharding = (
_xla.batched_copy_array_to_devices_with_sharding
)
batched_device_put = _xla.batched_device_put
reorder_shards = _xla.reorder_shards
batched_block_until_ready = _xla.batched_block_until_ready
check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind
Layout = _xla.Layout
custom_call_targets = _xla.custom_call_targets
ArrayCopySemantics = _xla.ArrayCopySemantics