175 lines
5.5 KiB
Python
175 lines
5.5 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Colocated Python object API implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
import random
|
|
import threading
|
|
from typing import Any, Callable, Type
|
|
|
|
import jax
|
|
from jax._src import api_util
|
|
from jax._src import tree_util
|
|
from jax._src.traceback_util import api_boundary
|
|
from jax._src.util import wraps
|
|
from jax.experimental.colocated_python import func
|
|
from jax.experimental.colocated_python import obj_backend
|
|
|
|
|
|
class _InstanceRegistry:
|
|
"""Registry of object instances."""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = threading.Lock()
|
|
self._storage: dict[int, set[jax.Device]] = {}
|
|
|
|
def new_instance(self) -> int:
|
|
"""Returns a new unique identifier for an instance on the controller."""
|
|
uid = random.getrandbits(63)
|
|
with self._lock:
|
|
assert uid not in self._storage
|
|
self._storage[uid] = set()
|
|
return uid
|
|
|
|
def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
|
|
"""Updates the set of devices on which it is live."""
|
|
with self._lock:
|
|
self._storage[uid] |= device_set
|
|
|
|
def pop_instance(self, uid: int) -> set[jax.Device]:
|
|
"""Removes the instance and returns the set of devices on which it is live."""
|
|
with self._lock:
|
|
return self._storage.pop(uid)
|
|
|
|
|
|
SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
|
|
|
|
|
|
@jax._src.util.cache(max_size=4096)
|
|
def _update_instance_devices(
|
|
uid: int, shardings: tuple[jax.sharding.Sharding, ...]
|
|
) -> None:
|
|
"""Caching version of _InstanceRegistry.update_devices()."""
|
|
device_set = set()
|
|
for sharding in shardings:
|
|
device_set |= sharding.device_set
|
|
SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
|
|
|
|
|
|
def _make_method(
|
|
cls: Type[object],
|
|
cls_sourceinfo: str | None,
|
|
uid: int,
|
|
init_args: tuple[Any, ...],
|
|
init_kwargs: dict[str, Any],
|
|
method_name: str,
|
|
original_method: Callable[..., Any],
|
|
):
|
|
# Initializer to use when the object is not present in the backend.
|
|
def initializer() -> object:
|
|
return cls(*init_args, **init_kwargs)
|
|
|
|
# Method to call on the backend.
|
|
def method(*args, **kwargs):
|
|
obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer)
|
|
return getattr(obj, method_name)(*args, **kwargs)
|
|
|
|
# Colocated Python callable for the controller.
|
|
callable = func.make_callable(
|
|
method,
|
|
cls_sourceinfo,
|
|
api_util.fun_signature(original_method),
|
|
)
|
|
|
|
# Outer wrapper of the method for the controller. It tracks
|
|
@api_boundary
|
|
def method_wrapper(*args, **kwargs):
|
|
if not args:
|
|
raise NotImplementedError(
|
|
'Method calls with no arguments are not yet supported.'
|
|
)
|
|
# TODO(hyeontaek): Instead of inspecting argument shardings, get shardings
|
|
# from final specialization of the function. This may require lowering
|
|
# `_update_instance_devices` into the function API.
|
|
args_leaves = tree_util.tree_leaves((args, kwargs))
|
|
shardings_leaves = tuple(func._get_spec(x).sharding for x in args_leaves)
|
|
_update_instance_devices(uid, shardings_leaves)
|
|
return callable(*args, **kwargs)
|
|
|
|
method_wrapper = wraps(original_method)(method_wrapper)
|
|
return method_wrapper
|
|
|
|
|
|
def wrap_class(
|
|
cls: Type[object],
|
|
cls_sourceinfo: str | None,
|
|
) -> Type[object]:
|
|
class WrappedClass:
|
|
|
|
@wraps(cls.__init__)
|
|
def __init__(self, *init_args, **init_kwargs) -> None:
|
|
uid = self._colocated_python_uid = (
|
|
SINGLETON_INSTANCE_REGISTRY.new_instance()
|
|
)
|
|
for attr_name in dir(cls):
|
|
original_member = getattr(cls, attr_name)
|
|
if not inspect.isfunction(original_member):
|
|
continue
|
|
|
|
# WrappedClass defines lazy initialization and colocated deletion logic.
|
|
# WrappedClass is not serializable even if the original class may be
|
|
# serializable.
|
|
if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
|
|
continue
|
|
|
|
method = _make_method(
|
|
cls,
|
|
cls_sourceinfo,
|
|
uid,
|
|
init_args,
|
|
init_kwargs,
|
|
attr_name,
|
|
original_member,
|
|
)
|
|
# TODO(hyeontaek): Support method specialization similar to function
|
|
# specialization.
|
|
setattr(self, attr_name, method)
|
|
|
|
def __del__(self) -> None:
|
|
uid = self._colocated_python_uid
|
|
devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid)
|
|
if devices:
|
|
|
|
def remove_object() -> None:
|
|
obj_backend.SINGLETON_OBJECT_STORE.remove(uid)
|
|
|
|
# TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries
|
|
# to run this function on any healthy processes instead of failing when
|
|
# any process of the execution is unhealthy.
|
|
destructor = func.make_callable(
|
|
remove_object,
|
|
cls_sourceinfo,
|
|
None,
|
|
)
|
|
destructor = destructor.specialize( # type: ignore[attribute-error]
|
|
devices=devices
|
|
)
|
|
destructor()
|
|
|
|
WrappedClass.__name__ = cls.__name__
|
|
WrappedClass.__doc__ = cls.__doc__
|
|
return WrappedClass
|