123 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2024 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Colocated Python top-level API."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import collections
 | |
| from typing import Any, Callable, Sequence, Type, overload
 | |
| 
 | |
| import jax
 | |
| from jax._src import api_util
 | |
| from jax._src import util
 | |
| from jax.experimental.colocated_python.func import make_callable
 | |
| from jax.experimental.colocated_python.obj import wrap_class
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def colocated_cpu_devices(
 | |
|     devices_or_mesh: Sequence[jax.Device],
 | |
| ) -> Sequence[jax.Device]:
 | |
|   ...
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def colocated_cpu_devices(
 | |
|     devices_or_mesh: jax.sharding.Mesh,
 | |
| ) -> jax.sharding.Mesh:
 | |
|   ...
 | |
| 
 | |
| 
 | |
| def colocated_cpu_devices(devices_or_mesh):
 | |
|   """Finds devices or a mesh that has CPU devices colocated with the given devices or mesh."""
 | |
|   if isinstance(devices_or_mesh, jax.sharding.Mesh):
 | |
|     return _colocated_cpu_mesh_cached(devices_or_mesh)
 | |
| 
 | |
|   if not isinstance(devices_or_mesh, tuple):
 | |
|     devices_or_mesh = tuple(devices_or_mesh)
 | |
|   try:
 | |
|     return _colocated_cpu_devices_cached(devices_or_mesh)
 | |
|   except (ValueError, AttributeError):
 | |
|     return _colocated_cpu_devices_cached_fallback_to_cpu_backend(
 | |
|         devices_or_mesh
 | |
|     )
 | |
| 
 | |
| 
 | |
| @util.cache(max_size=1024, trace_context_in_key=False)
 | |
| def _colocated_cpu_devices_cached(
 | |
|     devices: tuple[jax.Device, ...],
 | |
| ) -> Sequence[jax.Device]:
 | |
|   cpu_devices_by_colocation_id = collections.defaultdict(list)
 | |
|   for device in devices[0].client._get_all_devices():  # pylint: disable=protected-access
 | |
|     if device.device_kind == "cpu":
 | |
|       cpu_devices_by_colocation_id[device.colocation_id].append(device)
 | |
|   if not cpu_devices_by_colocation_id:
 | |
|     raise ValueError("No CPU devices found")
 | |
| 
 | |
|   colocated_cpu_devices = []
 | |
|   for device in devices:
 | |
|     matches = cpu_devices_by_colocation_id[device.colocation_id]
 | |
|     if not matches:
 | |
|       raise ValueError(f"Device {device} has no colocated devices")
 | |
|     elif len(matches) > 1:
 | |
|       raise ValueError(
 | |
|           f"Ambiguous colocated devices; device {device} has"
 | |
|           f" {len(matches)} colocated devices: f{matches}"
 | |
|       )
 | |
|     colocated_cpu_devices.append(matches[0])
 | |
|   return colocated_cpu_devices
 | |
| 
 | |
| 
 | |
| @util.cache(max_size=1024, trace_context_in_key=False)
 | |
| def _colocated_cpu_devices_cached_fallback_to_cpu_backend(
 | |
|     devices: tuple[jax.Device, ...],
 | |
| ) -> Sequence[jax.Device]:
 | |
|   # PjRt-IFRT currently defines CPU devices by using a CPU backend.
 | |
|   # TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines
 | |
|   # CPU devices by its own instead of using a separate CPU backend.
 | |
|   cpu_backend_devices = jax.local_devices(backend="cpu")
 | |
|   device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
 | |
| 
 | |
|   available_devices = devices[: min(len(cpu_backend_devices), len(devices))]
 | |
|   return [
 | |
|       cpu_backend_devices[device_index_map[d.id]] for d in available_devices
 | |
|   ]
 | |
| 
 | |
| 
 | |
| @util.cache(max_size=1024, trace_context_in_key=False)
 | |
| def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh:
 | |
|   """Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices."""
 | |
|   # Finding colocated CPU devices reuses the cache of `colocated_cpu_devices`
 | |
|   # called with devices. `_colocated_cpu_mesh` itself is also cached to avoid
 | |
|   # creating a new `Mesh` object repeatedly.
 | |
|   flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat))
 | |
|   return jax.sharding.Mesh(
 | |
|       np.array(flat_cpu_devices).reshape(mesh.axis_sizes),
 | |
|       mesh.axis_names,
 | |
|       axis_types=mesh.axis_types,
 | |
|   )
 | |
| 
 | |
| 
 | |
| def colocated_python(fun: Callable[..., Any]):
 | |
|   """Executes the given Python function on the same devices as the arguments."""
 | |
|   return make_callable(
 | |
|       fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
 | |
|   )
 | |
| 
 | |
| 
 | |
| def colocated_python_class(cls: Type[object]) -> Type[object]:
 | |
|   """Executes the given Python class methods on the same devices as the arguments."""
 | |
|   return wrap_class(cls, api_util.fun_sourceinfo(cls))
 |