# Copyright 2024 The JAX Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import contextlib import itertools import json import math from typing import Callable, ParamSpec, TypeAlias, TypeVar import warnings import jax from jax._src import stages from jax._src.lib import xla_client import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import scf import numpy as np from .utils import * # noqa: F403 try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib except ImportError: has_registrations = False else: # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36. has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations") if has_registrations: for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): xla_client.register_custom_call_target( name, handler, platform="CUDA", api_version=1 ) # ruff: noqa: F405 # mypy: ignore-errors T = TypeVar("T") P = ParamSpec("P") def _event_record(args, *, copy_before): flat_args, treedef = jax.tree.flatten(args) event, *flat_outs = jax.ffi.ffi_call( "mgpu_event_record", result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), input_output_aliases={i: i + 1 for i in range(len(flat_args))}, )(*flat_args, copy_before=copy_before) return event, treedef.unflatten(flat_outs) def _event_elapsed(start_event, end_event): return jax.ffi.ffi_call( "mgpu_event_elapsed", result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), )(start_event, end_event) def _measure_events( f: Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> tuple[T, float]: if not has_registrations: raise RuntimeError( "This function requires jaxlib >=0.4.36 with CUDA support." ) if not (args or kwargs): # We require at least one argument and at least one output to ensure # that there is a data dependency between `_event_record` calls in # the resulting HLO program. raise ValueError("Can only measure functions with arguments") @jax.jit def run(*args, **kwargs): start_event, (args, kwargs) = _event_record( (args, kwargs), copy_before=True ) end_event, outs = _event_record(f(*args, **kwargs), copy_before=False) if jax.tree.structure(outs).num_leaves == 0: raise ValueError("Can only measure functions with at least one output") return outs, _event_elapsed(start_event, end_event) jax.block_until_ready(run(*args, **kwargs)) # Warmup. outs, elapsed = run(*args, **kwargs) return outs, float(elapsed) Timings: TypeAlias = list[tuple[str, float]] | float | None @dataclasses.dataclass(frozen=True, kw_only=True) class Cupti: """CUPTI-based profiler.""" # If `True`, detach CUPTI from the process after measurement. finalize: bool = True def measure( self, f: Callable[P, T], *, aggregate: bool = True ) -> Callable[P, tuple[T, Timings]]: if not isinstance(f, (stages.Wrapped, stages.Compiled)): f = jax.jit(f) def wrapper(*args: P.args, **kwargs: P.kwargs): jax.block_until_ready(f(*args, **kwargs)) # Warmup. ext = mosaic_gpu_lib._mosaic_gpu_ext ext._cupti_init() try: results = jax.block_until_ready(f(*args, **kwargs)) finally: timings = ext._cupti_get_timings(self.finalize) if not timings: return results, None elif aggregate: return results, sum(item[1] for item in timings) else: return results, timings return wrapper def measure( f: Callable[P, T], *, mode: str = "events", aggregate: bool = True ) -> Callable[P, tuple[T, Timings]]: """Sets up a function ``f`` for profiling on GPU. ``measure`` is a higher-order function that augments the argument ``f`` to return GPU runtime in milliseconds, in addition to its proper outputs. Args: f: The function to measure. It must accept at least one argument and return at least one output to be measurable. mode: The mode of operation. Possible values are: - "cupti", for CUPTI-based profiling. - "events", for CUDA events-based profiling. The two modes use different measurement methodologies and should not be treated as interchangeable backends. See the Notes section for important discussion. aggregate: Whether to report an aggregate runtime. When ``False`` (only supported by ``mode="cupti"``), the per-kernel timings are returned as a list of tuples ``(, )``. Returns: A new function ``g`` that returns the measured GPU runtime as its last additional output. Otherwise ``g`` accepts the same inputs and returns the same outputs as ``f``. Notes: `CUPTI (CUDA Profiling Tools Interface) `_ is a high-accuracy, high-precision profiling and tracing API, used in particular by Nsight Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``, device (GPU) execution runtimes are recorded for each kernel launched during the execution of the function. In that mode, setting ``aggregate=True`` will sum the individual kernel runtimes to arrive at an aggregate measurement. The "gaps" between the kernels when the device is idle are not included in the aggregate. The CUPTI API only allows a single "subscriber". This means that the CUPTI-based profiler will fail when the program is run using tools that make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or Nsight Compute. ``mode="events"`` uses a different approach: a CUDA event is recorded before and after the function ``f`` is executed. The reported runtime is the time elapsed between the two events. In particular, included in the measurement are: - any potential "gaps" between the kernels when the device is idle - any potential "gaps" between the "before" event and the start of the first kernel, or between the end of the last kernel and the "after" event In an attempt to minimize the second effect, internally the events-based implementation may execute ``f`` more than once to "warm up" and exclude compilation time from the measurement. """ # fmt: skip match mode: case "cupti": return Cupti().measure(f, aggregate=aggregate) case "events": if not aggregate: raise ValueError(f"{aggregate=} is not supported with {mode=}") def measure_events_wrapper(*args, **kwargs): return _measure_events(f, *args, **kwargs) return measure_events_wrapper case _: raise ValueError(f"Unrecognized profiler mode {mode}") class ProfilerSpec: ENTER = 0 EXIT = 1 << 31 def __init__(self, entries_per_warpgroup: int): self.entries_per_warpgroup = entries_per_warpgroup self.interned_names = {} def _num_warpgroups( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> int: if math.prod(block) % WARPGROUP_SIZE: raise ValueError("Block size is not a multiple of warpgroup size") return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE def mlir_buffer_type( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> ir.Type: return ir.MemRefType.get( (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), ir.IntegerType.get_signless(32), ) def jax_buffer_type( self, grid: tuple[int, ...], block: tuple[int, ...] ) -> ir.Type: return jax.ShapeDtypeStruct( (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,), jnp.uint32, ) def smem_i32_elements(self, block: tuple[int, ...]): num_warpgroups = self._num_warpgroups((), block) return int(num_warpgroups * self.entries_per_warpgroup) def smem_bytes(self, block: tuple[int, ...]): bytes_per_entry = 4 return self.smem_i32_elements(block) * bytes_per_entry def intern_name(self, name: str) -> int: if (name_id := self.interned_names.get(name, None)) is not None: return name_id name_id = self.interned_names[name] = len(self.interned_names) if name_id & self.EXIT: raise RuntimeError("Allocated too many names") return name_id def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]): buffer = np.asarray(buffer) num_blocks = math.prod(grid) warpgroups_per_block = self._num_warpgroups((), block) entries = buffer.reshape( num_blocks, warpgroups_per_block, self.entries_per_warpgroup ) start_times = entries[..., 0] sm_ids = entries[..., 1] entries_used = entries[..., 2] if np.any(entries_used > self.entries_per_warpgroup - 2): raise RuntimeError("Insufficient space to capture a full trace") traces = entries[..., 3:] # Estimate the overhead of profiling. time_events = traces[:, :, 1::2] valid_times_mask = np.arange(traces.shape[-1])[1::2] < (entries_used[..., None] - 3) # 12 cycles is a ballpark estimate for H100 profiling_overhead = (time_events[:, :, 1:] - time_events[:, :, :-1]).min( where=valid_times_mask[:, :, 1:], initial=12 ) profiling_overhead = max(0, profiling_overhead - 1) unintern = {v: k for k, v in self.interned_names.items()} events = [] for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block): valid_entries = (entries_used[block_idx, wg_idx] - 3) local_clock_offset = None assert valid_entries % 2 == 0, valid_entries start_time = start_times[block_idx, wg_idx] block_events = [] last_time = float("-inf") for i in range(0, valid_entries, 2): tag = traces[block_idx, wg_idx, i] time = traces[block_idx, wg_idx, i + 1] if local_clock_offset is None: local_clock_offset = time time -= local_clock_offset time -= (i // 2) * profiling_overhead # Account for the overhead of profiling. if time < 0: break # Detect a timer wraparound name_id = tag begin = True if name_id & ProfilerSpec.EXIT: name_id = name_id ^ ProfilerSpec.EXIT begin = False name = unintern[name_id] if last_time >= time: if last_time - time > 10: warnings.warn( "Profiler clock went significantly backwards for event" f" {'start' if begin else 'end'} `{name}`: {last_time} ->" f" {time}" ) time = last_time + 1 last_time = time block_events.append({ "name": name, "ph": "B" if begin else "E", "ts": float(start_time + time) / 1e3, "pid": 1 + int(sm_ids[block_idx, wg_idx]), "tid": 1 + wg_idx + warpgroups_per_block * block_idx, }) else: # If we didn't break if block_events: events.append(block_events) events = sorted(events, key=lambda x: x[0]["ts"]) flat_events = list(itertools.chain.from_iterable(events)) return json.dump({"displayTimeUnit": "ns", "traceEvents": flat_events}, f) class OnDeviceProfiler: def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value): self.spec = spec self.start = globaltimer("low") i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() self.entries_per_wg = spec.entries_per_warpgroup wg_idx = warpgroup_idx(sync=False) self.smem_buffer = memref_slice( smem_buffer, ds( arith.index_cast( index, arith.muli(wg_idx, c(self.entries_per_wg, i32)) ), self.entries_per_wg, ), ) self.smem_buffer_ptr = memref_ptr(self.smem_buffer, memory_space=3) self.gmem_buffer = gmem_buffer self.is_profiling_thread = arith.cmpi( arith.CmpIPredicate.eq, arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)), c(0, i32), ) # Hopefully mem2reg will remove the allocation. self.offset = memref.alloca(ir.MemRefType.get((), i32), [], []) memref.store(c(0, i32), self.offset, []) @contextlib.contextmanager def record(self, name: str): i32 = ir.IntegerType.get_signless(32) name_id = self.spec.intern_name(name) def store(modifier): cur = memref.load(self.offset, []) i64 = ir.IntegerType.get_signless(64) base_addr = arith.addi( llvm.ptrtoint(i64, self.smem_buffer_ptr), arith.extui(i64, arith.muli(cur, c(4, i32))), ) llvm.inline_asm( ir.Type.parse("!llvm.void"), [self.is_profiling_thread, base_addr, c(modifier | name_id, i32)], """ @$0 st.shared.v2.u32 [$1], {$2, %clock}; """, "b,l,r", has_side_effects=True, ) memref.store( arith.addi(cur, c(2, cur.type)), self.offset, [], ) store(ProfilerSpec.ENTER) yield store(ProfilerSpec.EXIT) def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]): index = ir.IndexType.get() i32 = ir.IntegerType.get_signless(32) gpu.barrier() # Make sure all warpgroups are done. block_idx = c(0, index) for dim in gpu.Dimension: # pytype: disable=wrong-arg-types block_idx = arith.addi( arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim) ) wg_idx = warpgroup_idx(sync=False) wg_per_block = math.prod(block) // WARPGROUP_SIZE global_wg_idx = arith.addi( arith.muli(block_idx, c(wg_per_block, index)), arith.index_cast(index, wg_idx), ) start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index)) wg_gmem_buffer = memref.subview( self.gmem_buffer, [start_offset], [self.entries_per_wg], [1], result_type=ir.Type.parse( f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>" ), ) thread_in_wg = arith.remui(thread_idx(), c(128, i32)) if_first = scf.IfOp( arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32)) ) with ir.InsertionPoint(if_first.then_block): memref.store(self.start, wg_gmem_buffer, [c(0, index)]) memref.store(smid(), wg_gmem_buffer, [c(1, index)]) memref.store( arith.addi(memref.load(self.offset, []), c(3, i32)), wg_gmem_buffer, [c(2, index)], ) for_op = scf.ForOp( c(0, index), c(self.entries_per_wg - 3, index), c(1, index), ) with ir.InsertionPoint(for_op.body): x = memref.load(self.smem_buffer, [for_op.induction_variable]) memref.store( x, wg_gmem_buffer, [arith.addi(for_op.induction_variable, c(3, index))], ) scf.yield_([]) scf.yield_([])