550 lines
17 KiB
Python
550 lines
17 KiB
Python
# Copyright 2017 The JAX Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""An XLA client in Python."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
from collections.abc import Mapping
|
|
import contextlib
|
|
import enum
|
|
import logging
|
|
import os
|
|
import threading
|
|
from typing import Any, Protocol, Union
|
|
|
|
from jaxlib import _jax as _xla
|
|
|
|
# Note this module does *not* depend on any Python protocol buffers. The XLA
|
|
# Python bindings are currently packaged both as part of jaxlib and as part
|
|
# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
|
|
# and TensorFlow may fail with duplicate protocol buffer message definitions.
|
|
|
|
# Most functions are snake_case for consistency with other modules, some
|
|
# method names are CamelCase for consistency with XLA.
|
|
# pylint: disable=invalid-name
|
|
|
|
# Pylint has false positives for type annotations.
|
|
# pylint: disable=invalid-sequence-index
|
|
|
|
ifrt_programs = _xla.ifrt_programs
|
|
|
|
# Just an internal arbitrary increasing number to help with backward-compatible
|
|
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
|
|
_version = 355
|
|
|
|
# An internal increasing version number for protecting jaxlib code against
|
|
# ifrt changes.
|
|
# lives in xla/python/version.h.
|
|
# In JAX, reference this via jax._src.lib.ifrt_version.
|
|
_ifrt_version = _xla.ifrt_version_number
|
|
|
|
xla_platform_names = {
|
|
'cpu': 'Host',
|
|
'gpu': 'CUDA',
|
|
}
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]
|
|
|
|
|
|
def make_cpu_client(
|
|
asynchronous=True,
|
|
distributed_client=None,
|
|
node_id=0,
|
|
num_nodes=1,
|
|
collectives=None,
|
|
num_devices=None,
|
|
get_local_topology_timeout_minutes=None,
|
|
get_global_topology_timeout_minutes=None,
|
|
) -> Client:
|
|
register_custom_call_handler('cpu', _xla.register_custom_call_target)
|
|
register_custom_type_id_handler('cpu', _xla.register_custom_type_id)
|
|
return _xla.get_tfrt_cpu_client(
|
|
asynchronous=asynchronous,
|
|
distributed_client=distributed_client,
|
|
node_id=node_id,
|
|
num_nodes=num_nodes,
|
|
collectives=collectives,
|
|
num_devices=num_devices,
|
|
get_local_topology_timeout_minutes=get_local_topology_timeout_minutes,
|
|
get_global_topology_timeout_minutes=get_global_topology_timeout_minutes,
|
|
)
|
|
|
|
|
|
DeviceTopology = _xla.DeviceTopology
|
|
get_topology_for_devices = _xla.get_topology_for_devices
|
|
|
|
|
|
def make_tfrt_tpu_c_api_device_topology(
|
|
topology_name: str = '', **kwargs
|
|
) -> DeviceTopology:
|
|
"""Creates a PJRT C API TopologyDescription."""
|
|
return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs))
|
|
|
|
|
|
def make_c_api_device_topology(
|
|
c_api: Any, topology_name: str = '', **kwargs
|
|
) -> DeviceTopology:
|
|
"""Creates a PJRT C API TopologyDescription."""
|
|
return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs))
|
|
|
|
|
|
def pjrt_plugin_loaded(plugin_name: str) -> bool:
|
|
return _xla.pjrt_plugin_loaded(plugin_name)
|
|
|
|
|
|
def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
|
|
return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None)
|
|
|
|
|
|
def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
|
|
return _xla.load_pjrt_plugin(plugin_name, None, c_api)
|
|
|
|
|
|
def pjrt_plugin_initialized(plugin_name: str) -> bool:
|
|
return _xla.pjrt_plugin_initialized(plugin_name)
|
|
|
|
|
|
def initialize_pjrt_plugin(plugin_name: str) -> None:
|
|
"""Initializes a PJRT plugin.
|
|
|
|
The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
|
|
static linking) before this method is called.
|
|
Args:
|
|
plugin_name: the name of the PJRT plugin.
|
|
"""
|
|
_xla.initialize_pjrt_plugin(plugin_name)
|
|
|
|
|
|
def make_c_api_client(
|
|
plugin_name: str,
|
|
options: _NameValueMapping | None = None,
|
|
distributed_client: _xla.DistributedRuntimeClient | None = None,
|
|
):
|
|
"""Creates a PJRT C API client for a PJRT plugin.
|
|
|
|
It is required that load_pjrt_plugin_dynamically is called once with the same
|
|
plugin_name before this method is called.
|
|
|
|
Args:
|
|
plugin_name: the name of the PJRT plugin.
|
|
options: extra platform-specific options.
|
|
distributed_client: distributed client.
|
|
|
|
Returns:
|
|
A PJRT C API client for plugin_name.
|
|
"""
|
|
if options is None:
|
|
options = {}
|
|
return _xla.get_c_api_client(plugin_name, options, distributed_client)
|
|
|
|
|
|
def generate_pjrt_gpu_plugin_options() -> _NameValueMapping:
|
|
"""Generates the PjRt GPU plugin options.
|
|
|
|
Returns:
|
|
A dictionary of plugin options.
|
|
"""
|
|
|
|
options = {}
|
|
options['platform_name'] = 'cuda'
|
|
allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
|
|
memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '')
|
|
deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '')
|
|
if deprecated_memory_fraction:
|
|
if memory_fraction:
|
|
raise ValueError(
|
|
'XLA_CLIENT_MEM_FRACTION is specified together '
|
|
'with XLA_PYTHON_CLIENT_MEM_FRACTION. '
|
|
'Remove the latter one, it is deprecated.'
|
|
)
|
|
else:
|
|
memory_fraction = deprecated_memory_fraction
|
|
preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '')
|
|
collective_memory_size = os.getenv(
|
|
'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', ''
|
|
)
|
|
if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
|
|
raise ValueError(
|
|
'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
|
|
'"bfc", or "cuda_async", got "%s"' % allocator
|
|
)
|
|
options['allocator'] = allocator
|
|
if memory_fraction:
|
|
options['memory_fraction'] = float(memory_fraction)
|
|
if preallocate:
|
|
options['preallocate'] = preallocate not in ('false', 'False', '0')
|
|
if collective_memory_size:
|
|
options['collective_memory_size'] = int(collective_memory_size) * (1 << 20)
|
|
return options
|
|
|
|
|
|
PrimitiveType = _xla.PrimitiveType
|
|
|
|
Shape = _xla.Shape
|
|
Shape.__doc__ = """
|
|
A Shape is an object defined in C++ that duck types like the following class:
|
|
|
|
class Shape:
|
|
'''Represents an XLA shape.
|
|
|
|
A shape is either an array shape, having rank-many integer
|
|
dimensions and an element type (represented by a Numpy dtype), or it
|
|
is a tuple shape, having a shape for every tuple component:
|
|
|
|
type shape =
|
|
TupleShape of shape list
|
|
| ArrayShape of { dimensions: int list; element_type: dtype }
|
|
'''
|
|
|
|
@staticmethod
|
|
def tuple_shape(tuple_shapes) -> Shape:
|
|
"Construct a tuple shape."
|
|
|
|
@staticmethod
|
|
def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
|
|
|
|
@staticmethod
|
|
def from_pyval(pyval) -> Shape:
|
|
"Returns a Shape that describes a tuple-tree of Numpy arrays."
|
|
|
|
def __init__(self, str) -> Shape:
|
|
"Parses a shape string."
|
|
def __eq__(self, other: Shape) -> bool:
|
|
def __ne__(self, other: Shape) -> bool:
|
|
def __hash__(self):
|
|
def __repr__(self):
|
|
def is_tuple(self) -> bool:
|
|
def is_array(self) -> bool:
|
|
def tuple_shapes(self) -> [Shape]:
|
|
def numpy_dtype(self) -> np.dtype:
|
|
"Like element_type(), but returns dtype('O') for a tuple shape."
|
|
def xla_element_type(self) -> PrimitiveType:
|
|
def element_type(self) -> np.dtype:
|
|
def dimensions(self) -> (int, int, ...):
|
|
def rank(self) -> int:
|
|
def with_major_to_minor_layout_if_absent(self) -> Shape:
|
|
"Returns a copy with missing layouts set to major-to-minor."
|
|
|
|
def to_serialized_proto(self) -> bytes:
|
|
"Returns 'shape' as a serialized proto."
|
|
"""
|
|
|
|
ProgramShape = _xla.ProgramShape
|
|
ProgramShape.__doc__ = """
|
|
A ProgramShape is a C++ object that duck types like the following class.
|
|
|
|
class ProgramShape:
|
|
def __init__(self, parameter_shapes, result_shape):
|
|
def parameter_shapes(self) -> [Shape]:
|
|
def result_shape(self) -> Shape:
|
|
def __repr__(self):
|
|
"""
|
|
|
|
DeviceAssignment = _xla.DeviceAssignment
|
|
DeviceAssignment.__doc__ = """
|
|
A DeviceAssignment is a C++ object with the following signature.
|
|
|
|
def create(assignment):
|
|
'''Builds a device assignment.
|
|
|
|
Args:
|
|
assignment: a 2D numpy array of device ordinal integers, indexed by
|
|
[replica][computation_in_replica].
|
|
Returns:
|
|
A device assignment.
|
|
'''
|
|
|
|
def replica_count():
|
|
'''Returns the number of replicas.'''
|
|
def computation_count():
|
|
'''Returns the number of computations per replica.'''
|
|
"""
|
|
|
|
Device = _xla.Device
|
|
CompileOptions = _xla.CompileOptions
|
|
|
|
HostBufferSemantics = _xla.HostBufferSemantics
|
|
|
|
# An Executable is a C++ class that duck types with the following API:
|
|
# class Executable:
|
|
# def local_devices(self) -> [Device]:
|
|
# def execute(self, arguments : [Buffer]) -> Buffer:
|
|
# """Execute on one replica with Buffer arguments and return value."""
|
|
#
|
|
# def size_of_generated_code_in_bytes(self) -> int:
|
|
# """Return generated binary size, or -1 if not known."""
|
|
#
|
|
# def execute_sharded_on_local_devices(self, arguments: [[Buffer]])
|
|
# -> [Buffer]:
|
|
# """Execute on many replicas with Buffer arguments and return value.
|
|
#
|
|
# Args:
|
|
# arguments: A sequence of sequences of Buffers. The i'th element of each
|
|
# sequence comprises the arguments for execution on the i'th local
|
|
# device.
|
|
#
|
|
# Returns:
|
|
# A list of the computation's outputs as a list of Buffers for each
|
|
# device.
|
|
# """
|
|
#
|
|
# There are different implementations of Executable for different backends.
|
|
|
|
|
|
XlaComputation = _xla.XlaComputation
|
|
Client = _xla.Client
|
|
Memory = _xla.Memory
|
|
Array = _xla.Array
|
|
ArrayImpl = _xla.ArrayImpl
|
|
LoadedExecutable = _xla.LoadedExecutable
|
|
Executable = _xla.Executable
|
|
DeviceList = _xla.DeviceList
|
|
OpSharding = _xla.OpSharding
|
|
HloSharding = _xla.HloSharding
|
|
Sharding = _xla.Sharding
|
|
NamedSharding = _xla.NamedSharding
|
|
SingleDeviceSharding = _xla.SingleDeviceSharding
|
|
PmapSharding = _xla.PmapSharding
|
|
GSPMDSharding = _xla.GSPMDSharding
|
|
PjRtLayout = _xla.PjRtLayout
|
|
AutotuneCacheMode = _xla.AutotuneCacheMode
|
|
|
|
|
|
def LoadedExecutable_execute(self, arguments, device=None):
|
|
del device
|
|
results = self.execute_sharded(arguments)
|
|
return [x[0] for x in results.disassemble_into_single_device_arrays()]
|
|
|
|
|
|
def LoadedExecutable_execute_with_token(self, arguments, device=None):
|
|
del device
|
|
results = self.execute_sharded(arguments, with_tokens=True)
|
|
return (
|
|
[x[0] for x in results.disassemble_into_single_device_arrays()],
|
|
results.consume_token().get_token(0),
|
|
)
|
|
|
|
|
|
LoadedExecutable.execute = LoadedExecutable_execute
|
|
LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token
|
|
|
|
|
|
class CustomCallTargetTraits(enum.IntFlag):
|
|
DEFAULT = 0
|
|
# Calls to custom call are safe to trace into the command buffer. It means
|
|
# that calls to custom call always launch exactly the same device operations
|
|
# (can depend on attribute values) that can be captured and then replayed.
|
|
#
|
|
# Supported only for custom calls implemented with XLA FFI.
|
|
COMMAND_BUFFER_COMPATIBLE = 1
|
|
|
|
|
|
class CustomCallHandler(Protocol):
|
|
|
|
def __call__(
|
|
self,
|
|
name: str,
|
|
fn: Any,
|
|
platform: str,
|
|
/,
|
|
api_version: int = ...,
|
|
traits: CustomCallTargetTraits = ...,
|
|
) -> None:
|
|
...
|
|
|
|
|
|
_custom_callback_handler: dict[str, CustomCallHandler] = {}
|
|
# Key is xla_platform_name, value is (function_name, function, api_version)
|
|
_custom_callback: dict[
|
|
str, list[tuple[str, Any, int, CustomCallTargetTraits]]
|
|
] = {}
|
|
_custom_callback_lock = threading.Lock()
|
|
|
|
|
|
def register_custom_call_target(
|
|
name: str,
|
|
fn: Any,
|
|
platform: str = 'cpu',
|
|
api_version: int = 0,
|
|
traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
|
|
) -> None:
|
|
"""Registers a custom call target.
|
|
|
|
Args:
|
|
name: bytes containing the name of the function.
|
|
fn: a PyCapsule object containing the function pointer.
|
|
platform: the target platform.
|
|
api_version: the XLA FFI version to use. Supported versions are: 0 for the
|
|
untyped FFI and 1 for the typed FFI.
|
|
traits: custom call traits corresponding to XLA FFI handler traits.
|
|
"""
|
|
# To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
|
|
# Since that is hardcoded to CUDA, we are using the following as workaround.
|
|
xla_platform_name = xla_platform_names.get(platform, platform)
|
|
with _custom_callback_lock:
|
|
if xla_platform_name in _custom_callback_handler:
|
|
_custom_callback_handler[xla_platform_name](
|
|
name, fn, xla_platform_name, api_version, traits
|
|
)
|
|
else:
|
|
_custom_callback.setdefault(xla_platform_name, []).append(
|
|
(name, fn, api_version, traits)
|
|
)
|
|
|
|
|
|
def register_custom_call_handler(
|
|
platform: str, handler: CustomCallHandler
|
|
) -> None:
|
|
"""Registers a custom handler and use it to register existing custom calls.
|
|
|
|
If a custom call handler for the platform already exist, calling this method
|
|
is a no-op and it will not register a new handler.
|
|
|
|
Args:
|
|
platform: the target platform.
|
|
handler: the function to register a custom call.
|
|
"""
|
|
xla_platform_name = xla_platform_names.get(platform, platform)
|
|
with _custom_callback_lock:
|
|
if xla_platform_name in _custom_callback_handler:
|
|
logger.debug(
|
|
'Custom call handler for %s is already register. Will not register a'
|
|
' new one',
|
|
xla_platform_name,
|
|
)
|
|
return
|
|
_custom_callback_handler[xla_platform_name] = handler
|
|
if xla_platform_name in _custom_callback:
|
|
for name, fn, api_version, traits in _custom_callback[xla_platform_name]:
|
|
handler(name, fn, xla_platform_name, api_version, traits)
|
|
del _custom_callback[xla_platform_name]
|
|
|
|
|
|
class CustomTypeIdHandler(Protocol):
|
|
|
|
def __call__(self, name: str, capsule: Any) -> None:
|
|
...
|
|
|
|
|
|
_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {}
|
|
_custom_type_id: dict[str, Any] = {}
|
|
_custom_type_id_lock = threading.Lock()
|
|
|
|
|
|
def register_custom_type_id(
|
|
type_name: str,
|
|
type_id: Any,
|
|
platform: str = 'cpu',
|
|
) -> None:
|
|
"""Register a custom type id for use with the FFI.
|
|
|
|
Args:
|
|
type_name: a unique name for the type.
|
|
type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``.
|
|
platform: the target platform.
|
|
"""
|
|
xla_platform_name = xla_platform_names.get(platform, platform)
|
|
with _custom_type_id_lock:
|
|
if xla_platform_name in _custom_type_id_handler:
|
|
_custom_type_id_handler[xla_platform_name](type_name, type_id)
|
|
else:
|
|
_custom_type_id.setdefault(xla_platform_name, []).append(
|
|
(type_name, type_id)
|
|
)
|
|
|
|
|
|
def register_custom_type_id_handler(
|
|
platform: str, handler: CustomTypeIdHandler
|
|
) -> None:
|
|
"""Register a custom type id handler and use it to register existing type ids.
|
|
|
|
If a custom type id handler for the platform already exist, calling this
|
|
method is a no-op and it will not register a new handler.
|
|
|
|
Args:
|
|
platform: the target platform.
|
|
handler: the function to register a custom type id.
|
|
"""
|
|
xla_platform_name = xla_platform_names.get(platform, platform)
|
|
with _custom_callback_lock:
|
|
if xla_platform_name in _custom_type_id_handler:
|
|
logger.debug(
|
|
'Custom type id handler for %s is already register. Will not '
|
|
'register a new one',
|
|
xla_platform_name,
|
|
)
|
|
return
|
|
_custom_type_id_handler[xla_platform_name] = handler
|
|
if xla_platform_name in _custom_type_id:
|
|
for name, capsule in _custom_type_id[xla_platform_name]:
|
|
handler(name, capsule)
|
|
del _custom_type_id[xla_platform_name]
|
|
|
|
|
|
register_custom_call_partitioner = _xla.register_custom_call_partitioner
|
|
encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback
|
|
hlo_sharding_util = _xla.hlo_sharding_util
|
|
register_custom_call_as_batch_partitionable = (
|
|
_xla.register_custom_call_as_batch_partitionable
|
|
)
|
|
|
|
|
|
Traceback = _xla.Traceback
|
|
Frame = _xla.Frame
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tracebacks(enabled=True):
|
|
"""Context manager that enables or disables traceback collection."""
|
|
saved = _xla.tracebacks_enabled()
|
|
_xla.set_tracebacks_enabled(enabled)
|
|
try:
|
|
yield
|
|
finally:
|
|
_xla.set_tracebacks_enabled(saved)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def execution_stream_id(new_id: int):
|
|
"""Context manager that overwrites and restores the current thread's execution_stream_id."""
|
|
saved = _xla.get_execution_stream_id()
|
|
_xla.set_execution_stream_id(new_id)
|
|
try:
|
|
yield
|
|
finally:
|
|
_xla.set_execution_stream_id(saved)
|
|
|
|
|
|
XlaRuntimeError = _xla.XlaRuntimeError
|
|
|
|
# Perform one last garbage collection of deferred Python references. This is
|
|
# mostly to keep ASAN happy.
|
|
atexit.register(_xla.collect_garbage)
|
|
|
|
array_result_handler = _xla.array_result_handler
|
|
batched_copy_array_to_devices_with_sharding = (
|
|
_xla.batched_copy_array_to_devices_with_sharding
|
|
)
|
|
batched_device_put = _xla.batched_device_put
|
|
reorder_shards = _xla.reorder_shards
|
|
batched_block_until_ready = _xla.batched_block_until_ready
|
|
check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind
|
|
Layout = _xla.Layout
|
|
custom_call_targets = _xla.custom_call_targets
|
|
ArrayCopySemantics = _xla.ArrayCopySemantics
|