605 lines
20 KiB
Python
605 lines
20 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from collections import defaultdict
|
|
from dataclasses import replace
|
|
import itertools as it
|
|
from typing import Sequence
|
|
import numpy as np
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import core, util
|
|
from jax._src import ops
|
|
from jax._src import prng
|
|
from jax._src import random
|
|
from jax._src import shard_map
|
|
from jax._src.lax import (
|
|
ann,
|
|
convolution,
|
|
fft,
|
|
lax,
|
|
linalg,
|
|
parallel as lax_parallel,
|
|
slicing,
|
|
special,
|
|
windowed_reductions,
|
|
)
|
|
from jax.experimental import roofline
|
|
|
|
# One FMA (Fused Multiply Add) takes 2 flops to compute.
|
|
_FMA_FLOPS_FACTOR = 2
|
|
|
|
for prim in it.chain(
|
|
ad_util.__dict__.values(),
|
|
ann.__dict__.values(),
|
|
convolution.__dict__.values(),
|
|
fft.__dict__.values(),
|
|
lax.__dict__.values(),
|
|
linalg.__dict__.values(),
|
|
ops.__dict__.values(),
|
|
prng.__dict__.values(),
|
|
random.__dict__.values(),
|
|
shard_map.__dict__.values(),
|
|
slicing.__dict__.values(),
|
|
special.__dict__.values(),
|
|
windowed_reductions.__dict__.values(),
|
|
):
|
|
if isinstance(prim, core.Primitive):
|
|
roofline.register_standard_roofline(prim)
|
|
|
|
|
|
def _unary_p_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
|
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
|
return roofline.RooflineResult(
|
|
unfused_flops=x.size,
|
|
unfused_hbm_bytes=(
|
|
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
|
),
|
|
)
|
|
|
|
roofline.register_roofline(lax.abs_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.acos_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.asin_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.atan_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.ceil_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.conj_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.cos_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.cosh_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.exp_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.expm1_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.floor_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.imag_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.log_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.log1p_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.logistic_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.neg_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.not_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.real_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.round_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.sign_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.sin_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.sinh_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.square_p)(_unary_p_roofline)
|
|
roofline.register_roofline(lax.tan_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.digamma_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.erf_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.erfc_p)(_unary_p_roofline)
|
|
roofline.register_roofline(special.lgamma_p)(_unary_p_roofline)
|
|
|
|
roofline.register_standard_roofline(core.pvary_p)
|
|
|
|
def _binary_p_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
|
broadcasted_shape = [
|
|
max(l, r) for l, r in it.zip_longest(lhs.shape, rhs.shape, fillvalue=1)
|
|
]
|
|
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
|
return roofline.RooflineResult(
|
|
unfused_flops=int(np.prod(broadcasted_shape)),
|
|
unfused_hbm_bytes=(
|
|
lhs.dtype.itemsize * lhs.size
|
|
+ rhs.dtype.itemsize * rhs.size
|
|
+ out.dtype.itemsize * out.size
|
|
),
|
|
)
|
|
|
|
|
|
roofline.register_roofline(lax.add_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.sub_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.mul_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.div_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.rem_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.and_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.or_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.xor_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.gt_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.lt_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.ge_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.le_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.eq_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.ne_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.min_p)(_binary_p_roofline)
|
|
roofline.register_roofline(lax.max_p)(_binary_p_roofline)
|
|
|
|
|
|
@roofline.register_roofline(lax.dot_general_p)
|
|
def _dot_general_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
dimension_numbers: lax.DotDimensionNumbers,
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
|
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
|
(lhs_contract, _), (lhs_batch, _) = dimension_numbers
|
|
|
|
flops = (
|
|
_FMA_FLOPS_FACTOR
|
|
* lhs.size
|
|
* rhs.size
|
|
/ np.prod([lhs.shape[i] for i in lhs_contract])
|
|
/ np.prod([lhs.shape[i] for i in lhs_batch])
|
|
)
|
|
|
|
hbm_bytes = 0
|
|
if not ctx.pin_lhs_in_vmem:
|
|
hbm_bytes += lhs.bytes
|
|
hbm_bytes += out.bytes
|
|
if not ctx.pin_rhs_in_vmem:
|
|
hbm_bytes += rhs.bytes
|
|
|
|
return roofline.RooflineResult(
|
|
flops=int(flops),
|
|
unfused_flops=int(flops),
|
|
hbm_bytes=hbm_bytes,
|
|
unfused_hbm_bytes=hbm_bytes,
|
|
)
|
|
|
|
|
|
def _get_spatial_valid_position_count_for_one_dim(
|
|
window_dim_stride: int,
|
|
base_dilation: int,
|
|
window_dilation: int,
|
|
kernel_limit: int,
|
|
input_limit: int,
|
|
output_limit: int,
|
|
padding: tuple[int, int],
|
|
) -> int:
|
|
"""Gets the valid position count for conv for a single spatial dimension.
|
|
|
|
Args:
|
|
window_dim_stride: The stride of the window along this dimension.
|
|
base_dilation: The base dilation factor along this dimension.
|
|
window_dilation: The window dilation factor along this dimension.
|
|
kernel_limit: The size of the kernel along this dimension.
|
|
input_limit: The size of the input along this dimension.
|
|
output_limit: The size of the output along this dimension.
|
|
padding: The padding applied to the input along this dimension.
|
|
"""
|
|
padding_low = padding[0]
|
|
padding_high = padding[1]
|
|
|
|
# These two conditions will create an N^2 iteration pattern with only N
|
|
# valid elements. This is a performance optimization and produces the same
|
|
# result as the whole loop.
|
|
if (
|
|
input_limit == output_limit
|
|
and kernel_limit == output_limit
|
|
and input_limit == base_dilation
|
|
and window_dilation == 1
|
|
and max(1, input_limit - 1) == window_dim_stride
|
|
and padding_low == 0
|
|
and padding_high == 0
|
|
):
|
|
return input_limit
|
|
|
|
if (
|
|
input_limit == 1
|
|
and kernel_limit == output_limit
|
|
and window_dilation == 1
|
|
and base_dilation == 1
|
|
and window_dim_stride == 1
|
|
and padding_low == output_limit - 1
|
|
and padding_high == output_limit - 1
|
|
):
|
|
return output_limit
|
|
|
|
valid_position_count = 0
|
|
# Loop over each point in the kernel
|
|
for kernel_idx in range(kernel_limit):
|
|
|
|
# Skip loop for trivial stride and base_dilation
|
|
if window_dim_stride == 1 and base_dilation == 1:
|
|
undilated_index_base = padding_low - kernel_idx * window_dilation
|
|
upper_limit = min(
|
|
input_limit + undilated_index_base,
|
|
output_limit,
|
|
)
|
|
lower_limit = max(0, undilated_index_base)
|
|
|
|
valid_position_count += max(upper_limit - lower_limit, 0)
|
|
continue
|
|
|
|
# Loop over each point in the output
|
|
for output_idx in range(output_limit):
|
|
# Calculate lhs (input) index without taking base dilation into account
|
|
undilated_index = (
|
|
output_idx * window_dim_stride
|
|
- padding_low
|
|
+ kernel_idx * window_dilation
|
|
)
|
|
# Calculate the actual lhs (input) index after dilation
|
|
lhs_spatial_index = int(undilated_index / base_dilation)
|
|
|
|
# Skip if the lhs (input) index is to be dilated.
|
|
if undilated_index != lhs_spatial_index * base_dilation:
|
|
continue
|
|
# Skip if input index is not in bound.
|
|
if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit:
|
|
continue
|
|
|
|
valid_position_count += 1
|
|
return valid_position_count
|
|
|
|
|
|
def _get_spatial_valid_position_count(
|
|
dnums: convolution.ConvDimensionNumbers,
|
|
lhs: roofline.RooflineShape,
|
|
rhs: roofline.RooflineShape,
|
|
out: roofline.RooflineShape,
|
|
window_strides: Sequence[int],
|
|
padding: Sequence[tuple[int, int]],
|
|
lhs_dilation: Sequence[int],
|
|
rhs_dilation: Sequence[int],
|
|
) -> int:
|
|
"""Gets the number of valid spatial positions for conv_general_dilated.
|
|
|
|
Args:
|
|
dnums: The dimension numbers for the convolution.
|
|
lhs: The shape of the left-hand side of the convolution.
|
|
rhs: The shape of the right-hand side of the convolution.
|
|
out: The shape of the output of the convolution.
|
|
window_strides: The stride of the window along each spatial dimension.
|
|
padding: The padding applied to the input along each spatial dimension.
|
|
lhs_dilation: The dilation factor for the left-hand side along each spatial
|
|
dimension.
|
|
rhs_dilation: The dilation factor for the right-hand side along each spatial
|
|
dimension.
|
|
"""
|
|
input_spatial_dims, kernel_spatial_dims, out_spatial_dims = (
|
|
dnums.lhs_spec[2:],
|
|
dnums.rhs_spec[2:],
|
|
dnums.out_spec[2:],
|
|
)
|
|
|
|
valid_position_counts = 1
|
|
# Loop over each spatial dimension and determine how many valid positions
|
|
# there are for each dimension.
|
|
for d in range(len(input_spatial_dims)):
|
|
valid_position_counts *= _get_spatial_valid_position_count_for_one_dim(
|
|
window_dim_stride=window_strides[d],
|
|
base_dilation=lhs_dilation[d],
|
|
window_dilation=rhs_dilation[d],
|
|
kernel_limit=rhs.shape[kernel_spatial_dims[d]],
|
|
input_limit=lhs.shape[input_spatial_dims[d]],
|
|
output_limit=out.shape[out_spatial_dims[d]],
|
|
padding=padding[d],
|
|
)
|
|
|
|
return valid_position_counts
|
|
|
|
|
|
def _calculate_conv_flops(
|
|
lhs: roofline.RooflineShape,
|
|
rhs: roofline.RooflineShape,
|
|
out: roofline.RooflineShape,
|
|
window_strides: Sequence[int],
|
|
padding: Sequence[tuple[int, int]],
|
|
lhs_dilation: Sequence[int],
|
|
rhs_dilation: Sequence[int],
|
|
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
|
|
batch_group_count: int,
|
|
) -> int:
|
|
"""Calculates roofline unfused flops for Jax's conv_general_dilated primitive.
|
|
|
|
See `jax.lax.conv_general_dilated` for details on the arguments.
|
|
"""
|
|
dnums = convolution.conv_dimension_numbers(
|
|
lhs.shape, rhs.shape, dimension_numbers
|
|
)
|
|
|
|
spatial_valid_position_counts = _get_spatial_valid_position_count(
|
|
dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation
|
|
)
|
|
|
|
batch = lhs.shape[dnums.lhs_spec[0]]
|
|
num_output_features = out.shape[dnums.out_spec[1]]
|
|
num_input_features = rhs.shape[dnums.rhs_spec[1]]
|
|
num_output_batch = batch / batch_group_count
|
|
|
|
non_spatial_dims_factor = (
|
|
num_input_features * num_output_features * num_output_batch
|
|
)
|
|
|
|
fma_count = non_spatial_dims_factor * spatial_valid_position_counts
|
|
flops = fma_count * _FMA_FLOPS_FACTOR
|
|
return int(flops)
|
|
|
|
|
|
@roofline.register_roofline(convolution.conv_general_dilated_p)
|
|
def _conv_general_dilated_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
window_strides: Sequence[int],
|
|
padding: Sequence[tuple[int, int]],
|
|
lhs_dilation: Sequence[int],
|
|
rhs_dilation: Sequence[int],
|
|
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
|
|
batch_group_count: int,
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
"""Roofline for Jax's conv_general_dilated primitive.
|
|
|
|
See `jax.lax.conv_general_dilated` for details on the arguments.
|
|
"""
|
|
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
|
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
|
|
|
return roofline.RooflineResult(
|
|
unfused_flops=_calculate_conv_flops(
|
|
lhs,
|
|
rhs,
|
|
out,
|
|
window_strides,
|
|
padding,
|
|
lhs_dilation,
|
|
rhs_dilation,
|
|
dimension_numbers,
|
|
batch_group_count,
|
|
),
|
|
unfused_hbm_bytes=(
|
|
lhs.dtype.itemsize * lhs.size
|
|
+ rhs.dtype.itemsize * rhs.size
|
|
+ out.dtype.itemsize * out.size
|
|
),
|
|
)
|
|
|
|
|
|
def _return_zeros_if_one_sized_axis(
|
|
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]
|
|
) -> roofline.RooflineResult | None:
|
|
assert ctx.mesh
|
|
axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes])
|
|
if axes_size > 1:
|
|
return None
|
|
return roofline.RooflineResult(
|
|
ici_bytes={axis: 0 for axis in axes},
|
|
ici_latency={axis: 0 for axis in axes},
|
|
)
|
|
|
|
|
|
def _ring_collective_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axes: tuple[str, ...],
|
|
is_reduce: bool = True,
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes):
|
|
return zeros_result
|
|
|
|
assert ctx.mesh
|
|
mesh = ctx.mesh.shape
|
|
current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
|
if is_reduce:
|
|
current_shard_size /= np.prod([mesh[axis] for axis in axes])
|
|
|
|
# We model the slowest color as the bottleneck.
|
|
sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True)
|
|
num_axes = len(sorted_axes)
|
|
|
|
ici_bytes = 0
|
|
# Phase split.
|
|
current_shard_size //= num_axes
|
|
for axis in sorted_axes:
|
|
axis_size = mesh[axis]
|
|
# Do phase.
|
|
ici_bytes += current_shard_size * (axis_size - 1)
|
|
# Increase shard size.
|
|
current_shard_size *= axis_size
|
|
|
|
# Bottleneck is the longest axis.
|
|
ici_latency = mesh[sorted_axes[0]] * num_axes
|
|
|
|
return roofline.RooflineResult(
|
|
ici_bytes={axis: int(ici_bytes) for axis in sorted_axes},
|
|
ici_latency={axis: int(ici_latency) for axis in sorted_axes},
|
|
)
|
|
|
|
|
|
roofline.register_roofline(lax_parallel.reduce_scatter_p)(
|
|
lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw)
|
|
)
|
|
roofline.register_roofline(lax_parallel.all_gather_p)(
|
|
lambda *args, axis_name, **kw: _ring_collective_roofline(
|
|
*args, axes=axis_name, is_reduce=False, **kw
|
|
)
|
|
)
|
|
|
|
|
|
def _scalar_collective_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axes: tuple[str, ...],
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in]
|
|
ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes])
|
|
return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw)
|
|
|
|
|
|
roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
|
|
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
|
|
|
|
|
|
@roofline.register_roofline(lax_parallel.psum_invariant_p)
|
|
def _psum2_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axes: tuple[str, ...],
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw)
|
|
|
|
def double_dict(d: dict[str, int]) -> dict[str, int]:
|
|
return {k: v * 2 for k, v in d.items()}
|
|
|
|
return roofline.RooflineResult(
|
|
ici_bytes=double_dict(ring_roofline.ici_bytes),
|
|
ici_latency=double_dict(ring_roofline.ici_latency),
|
|
)
|
|
|
|
|
|
@roofline.register_roofline(lax_parallel.all_to_all_p)
|
|
def _all_to_all_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axis_name: tuple[str, ...],
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
|
return zeros_result
|
|
|
|
assert ctx.mesh
|
|
mesh = ctx.mesh.shape
|
|
size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([
|
|
mesh[axis] for axis in axis_name
|
|
])
|
|
|
|
smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0]
|
|
num_axes = len(axis_name)
|
|
bisection_bw = mesh[smallest_axis] ** (num_axes - 1)
|
|
if mesh[smallest_axis] > 2:
|
|
# Times 2 because of wraparound.
|
|
bisection_bw *= 2
|
|
|
|
# Half the data needs to cross the bisection on average.
|
|
ici_bytes = size / 2 / bisection_bw
|
|
|
|
# The latency is the max number of hops across the mesh.
|
|
ici_latency = sum(mesh[axis] / 2 for axis in axis_name)
|
|
|
|
return roofline.RooflineResult(
|
|
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
|
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
|
)
|
|
|
|
|
|
@roofline.register_roofline(lax_parallel.ppermute_p)
|
|
def _ppermute_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axis_name: tuple[str, ...],
|
|
perm: tuple[tuple[int, int], ...],
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
|
return zeros_result
|
|
|
|
assert ctx.mesh
|
|
mesh = ctx.mesh.shape
|
|
mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name]
|
|
shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
|
|
|
ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float)
|
|
ici_latency = 0
|
|
|
|
for src, dst in perm:
|
|
if src == dst:
|
|
continue
|
|
# Perms are linearized.
|
|
src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims))
|
|
dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims))
|
|
|
|
ici_latency_for_perm = 0
|
|
|
|
# For each dimension.
|
|
for i in range(len(axis_name)):
|
|
dim_size = mesh_dims[i]
|
|
src_pos = src_coords[i]
|
|
dst_pos = dst_coords[i]
|
|
|
|
if src_pos != dst_pos:
|
|
# Calculate distance with wraparound.
|
|
clockwise_dist = (dst_pos - src_pos) % dim_size
|
|
counter_dist = (src_pos - dst_pos) % dim_size
|
|
direction = 1 if clockwise_dist <= counter_dist else -1
|
|
|
|
curr_pos = src_pos
|
|
while curr_pos != dst_pos:
|
|
curr_coords = util.tuple_update(src_coords, i, curr_pos)
|
|
next_pos = (curr_pos + direction) % dim_size
|
|
next_coords = util.tuple_update(curr_coords, i, next_pos)
|
|
ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1
|
|
curr_pos = next_pos
|
|
|
|
distance = min(clockwise_dist, counter_dist)
|
|
ici_latency_for_perm += distance
|
|
|
|
ici_latency = max(ici_latency, ici_latency_for_perm)
|
|
|
|
ici_bytes = shard_size * max(ici_contention.values(), default=0)
|
|
return roofline.RooflineResult(
|
|
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
|
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
|
)
|
|
|
|
|
|
@roofline.register_roofline(lax.reduce_sum_p)
|
|
def _reduce_sum_p_roofline(
|
|
ctx: roofline.RooflineRuleContext,
|
|
*args,
|
|
axes: tuple[int, ...],
|
|
**kw,
|
|
) -> roofline.RooflineResult:
|
|
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
|
domain_size = np.prod([x.shape[i] for i in axes])
|
|
other_axes = set(range(len(x.shape))) - set(axes)
|
|
result_size = np.prod([x.shape[i] for i in other_axes])
|
|
|
|
return roofline.RooflineResult(
|
|
# To add n values, we do n - 1 add operations, and we have to do that
|
|
# for every element in the result.
|
|
unfused_flops=int((domain_size - 1) * result_size),
|
|
# Size of input, plus output. (We assume that the output is also used
|
|
# as accumulator.)
|
|
unfused_hbm_bytes=int(x.dtype.itemsize * (x.size + result_size)),
|
|
)
|