175 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			175 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2025 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Colocated Python object API implementation."""
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import inspect
 | |
| import random
 | |
| import threading
 | |
| from typing import Any, Callable, Type
 | |
| 
 | |
| import jax
 | |
| from jax._src import api_util
 | |
| from jax._src import tree_util
 | |
| from jax._src.traceback_util import api_boundary
 | |
| from jax._src.util import wraps
 | |
| from jax.experimental.colocated_python import func
 | |
| from jax.experimental.colocated_python import obj_backend
 | |
| 
 | |
| 
 | |
| class _InstanceRegistry:
 | |
|   """Registry of object instances."""
 | |
| 
 | |
|   def __init__(self) -> None:
 | |
|     self._lock = threading.Lock()
 | |
|     self._storage: dict[int, set[jax.Device]] = {}
 | |
| 
 | |
|   def new_instance(self) -> int:
 | |
|     """Returns a new unique identifier for an instance on the controller."""
 | |
|     uid = random.getrandbits(63)
 | |
|     with self._lock:
 | |
|       assert uid not in self._storage
 | |
|       self._storage[uid] = set()
 | |
|     return uid
 | |
| 
 | |
|   def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
 | |
|     """Updates the set of devices on which it is live."""
 | |
|     with self._lock:
 | |
|       self._storage[uid] |= device_set
 | |
| 
 | |
|   def pop_instance(self, uid: int) -> set[jax.Device]:
 | |
|     """Removes the instance and returns the set of devices on which it is live."""
 | |
|     with self._lock:
 | |
|       return self._storage.pop(uid)
 | |
| 
 | |
| 
 | |
| SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
 | |
| 
 | |
| 
 | |
| @jax._src.util.cache(max_size=4096)
 | |
| def _update_instance_devices(
 | |
|     uid: int, shardings: tuple[jax.sharding.Sharding, ...]
 | |
| ) -> None:
 | |
|   """Caching version of _InstanceRegistry.update_devices()."""
 | |
|   device_set = set()
 | |
|   for sharding in shardings:
 | |
|     device_set |= sharding.device_set
 | |
|   SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
 | |
| 
 | |
| 
 | |
| def _make_method(
 | |
|     cls: Type[object],
 | |
|     cls_sourceinfo: str | None,
 | |
|     uid: int,
 | |
|     init_args: tuple[Any, ...],
 | |
|     init_kwargs: dict[str, Any],
 | |
|     method_name: str,
 | |
|     original_method: Callable[..., Any],
 | |
| ):
 | |
|   # Initializer to use when the object is not present in the backend.
 | |
|   def initializer() -> object:
 | |
|     return cls(*init_args, **init_kwargs)
 | |
| 
 | |
|   # Method to call on the backend.
 | |
|   def method(*args, **kwargs):
 | |
|     obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer)
 | |
|     return getattr(obj, method_name)(*args, **kwargs)
 | |
| 
 | |
|   # Colocated Python callable for the controller.
 | |
|   callable = func.make_callable(
 | |
|       method,
 | |
|       cls_sourceinfo,
 | |
|       api_util.fun_signature(original_method),
 | |
|   )
 | |
| 
 | |
|   # Outer wrapper of the method for the controller. It tracks
 | |
|   @api_boundary
 | |
|   def method_wrapper(*args, **kwargs):
 | |
|     if not args:
 | |
|       raise NotImplementedError(
 | |
|           'Method calls with no arguments are not yet supported.'
 | |
|       )
 | |
|     # TODO(hyeontaek): Instead of inspecting argument shardings, get shardings
 | |
|     # from final specialization of the function. This may require lowering
 | |
|     # `_update_instance_devices` into the function API.
 | |
|     args_leaves = tree_util.tree_leaves((args, kwargs))
 | |
|     shardings_leaves = tuple(func._get_spec(x).sharding for x in args_leaves)
 | |
|     _update_instance_devices(uid, shardings_leaves)
 | |
|     return callable(*args, **kwargs)
 | |
| 
 | |
|   method_wrapper = wraps(original_method)(method_wrapper)
 | |
|   return method_wrapper
 | |
| 
 | |
| 
 | |
| def wrap_class(
 | |
|     cls: Type[object],
 | |
|     cls_sourceinfo: str | None,
 | |
| ) -> Type[object]:
 | |
|   class WrappedClass:
 | |
| 
 | |
|     @wraps(cls.__init__)
 | |
|     def __init__(self, *init_args, **init_kwargs) -> None:
 | |
|       uid = self._colocated_python_uid = (
 | |
|           SINGLETON_INSTANCE_REGISTRY.new_instance()
 | |
|       )
 | |
|       for attr_name in dir(cls):
 | |
|         original_member = getattr(cls, attr_name)
 | |
|         if not inspect.isfunction(original_member):
 | |
|           continue
 | |
| 
 | |
|         # WrappedClass defines lazy initialization and colocated deletion logic.
 | |
|         # WrappedClass is not serializable even if the original class may be
 | |
|         # serializable.
 | |
|         if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
 | |
|           continue
 | |
| 
 | |
|         method = _make_method(
 | |
|             cls,
 | |
|             cls_sourceinfo,
 | |
|             uid,
 | |
|             init_args,
 | |
|             init_kwargs,
 | |
|             attr_name,
 | |
|             original_member,
 | |
|         )
 | |
|         # TODO(hyeontaek): Support method specialization similar to function
 | |
|         # specialization.
 | |
|         setattr(self, attr_name, method)
 | |
| 
 | |
|     def __del__(self) -> None:
 | |
|       uid = self._colocated_python_uid
 | |
|       devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid)
 | |
|       if devices:
 | |
| 
 | |
|         def remove_object() -> None:
 | |
|           obj_backend.SINGLETON_OBJECT_STORE.remove(uid)
 | |
| 
 | |
|         # TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries
 | |
|         # to run this function on any healthy processes instead of failing when
 | |
|         # any process of the execution is unhealthy.
 | |
|         destructor = func.make_callable(
 | |
|             remove_object,
 | |
|             cls_sourceinfo,
 | |
|             None,
 | |
|         )
 | |
|         destructor = destructor.specialize(  # type: ignore[attribute-error]
 | |
|             devices=devices
 | |
|         )
 | |
|         destructor()
 | |
| 
 | |
|   WrappedClass.__name__ = cls.__name__
 | |
|   WrappedClass.__doc__ = cls.__doc__
 | |
|   return WrappedClass
 |