433 lines
15 KiB
Python
433 lines
15 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import contextlib
|
|
import itertools
|
|
import json
|
|
import math
|
|
from typing import Callable, ParamSpec, TypeAlias, TypeVar
|
|
import warnings
|
|
|
|
import jax
|
|
from jax._src import stages
|
|
from jax._src.lib import xla_client
|
|
import jax.numpy as jnp
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import gpu
|
|
from jaxlib.mlir.dialects import memref
|
|
from jaxlib.mlir.dialects import scf
|
|
import numpy as np
|
|
|
|
from .utils import * # noqa: F403
|
|
|
|
try:
|
|
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
|
except ImportError:
|
|
has_registrations = False
|
|
else:
|
|
# TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36.
|
|
has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations")
|
|
if has_registrations:
|
|
for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
|
|
xla_client.register_custom_call_target(
|
|
name, handler, platform="CUDA", api_version=1
|
|
)
|
|
|
|
# ruff: noqa: F405
|
|
# mypy: ignore-errors
|
|
|
|
T = TypeVar("T")
|
|
P = ParamSpec("P")
|
|
|
|
def _event_record(args, *, copy_before):
|
|
flat_args, treedef = jax.tree.flatten(args)
|
|
event, *flat_outs = jax.ffi.ffi_call(
|
|
"mgpu_event_record",
|
|
result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args),
|
|
input_output_aliases={i: i + 1 for i in range(len(flat_args))},
|
|
)(*flat_args, copy_before=copy_before)
|
|
return event, treedef.unflatten(flat_outs)
|
|
|
|
|
|
def _event_elapsed(start_event, end_event):
|
|
return jax.ffi.ffi_call(
|
|
"mgpu_event_elapsed",
|
|
result_shape_dtypes=jax.core.ShapedArray((), jnp.float32),
|
|
)(start_event, end_event)
|
|
|
|
|
|
def _measure_events(
|
|
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
|
) -> tuple[T, float]:
|
|
if not has_registrations:
|
|
raise RuntimeError(
|
|
"This function requires jaxlib >=0.4.36 with CUDA support."
|
|
)
|
|
|
|
if not (args or kwargs):
|
|
# We require at least one argument and at least one output to ensure
|
|
# that there is a data dependency between `_event_record` calls in
|
|
# the resulting HLO program.
|
|
raise ValueError("Can only measure functions with arguments")
|
|
|
|
@jax.jit
|
|
def run(*args, **kwargs):
|
|
start_event, (args, kwargs) = _event_record(
|
|
(args, kwargs), copy_before=True
|
|
)
|
|
end_event, outs = _event_record(f(*args, **kwargs), copy_before=False)
|
|
if jax.tree.structure(outs).num_leaves == 0:
|
|
raise ValueError("Can only measure functions with at least one output")
|
|
return outs, _event_elapsed(start_event, end_event)
|
|
|
|
jax.block_until_ready(run(*args, **kwargs)) # Warmup.
|
|
outs, elapsed = run(*args, **kwargs)
|
|
return outs, float(elapsed)
|
|
|
|
|
|
Timings: TypeAlias = list[tuple[str, float]] | float | None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
class Cupti:
|
|
"""CUPTI-based profiler."""
|
|
|
|
# If `True`, detach CUPTI from the process after measurement.
|
|
finalize: bool = True
|
|
|
|
def measure(
|
|
self, f: Callable[P, T], *, aggregate: bool = True
|
|
) -> Callable[P, tuple[T, Timings]]:
|
|
if not isinstance(f, (stages.Wrapped, stages.Compiled)):
|
|
f = jax.jit(f)
|
|
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
|
jax.block_until_ready(f(*args, **kwargs)) # Warmup.
|
|
ext = mosaic_gpu_lib._mosaic_gpu_ext
|
|
ext._cupti_init()
|
|
try:
|
|
results = jax.block_until_ready(f(*args, **kwargs))
|
|
finally:
|
|
timings = ext._cupti_get_timings(self.finalize)
|
|
|
|
if not timings:
|
|
return results, None
|
|
elif aggregate:
|
|
return results, sum(item[1] for item in timings)
|
|
else:
|
|
return results, timings
|
|
|
|
return wrapper
|
|
|
|
|
|
def measure(
|
|
f: Callable[P, T], *, mode: str = "events", aggregate: bool = True
|
|
) -> Callable[P, tuple[T, Timings]]:
|
|
"""Sets up a function ``f`` for profiling on GPU.
|
|
|
|
``measure`` is a higher-order function that augments the argument ``f`` to
|
|
return GPU runtime in milliseconds, in addition to its proper outputs.
|
|
|
|
Args:
|
|
f: The function to measure. It must accept at least one argument and return
|
|
at least one output to be measurable.
|
|
mode: The mode of operation. Possible values are:
|
|
|
|
- "cupti", for CUPTI-based profiling.
|
|
- "events", for CUDA events-based profiling.
|
|
|
|
The two modes use different measurement methodologies and should not be
|
|
treated as interchangeable backends. See the Notes section for important
|
|
discussion.
|
|
aggregate: Whether to report an aggregate runtime. When ``False`` (only
|
|
supported by ``mode="cupti"``), the per-kernel timings are returned as a
|
|
list of tuples ``(<kernel name>, <runtime in ms>)``.
|
|
|
|
Returns:
|
|
A new function ``g`` that returns the measured GPU runtime as its last
|
|
additional output. Otherwise ``g`` accepts the same inputs and returns the
|
|
same outputs as ``f``.
|
|
|
|
Notes:
|
|
`CUPTI (CUDA Profiling Tools Interface)
|
|
<https://docs.nvidia.com/cupti/index.html>`_ is a high-accuracy,
|
|
high-precision profiling and tracing API, used in particular by Nsight
|
|
Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``,
|
|
device (GPU) execution runtimes are recorded for each kernel launched
|
|
during the execution of the function. In that mode, setting
|
|
``aggregate=True`` will sum the individual kernel runtimes to arrive at an
|
|
aggregate measurement. The "gaps" between the kernels when the device is
|
|
idle are not included in the aggregate.
|
|
|
|
The CUPTI API only allows a single "subscriber". This means that the
|
|
CUPTI-based profiler will fail when the program is run using tools that
|
|
make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or
|
|
Nsight Compute.
|
|
|
|
``mode="events"`` uses a different approach: a CUDA event is recorded
|
|
before and after the function ``f`` is executed. The reported runtime is
|
|
the time elapsed between the two events. In particular, included in the
|
|
measurement are:
|
|
|
|
- any potential "gaps" between the kernels when the device is idle
|
|
- any potential "gaps" between the "before" event and the start of the
|
|
first kernel, or between the end of the last kernel and the "after" event
|
|
|
|
In an attempt to minimize the second effect, internally the events-based
|
|
implementation may execute ``f`` more than once to "warm up" and exclude
|
|
compilation time from the measurement.
|
|
""" # fmt: skip
|
|
match mode:
|
|
case "cupti":
|
|
return Cupti().measure(f, aggregate=aggregate)
|
|
case "events":
|
|
if not aggregate:
|
|
raise ValueError(f"{aggregate=} is not supported with {mode=}")
|
|
def measure_events_wrapper(*args, **kwargs):
|
|
return _measure_events(f, *args, **kwargs)
|
|
return measure_events_wrapper
|
|
case _:
|
|
raise ValueError(f"Unrecognized profiler mode {mode}")
|
|
|
|
|
|
class ProfilerSpec:
|
|
ENTER = 0
|
|
EXIT = 1 << 31
|
|
|
|
def __init__(self, entries_per_warpgroup: int):
|
|
self.entries_per_warpgroup = entries_per_warpgroup
|
|
self.interned_names = {}
|
|
|
|
def _num_warpgroups(
|
|
self, grid: tuple[int, ...], block: tuple[int, ...]
|
|
) -> int:
|
|
if math.prod(block) % WARPGROUP_SIZE:
|
|
raise ValueError("Block size is not a multiple of warpgroup size")
|
|
return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE
|
|
|
|
def mlir_buffer_type(
|
|
self, grid: tuple[int, ...], block: tuple[int, ...]
|
|
) -> ir.Type:
|
|
return ir.MemRefType.get(
|
|
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
|
|
ir.IntegerType.get_signless(32),
|
|
)
|
|
|
|
def jax_buffer_type(
|
|
self, grid: tuple[int, ...], block: tuple[int, ...]
|
|
) -> ir.Type:
|
|
return jax.ShapeDtypeStruct(
|
|
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
|
|
jnp.uint32,
|
|
)
|
|
|
|
def smem_i32_elements(self, block: tuple[int, ...]):
|
|
num_warpgroups = self._num_warpgroups((), block)
|
|
return int(num_warpgroups * self.entries_per_warpgroup)
|
|
|
|
def smem_bytes(self, block: tuple[int, ...]):
|
|
bytes_per_entry = 4
|
|
return self.smem_i32_elements(block) * bytes_per_entry
|
|
|
|
def intern_name(self, name: str) -> int:
|
|
if (name_id := self.interned_names.get(name, None)) is not None:
|
|
return name_id
|
|
name_id = self.interned_names[name] = len(self.interned_names)
|
|
if name_id & self.EXIT:
|
|
raise RuntimeError("Allocated too many names")
|
|
return name_id
|
|
|
|
def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]):
|
|
buffer = np.asarray(buffer)
|
|
num_blocks = math.prod(grid)
|
|
warpgroups_per_block = self._num_warpgroups((), block)
|
|
entries = buffer.reshape(
|
|
num_blocks, warpgroups_per_block, self.entries_per_warpgroup
|
|
)
|
|
start_times = entries[..., 0]
|
|
sm_ids = entries[..., 1]
|
|
entries_used = entries[..., 2]
|
|
if np.any(entries_used > self.entries_per_warpgroup - 2):
|
|
raise RuntimeError("Insufficient space to capture a full trace")
|
|
traces = entries[..., 3:]
|
|
|
|
# Estimate the overhead of profiling.
|
|
time_events = traces[:, :, 1::2]
|
|
valid_times_mask = np.arange(traces.shape[-1])[1::2] < (entries_used[..., None] - 3)
|
|
# 12 cycles is a ballpark estimate for H100
|
|
profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min(
|
|
where=valid_times_mask[:, :, 1:], initial=12
|
|
)
|
|
profiling_overhead = max(0, profiling_overhead - 1)
|
|
|
|
unintern = {v: k for k, v in self.interned_names.items()}
|
|
events = []
|
|
for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block):
|
|
valid_entries = (entries_used[block_idx, wg_idx] - 3)
|
|
local_clock_offset = None
|
|
assert valid_entries % 2 == 0, valid_entries
|
|
start_time = start_times[block_idx, wg_idx]
|
|
block_events = []
|
|
last_time = float("-inf")
|
|
for i in range(0, valid_entries, 2):
|
|
tag = traces[block_idx, wg_idx, i]
|
|
time = traces[block_idx, wg_idx, i + 1]
|
|
if local_clock_offset is None:
|
|
local_clock_offset = time
|
|
time -= local_clock_offset
|
|
time -= (i // 2) * profiling_overhead # Account for the overhead of profiling.
|
|
if time < 0:
|
|
break # Detect a timer wraparound
|
|
name_id = tag
|
|
begin = True
|
|
if name_id & ProfilerSpec.EXIT:
|
|
name_id = name_id ^ ProfilerSpec.EXIT
|
|
begin = False
|
|
name = unintern[name_id]
|
|
if last_time >= time:
|
|
if last_time - time > 10:
|
|
warnings.warn(
|
|
"Profiler clock went significantly backwards for event"
|
|
f" {'start' if begin else 'end'} `{name}`: {last_time} ->"
|
|
f" {time}"
|
|
)
|
|
time = last_time + 1
|
|
last_time = time
|
|
block_events.append({
|
|
"name": name,
|
|
"ph": "B" if begin else "E",
|
|
"ts": float(start_time + time) / 1e3,
|
|
"pid": 1 + int(sm_ids[block_idx, wg_idx]),
|
|
"tid": 1 + wg_idx + warpgroups_per_block * block_idx,
|
|
})
|
|
else: # If we didn't break
|
|
if block_events:
|
|
events.append(block_events)
|
|
events = sorted(events, key=lambda x: x[0]["ts"])
|
|
flat_events = list(itertools.chain.from_iterable(events))
|
|
return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f)
|
|
|
|
|
|
class OnDeviceProfiler:
|
|
|
|
def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value):
|
|
self.spec = spec
|
|
self.start = globaltimer("low")
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
index = ir.IndexType.get()
|
|
self.entries_per_wg = spec.entries_per_warpgroup
|
|
wg_idx = warpgroup_idx(sync=False)
|
|
self.smem_buffer = memref_slice(
|
|
smem_buffer,
|
|
ds(
|
|
arith.index_cast(
|
|
index, arith.muli(wg_idx, c(self.entries_per_wg, i32))
|
|
),
|
|
self.entries_per_wg,
|
|
),
|
|
)
|
|
self.smem_buffer_ptr = memref_ptr(self.smem_buffer, memory_space=3)
|
|
self.gmem_buffer = gmem_buffer
|
|
self.is_profiling_thread = arith.cmpi(
|
|
arith.CmpIPredicate.eq,
|
|
arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)),
|
|
c(0, i32),
|
|
)
|
|
# Hopefully mem2reg will remove the allocation.
|
|
self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
|
|
memref.store(c(0, i32), self.offset, [])
|
|
|
|
@contextlib.contextmanager
|
|
def record(self, name: str):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
name_id = self.spec.intern_name(name)
|
|
def store(modifier):
|
|
cur = memref.load(self.offset, [])
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
base_addr = arith.addi(
|
|
llvm.ptrtoint(i64, self.smem_buffer_ptr),
|
|
arith.extui(i64, arith.muli(cur, c(4, i32))),
|
|
)
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[self.is_profiling_thread, base_addr, c(modifier | name_id, i32)],
|
|
"""
|
|
@$0 st.shared.v2.u32 [$1], {$2, %clock};
|
|
""",
|
|
"b,l,r",
|
|
has_side_effects=True,
|
|
)
|
|
memref.store(
|
|
arith.addi(cur, c(2, cur.type)),
|
|
self.offset,
|
|
[],
|
|
)
|
|
store(ProfilerSpec.ENTER)
|
|
yield
|
|
store(ProfilerSpec.EXIT)
|
|
|
|
def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]):
|
|
index = ir.IndexType.get()
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
|
|
gpu.barrier() # Make sure all warpgroups are done.
|
|
|
|
block_idx = c(0, index)
|
|
for dim in gpu.Dimension: # pytype: disable=wrong-arg-types
|
|
block_idx = arith.addi(
|
|
arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
|
|
)
|
|
wg_idx = warpgroup_idx(sync=False)
|
|
wg_per_block = math.prod(block) // WARPGROUP_SIZE
|
|
global_wg_idx = arith.addi(
|
|
arith.muli(block_idx, c(wg_per_block, index)),
|
|
arith.index_cast(index, wg_idx),
|
|
)
|
|
start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index))
|
|
wg_gmem_buffer = memref.subview(
|
|
self.gmem_buffer, [start_offset], [self.entries_per_wg], [1],
|
|
result_type=ir.Type.parse(
|
|
f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>"
|
|
),
|
|
)
|
|
thread_in_wg = arith.remui(thread_idx(), c(128, i32))
|
|
if_first = scf.IfOp(
|
|
arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32))
|
|
)
|
|
with ir.InsertionPoint(if_first.then_block):
|
|
memref.store(self.start, wg_gmem_buffer, [c(0, index)])
|
|
memref.store(smid(), wg_gmem_buffer, [c(1, index)])
|
|
memref.store(
|
|
arith.addi(memref.load(self.offset, []), c(3, i32)),
|
|
wg_gmem_buffer,
|
|
[c(2, index)],
|
|
)
|
|
|
|
for_op = scf.ForOp(
|
|
c(0, index),
|
|
c(self.entries_per_wg - 3, index),
|
|
c(1, index),
|
|
)
|
|
with ir.InsertionPoint(for_op.body):
|
|
x = memref.load(self.smem_buffer, [for_op.induction_variable])
|
|
memref.store(
|
|
x,
|
|
wg_gmem_buffer,
|
|
[arith.addi(for_op.induction_variable, c(3, index))],
|
|
)
|
|
scf.yield_([])
|
|
scf.yield_([])
|