2025-08-11 12:24:21 +08:00

255 lines
8.4 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper for running multi-process tests."""
import os
import pathlib
import re
import signal
import subprocess
import time
from absl import app
from absl import flags
import jax
from jax import config
from jax._src import distributed
try:
import portpicker
except ImportError:
portpicker = None
from absl.testing import absltest
from jax._src import test_util as jtu
_NUM_PROCESSES = flags.DEFINE_integer(
"num_processes", None, "Number of processes to use."
)
_GPUS_PER_PROCESS = flags.DEFINE_integer(
"gpus_per_process",
0,
"Number of GPUs per worker process.",
)
_MULTIPROCESS_TEST_WORKER_ID = flags.DEFINE_integer(
"multiprocess_test_worker_id",
-1,
"Worker id. Set by main test process; should not be set by users.",
)
_MULTIPROCESS_TEST_CONTROLLER_ADDRESS = flags.DEFINE_string(
"multiprocess_test_controller_address",
"",
"Address of the JAX controller. Set by the main test process; should not be"
" set by users.",
)
expect_failures_with_regex = None
def main():
config.config_with_absl()
app.run(_main)
class GracefulKiller:
"""Add a signal handler that sets a flag if SIGINT or SIGTERM are caught."""
# From https://stackoverflow.com/a/31464349
kill_now = False
def __init__(self):
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, sig_num, unused_stack_frame):
print(f"Caught signal: {signal.Signals(sig_num).name} ({sig_num})")
self.kill_now = True
def _main(argv):
if _MULTIPROCESS_TEST_WORKER_ID.value >= 0:
jax.distributed.initialize(
_MULTIPROCESS_TEST_CONTROLLER_ADDRESS.value,
num_processes=_NUM_PROCESSES.value,
process_id=_MULTIPROCESS_TEST_WORKER_ID.value,
initialization_timeout=10,
)
absltest.main(testLoader=jtu.JaxTestLoader())
if not argv[0].endswith(".py"): # Skip the interpreter path if present.
argv = argv[1:]
num_processes = _NUM_PROCESSES.value
if num_processes is None:
raise ValueError("num_processes must be set")
gpus_per_process = _GPUS_PER_PROCESS.value
if portpicker is None:
jax_port = 9876
else:
jax_port = portpicker.pick_unused_port()
subprocesses = []
output_filenames = []
output_files = []
for i in range(num_processes):
env = os.environ.copy()
args = [
"/proc/self/exe",
*argv,
f"--num_processes={num_processes}",
f"--multiprocess_test_worker_id={i}",
f"--multiprocess_test_controller_address=localhost:{jax_port}",
"--logtostderr",
]
if gpus_per_process > 0:
gpus = range(i * gpus_per_process, (i + 1) * gpus_per_process)
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpus))
undeclared_outputs = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp")
stdout_name = f"{undeclared_outputs}/jax_{i}_stdout.log"
stderr_name = f"{undeclared_outputs}/jax_{i}_stderr.log"
stdout = open(stdout_name, "wb")
stderr = open(stderr_name, "wb")
print(f"Launching process {i}:")
print(f" stdout: {stdout_name}")
print(f" stderr: {stderr_name}")
proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr)
subprocesses.append(proc)
output_filenames.append((stdout_name, stderr_name))
output_files.append((stdout, stderr))
print(" All launched, running ".center(80, "="), flush=True)
# Wait for all the children to finish or for a SIGTERM from bazel. If we get
# SIGTERM, we still want to collect their logs, so kill them and continue.
killer = GracefulKiller()
running_procs = dict(enumerate(subprocesses))
while not killer.kill_now and running_procs:
time.sleep(0.1)
for i, proc in list(running_procs.items()):
if proc.poll() is not None:
print(f"Process {i} finished.", flush=True)
running_procs.pop(i)
if killer.kill_now and running_procs:
print("Caught termination, terminating remaining children.", flush=True)
# Send a SIGTERM to each child process, to let it know it should terminate.
for i, proc in running_procs.items():
proc.terminate()
print(f"Process {i} terminated.", flush=True)
# We give the child process(es) a few seconds for their own cleanup, and
# keep the rest (up to 15s) for copying the children logs into our own.
time.sleep(5)
# Send a SIGKILL (a "hard" kill) to each child process. This is CRITICAL:
# without it, this process may end up waiting a long time on the proc.wait()
# below, and never get to saving the children logs, making test timeouts
# very hard to debug.
for i, proc in running_procs.items():
proc.kill()
print(f"Process {i} killed.")
print("Killed all child processes.", flush=True)
retvals = []
stdouts = []
stderrs = []
for proc, fds, (stdout, stderr) in zip(
subprocesses, output_files, output_filenames
):
retvals.append(proc.wait())
for fd in fds:
fd.close()
stdouts.append(pathlib.Path(stdout).read_text(errors="replace"))
stderrs.append(pathlib.Path(stderr).read_text(errors="replace"))
print(" All finished ".center(80, "="), flush=True)
print(" Summary ".center(80, "="))
for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)):
m = re.search(r"Ran \d+ tests? in [\d.]+s\n\n.*", stderr, re.MULTILINE)
result = m.group().replace("\n\n", "; ") if m else "Test crashed?"
print(
f"Process {i}, ret: {retval}, len(stdout): {len(stdout)}, "
f"len(stderr): {len(stderr)}; {result}"
)
print(" Detailed logs ".center(80, "="))
for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)):
print(f" Process {i}: return code: {retval} ".center(80, "="))
if stdout:
print(f" Process {i} stdout ".center(80, "-"))
print(stdout)
if stderr:
print(f" Process {i} stderr ".center(80, "-"))
print(stderr)
print(" Done detailed logs ".center(80, "="), flush=True)
for i, (retval, stderr) in enumerate(zip(retvals, stderrs)):
if retval != 0:
if expect_failures_with_regex is not None:
assert re.search(
expect_failures_with_regex, stderr
), f"process {i} failed, expected regex: {expect_failures_with_regex}"
else:
assert retval == 0, f"process {i} failed, return value: {retval}"
class MultiProcessTest(absltest.TestCase):
def setUp(self):
"""Start tests together."""
super().setUp()
assert jax.process_count() == _NUM_PROCESSES.value, (
jax.process_count(),
_NUM_PROCESSES.value,
)
# Make sure all processes are at the same test case.
client = distributed.global_state.client
try:
client.wait_at_barrier(self._testMethodName + "_start", 10000)
except jax.errors.JaxRuntimeError as e:
msg, *_ = e.args
if msg.startswith("DEADLINE_EXCEEDED"):
raise RuntimeError(
f"Init or some test executed earlier than {self._testMethodName} "
"failed. Check logs from earlier tests to debug further. We "
"recommend debugging that specific failed test with "
"`--test_filter` before running the full test suite again."
) from e
def tearDown(self):
"""End tests together."""
client = distributed.global_state.client
# Ensure a shared fate for tests where a subset of processes run different
# test assertions (i.e. some processes may pass and some processes fail -
# but the overall test should fail).
try:
client.wait_at_barrier(self._testMethodName + "_end", 10000)
except jax.errors.JaxRuntimeError as e:
msg, *_ = e.args
if msg.startswith("DEADLINE_EXCEEDED"):
raise RuntimeError(
f"Test {self._testMethodName} failed in another process. We "
"recommend debugging that specific failed test with "
"`--test_filter` before running the full test suite again."
) from e
super().tearDown()