255 lines
8.4 KiB
Python
255 lines
8.4 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Helper for running multi-process tests."""
|
|
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import signal
|
|
import subprocess
|
|
import time
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
import jax
|
|
from jax import config
|
|
from jax._src import distributed
|
|
try:
|
|
import portpicker
|
|
except ImportError:
|
|
portpicker = None
|
|
|
|
from absl.testing import absltest
|
|
from jax._src import test_util as jtu
|
|
|
|
|
|
_NUM_PROCESSES = flags.DEFINE_integer(
|
|
"num_processes", None, "Number of processes to use."
|
|
)
|
|
|
|
_GPUS_PER_PROCESS = flags.DEFINE_integer(
|
|
"gpus_per_process",
|
|
0,
|
|
"Number of GPUs per worker process.",
|
|
)
|
|
|
|
_MULTIPROCESS_TEST_WORKER_ID = flags.DEFINE_integer(
|
|
"multiprocess_test_worker_id",
|
|
-1,
|
|
"Worker id. Set by main test process; should not be set by users.",
|
|
)
|
|
|
|
_MULTIPROCESS_TEST_CONTROLLER_ADDRESS = flags.DEFINE_string(
|
|
"multiprocess_test_controller_address",
|
|
"",
|
|
"Address of the JAX controller. Set by the main test process; should not be"
|
|
" set by users.",
|
|
)
|
|
|
|
|
|
expect_failures_with_regex = None
|
|
|
|
|
|
def main():
|
|
config.config_with_absl()
|
|
app.run(_main)
|
|
|
|
|
|
class GracefulKiller:
|
|
"""Add a signal handler that sets a flag if SIGINT or SIGTERM are caught."""
|
|
|
|
# From https://stackoverflow.com/a/31464349
|
|
kill_now = False
|
|
|
|
def __init__(self):
|
|
signal.signal(signal.SIGINT, self.exit_gracefully)
|
|
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
|
|
|
def exit_gracefully(self, sig_num, unused_stack_frame):
|
|
print(f"Caught signal: {signal.Signals(sig_num).name} ({sig_num})")
|
|
self.kill_now = True
|
|
|
|
|
|
def _main(argv):
|
|
if _MULTIPROCESS_TEST_WORKER_ID.value >= 0:
|
|
jax.distributed.initialize(
|
|
_MULTIPROCESS_TEST_CONTROLLER_ADDRESS.value,
|
|
num_processes=_NUM_PROCESSES.value,
|
|
process_id=_MULTIPROCESS_TEST_WORKER_ID.value,
|
|
initialization_timeout=10,
|
|
)
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
|
|
if not argv[0].endswith(".py"): # Skip the interpreter path if present.
|
|
argv = argv[1:]
|
|
|
|
num_processes = _NUM_PROCESSES.value
|
|
if num_processes is None:
|
|
raise ValueError("num_processes must be set")
|
|
gpus_per_process = _GPUS_PER_PROCESS.value
|
|
if portpicker is None:
|
|
jax_port = 9876
|
|
else:
|
|
jax_port = portpicker.pick_unused_port()
|
|
subprocesses = []
|
|
output_filenames = []
|
|
output_files = []
|
|
for i in range(num_processes):
|
|
env = os.environ.copy()
|
|
|
|
args = [
|
|
"/proc/self/exe",
|
|
*argv,
|
|
f"--num_processes={num_processes}",
|
|
f"--multiprocess_test_worker_id={i}",
|
|
f"--multiprocess_test_controller_address=localhost:{jax_port}",
|
|
"--logtostderr",
|
|
]
|
|
|
|
if gpus_per_process > 0:
|
|
gpus = range(i * gpus_per_process, (i + 1) * gpus_per_process)
|
|
env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpus))
|
|
|
|
undeclared_outputs = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp")
|
|
stdout_name = f"{undeclared_outputs}/jax_{i}_stdout.log"
|
|
stderr_name = f"{undeclared_outputs}/jax_{i}_stderr.log"
|
|
stdout = open(stdout_name, "wb")
|
|
stderr = open(stderr_name, "wb")
|
|
print(f"Launching process {i}:")
|
|
print(f" stdout: {stdout_name}")
|
|
print(f" stderr: {stderr_name}")
|
|
proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr)
|
|
subprocesses.append(proc)
|
|
output_filenames.append((stdout_name, stderr_name))
|
|
output_files.append((stdout, stderr))
|
|
|
|
print(" All launched, running ".center(80, "="), flush=True)
|
|
|
|
# Wait for all the children to finish or for a SIGTERM from bazel. If we get
|
|
# SIGTERM, we still want to collect their logs, so kill them and continue.
|
|
killer = GracefulKiller()
|
|
running_procs = dict(enumerate(subprocesses))
|
|
while not killer.kill_now and running_procs:
|
|
time.sleep(0.1)
|
|
for i, proc in list(running_procs.items()):
|
|
if proc.poll() is not None:
|
|
print(f"Process {i} finished.", flush=True)
|
|
running_procs.pop(i)
|
|
if killer.kill_now and running_procs:
|
|
print("Caught termination, terminating remaining children.", flush=True)
|
|
|
|
# Send a SIGTERM to each child process, to let it know it should terminate.
|
|
for i, proc in running_procs.items():
|
|
proc.terminate()
|
|
print(f"Process {i} terminated.", flush=True)
|
|
|
|
# We give the child process(es) a few seconds for their own cleanup, and
|
|
# keep the rest (up to 15s) for copying the children logs into our own.
|
|
time.sleep(5)
|
|
|
|
# Send a SIGKILL (a "hard" kill) to each child process. This is CRITICAL:
|
|
# without it, this process may end up waiting a long time on the proc.wait()
|
|
# below, and never get to saving the children logs, making test timeouts
|
|
# very hard to debug.
|
|
for i, proc in running_procs.items():
|
|
proc.kill()
|
|
print(f"Process {i} killed.")
|
|
print("Killed all child processes.", flush=True)
|
|
|
|
retvals = []
|
|
stdouts = []
|
|
stderrs = []
|
|
for proc, fds, (stdout, stderr) in zip(
|
|
subprocesses, output_files, output_filenames
|
|
):
|
|
retvals.append(proc.wait())
|
|
for fd in fds:
|
|
fd.close()
|
|
stdouts.append(pathlib.Path(stdout).read_text(errors="replace"))
|
|
stderrs.append(pathlib.Path(stderr).read_text(errors="replace"))
|
|
|
|
print(" All finished ".center(80, "="), flush=True)
|
|
|
|
print(" Summary ".center(80, "="))
|
|
for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)):
|
|
m = re.search(r"Ran \d+ tests? in [\d.]+s\n\n.*", stderr, re.MULTILINE)
|
|
result = m.group().replace("\n\n", "; ") if m else "Test crashed?"
|
|
print(
|
|
f"Process {i}, ret: {retval}, len(stdout): {len(stdout)}, "
|
|
f"len(stderr): {len(stderr)}; {result}"
|
|
)
|
|
|
|
print(" Detailed logs ".center(80, "="))
|
|
for i, (retval, stdout, stderr) in enumerate(zip(retvals, stdouts, stderrs)):
|
|
print(f" Process {i}: return code: {retval} ".center(80, "="))
|
|
if stdout:
|
|
print(f" Process {i} stdout ".center(80, "-"))
|
|
print(stdout)
|
|
if stderr:
|
|
print(f" Process {i} stderr ".center(80, "-"))
|
|
print(stderr)
|
|
|
|
print(" Done detailed logs ".center(80, "="), flush=True)
|
|
for i, (retval, stderr) in enumerate(zip(retvals, stderrs)):
|
|
if retval != 0:
|
|
if expect_failures_with_regex is not None:
|
|
assert re.search(
|
|
expect_failures_with_regex, stderr
|
|
), f"process {i} failed, expected regex: {expect_failures_with_regex}"
|
|
else:
|
|
assert retval == 0, f"process {i} failed, return value: {retval}"
|
|
|
|
|
|
class MultiProcessTest(absltest.TestCase):
|
|
|
|
def setUp(self):
|
|
"""Start tests together."""
|
|
super().setUp()
|
|
assert jax.process_count() == _NUM_PROCESSES.value, (
|
|
jax.process_count(),
|
|
_NUM_PROCESSES.value,
|
|
)
|
|
# Make sure all processes are at the same test case.
|
|
client = distributed.global_state.client
|
|
try:
|
|
client.wait_at_barrier(self._testMethodName + "_start", 10000)
|
|
except jax.errors.JaxRuntimeError as e:
|
|
msg, *_ = e.args
|
|
if msg.startswith("DEADLINE_EXCEEDED"):
|
|
raise RuntimeError(
|
|
f"Init or some test executed earlier than {self._testMethodName} "
|
|
"failed. Check logs from earlier tests to debug further. We "
|
|
"recommend debugging that specific failed test with "
|
|
"`--test_filter` before running the full test suite again."
|
|
) from e
|
|
|
|
def tearDown(self):
|
|
"""End tests together."""
|
|
client = distributed.global_state.client
|
|
# Ensure a shared fate for tests where a subset of processes run different
|
|
# test assertions (i.e. some processes may pass and some processes fail -
|
|
# but the overall test should fail).
|
|
try:
|
|
client.wait_at_barrier(self._testMethodName + "_end", 10000)
|
|
except jax.errors.JaxRuntimeError as e:
|
|
msg, *_ = e.args
|
|
if msg.startswith("DEADLINE_EXCEEDED"):
|
|
raise RuntimeError(
|
|
f"Test {self._testMethodName} failed in another process. We "
|
|
"recommend debugging that specific failed test with "
|
|
"`--test_filter` before running the full test suite again."
|
|
) from e
|
|
super().tearDown()
|