1038 lines
38 KiB
Python
1038 lines
38 KiB
Python
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import math
|
|
|
|
from jaxlib.mlir import ir
|
|
from jaxlib.mlir.dialects import arith
|
|
from jaxlib.mlir.dialects import llvm
|
|
from jaxlib.mlir.dialects import memref
|
|
import numpy as np
|
|
|
|
from . import utils
|
|
from . import fragmented_array as fa
|
|
from . import mma_utils
|
|
from .launch_context import LaunchContext
|
|
|
|
# MyPy does a terrible job with the MLIR API.
|
|
# mypy: ignore-errors
|
|
|
|
|
|
TMEM_ROWS = 128
|
|
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
|
|
LAYOUT = fa.TCGEN05_LAYOUT
|
|
ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT
|
|
COL_LAYOUT = fa.TCGEN05_COL_LAYOUT
|
|
|
|
# A layout resembling the logical organization of TMEM. The 128 rows in a tile
|
|
# are assigned to 128 lanes in the warpgroup. Useful when the result needs to be
|
|
# processed in registers and then stored back into TMEM. Should not be used if
|
|
# the result is to be written back to SMEM, as there is no good way to store it
|
|
# without bank conflicts.
|
|
#
|
|
# We use a vector_dim of 2, to be able to make sure that the vectors are always
|
|
# a multiple of 32-bits, even when the data is 16-bits.
|
|
TMEM_NATIVE_LAYOUT = fa.TiledLayout(
|
|
fa.Tiling(((128, 2), (32, 2))),
|
|
warp_dim=-4,
|
|
lane_dims=(-2,),
|
|
vector_dim=-1,
|
|
)
|
|
|
|
|
|
def create_instr_descriptor(
|
|
m: int,
|
|
n: int,
|
|
acc_dtype,
|
|
input_dtype,
|
|
transpose_a: bool = False,
|
|
transpose_b: bool = False,
|
|
):
|
|
f32 = ir.F32Type.get()
|
|
bf16 = ir.BF16Type.get()
|
|
f16 = ir.F16Type.get()
|
|
if acc_dtype not in {f32, f16}:
|
|
raise NotImplementedError("Only float32 and float16 accumulators supported")
|
|
if utils.bitwidth(input_dtype) == 16:
|
|
if input_dtype not in {f16, bf16}:
|
|
raise NotImplementedError(
|
|
"The only supported 16-bit input types are float16 and bfloat16, got"
|
|
f" {input_dtype}"
|
|
)
|
|
desc = 0
|
|
desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5
|
|
# Bit 6 is reserved
|
|
desc |= (input_dtype == bf16) << 7 # A dtype, bits 7-9
|
|
desc |= (input_dtype == bf16) << 10 # B dtype, bits 10-12
|
|
return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b)
|
|
elif utils.bitwidth(input_dtype) == 8:
|
|
desc = 0
|
|
desc |= (acc_dtype == f32) << 4 # D dtype, bits 4-5
|
|
# Bit 6 is reserved
|
|
if input_dtype == ir.Float8E4M3FNType.get():
|
|
input_dtype_enum = 0
|
|
elif input_dtype == ir.Float8E5M2Type.get():
|
|
input_dtype_enum = 1
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input dtype: {input_dtype}")
|
|
desc |= input_dtype_enum << 7 # A dtype, bits 7-9
|
|
desc |= input_dtype_enum << 10 # B dtype, bits 10-12
|
|
return _finish_instr_descriptor(desc, m, n, transpose_a, transpose_b)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input dtype: {input_dtype}")
|
|
|
|
|
|
def _finish_instr_descriptor(
|
|
desc: int, m: int, n: int, transpose_a: bool, transpose_b: bool,
|
|
):
|
|
# We ignore sparsity in bits 0-3
|
|
# A, B and D types are set by the caller
|
|
# We ignore negate bits 13-14
|
|
desc |= transpose_a << 15 # Transpose A
|
|
desc |= transpose_b << 16 # Transpose B
|
|
if n % 8 or n > 256:
|
|
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
|
desc |= (n >> 3) << 17 # N, bits 17-22
|
|
# Bit 23 is reserved
|
|
if m % 16 or m > 256:
|
|
raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
|
|
desc |= (m >> 4) << 24 # M >> 4, bits 24-28
|
|
# Bit 29 is reserved
|
|
# We ignore max shift under .ws, bits 30-31
|
|
return arith.constant(ir.IntegerType.get_signless(32), desc)
|
|
|
|
|
|
def mma(
|
|
d: TMEMRef,
|
|
a: ir.Value | TMEMRef,
|
|
b: ir.Value,
|
|
*,
|
|
a_swizzle: int = 128,
|
|
b_swizzle: int = 128,
|
|
accumulate: ir.Value | bool = True,
|
|
collective: bool = False,
|
|
):
|
|
if a_swizzle == 16 or b_swizzle == 16:
|
|
raise NotImplementedError("No swizzle is not supported")
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
if isinstance(accumulate, bool):
|
|
accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
|
|
if a_swizzle != b_swizzle:
|
|
raise NotImplementedError(f"{a_swizzle=} != {b_swizzle=}")
|
|
swizzle = a_swizzle
|
|
num_cta = 2 if collective else 1
|
|
|
|
# Step 1. Establish the shape and element type of the operation.
|
|
if not ir.MemRefType.isinstance(b.type):
|
|
raise ValueError(f"B must be a memref, got: {b.type}")
|
|
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
|
if isinstance(a, TMEMRef):
|
|
m, k2 = a.shape
|
|
element_type2 = a.dtype
|
|
if collective and n * num_cta == 512:
|
|
raise NotImplementedError("Collective MMA with N=512 is not supported")
|
|
if a.layout != (expected_layout := _infer_tmem_layout(a.shape, packing=2)):
|
|
raise ValueError(
|
|
f"A layout mismatch: expected {expected_layout}, got {a.layout}"
|
|
)
|
|
else:
|
|
if not ir.MemRefType.isinstance(a.type):
|
|
raise ValueError(f"A must be a memref, got {a.type}")
|
|
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
|
if k != k2:
|
|
raise ValueError(
|
|
"MMA requires A and B to have the same contraction dimension (K),"
|
|
f" got: {k2} and {k}"
|
|
)
|
|
if element_type != element_type2:
|
|
raise ValueError(
|
|
"MMA requires A and B to have the same element type, got:"
|
|
f" {element_type2} and {element_type}"
|
|
)
|
|
if d.shape != (m, n * num_cta):
|
|
raise ValueError(
|
|
f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
|
|
)
|
|
expected_d_layout = (
|
|
TMEM_COLLECTIVE_N512_LAYOUT
|
|
if collective and n * num_cta == 512
|
|
else TMEM_DEFAULT_LAYOUT
|
|
)
|
|
if d.layout != expected_d_layout:
|
|
raise ValueError(
|
|
f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}"
|
|
)
|
|
f32 = ir.F32Type.get()
|
|
f16 = ir.F16Type.get()
|
|
if element_type == f32 or element_type == ir.BF16Type.get():
|
|
if d.dtype != f32:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports accumulators"
|
|
f" of type f32, but got: {d.dtype}"
|
|
)
|
|
elif any(
|
|
t.isinstance(element_type)
|
|
for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType}
|
|
):
|
|
if d.dtype != f16 and d.dtype != f32:
|
|
raise ValueError(
|
|
f"MMA with element type {element_type} only supports accumulators of"
|
|
f" type f32 or f16, but got: {d.dtype}"
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported element type: {element_type}")
|
|
|
|
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
|
# instructions must be issued in groups of the same width as the swizzle.
|
|
m_group_elems = d.layout.elements_in_tile[0]
|
|
if m_group_elems != 128:
|
|
raise NotImplementedError("Only 128-row accumulators supported for now")
|
|
k_group_elems = swizzle // utils.bytewidth(element_type)
|
|
if n % 8:
|
|
raise ValueError(f"N must be a multiple of 8, got: {n}")
|
|
elif n > 256 and n != 512:
|
|
raise ValueError("Only N below 256 or N=512 are supported")
|
|
n_group_elems = min(n, 256 // num_cta)
|
|
if m % m_group_elems:
|
|
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
|
if k % k_group_elems:
|
|
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
|
if n % n_group_elems:
|
|
raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}")
|
|
m_groups = m // m_group_elems
|
|
k_groups = k // k_group_elems
|
|
n_groups = n // n_group_elems
|
|
# TODO(apaszke): Require users to bitcast input refs to tf32 before MMA.
|
|
mma_element_type = (
|
|
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
|
)
|
|
|
|
# Step 3. Compute the operand descriptors.
|
|
if not isinstance(a, TMEMRef):
|
|
(
|
|
(a_desc_base, a_k_instr_stride),
|
|
(a_m_group_stride, a_k_group_stride),
|
|
a_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
a,
|
|
swizzle=swizzle,
|
|
group_size=(m_group_elems, k_group_elems),
|
|
logical_k_major=False,
|
|
)
|
|
else:
|
|
a_fastest = mma_utils.Dim.K
|
|
a_k_instr_stride = None
|
|
a_m_group_stride = a_k_group_stride = a_desc_base = None
|
|
(
|
|
(b_desc_base, b_k_instr_stride),
|
|
(b_n_group_stride, b_k_group_stride),
|
|
b_fastest,
|
|
) = mma_utils.create_descriptor(
|
|
b,
|
|
swizzle=swizzle,
|
|
group_size=(k_group_elems, n_group_elems),
|
|
logical_k_major=True,
|
|
)
|
|
|
|
# Step 4. Issue the instructions.
|
|
true = arith.constant(ir.IntegerType.get_signless(1), 1)
|
|
n_collective_group_elems = n_group_elems * num_cta
|
|
for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
|
|
if isinstance(a, TMEMRef):
|
|
a_mk = a.slice(slice(None), utils.ds(ki * k_group_elems, k_group_elems)).address
|
|
else:
|
|
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
|
a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
|
|
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
|
|
b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64))
|
|
if m_groups != 1:
|
|
raise NotImplementedError("D needs to be sliced")
|
|
acc = accumulate if ki == 0 else true
|
|
_do_mma(
|
|
arith.addi(
|
|
d.address, arith.constant(i32, ni * n_collective_group_elems)
|
|
),
|
|
a_mk,
|
|
b_nk,
|
|
d_type=d.dtype,
|
|
m=m_group_elems,
|
|
n=n_group_elems,
|
|
collective=collective,
|
|
a_transpose=a_fastest != mma_utils.Dim.K,
|
|
b_transpose=b_fastest != mma_utils.Dim.K,
|
|
a_k_stride=a_k_instr_stride,
|
|
b_k_stride=b_k_instr_stride,
|
|
accumulate=acc,
|
|
swizzle=swizzle,
|
|
element_type=mma_element_type,
|
|
)
|
|
|
|
|
|
def _do_mma(
|
|
d_addr: ir.Value,
|
|
a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None
|
|
b_desc: ir.Value,
|
|
a_transpose: bool,
|
|
b_transpose: bool,
|
|
a_k_stride: int | None,
|
|
b_k_stride: int,
|
|
m: int,
|
|
n: int,
|
|
swizzle: int,
|
|
element_type: ir.Type,
|
|
d_type: ir.Type,
|
|
accumulate: ir.Value,
|
|
collective: bool,
|
|
):
|
|
i1 = ir.IntegerType.get_signless(1)
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
elem_bytewidth = utils.bytewidth(element_type)
|
|
kn_tiling = swizzle // elem_bytewidth
|
|
instr_k = 32 // elem_bytewidth
|
|
packing = 4 // elem_bytewidth
|
|
if (a_k_stride is not None and a_k_stride % 16) or b_k_stride % 16:
|
|
raise ValueError
|
|
|
|
if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type):
|
|
kind = "f16"
|
|
elif ir.Float8E5M2Type.isinstance(element_type):
|
|
kind = "f8f6f4"
|
|
elif ir.Float8E4M3FNType.isinstance(element_type):
|
|
kind = "f8f6f4"
|
|
else:
|
|
raise NotImplementedError(f"Unsupported input element type: {element_type}")
|
|
|
|
num_cta = 2 if collective else 1
|
|
i_desc = create_instr_descriptor(
|
|
m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose
|
|
)
|
|
a_in_tmem = a_k_stride is None
|
|
a_ptx = "[$1]" if a_in_tmem else "$1"
|
|
a_ptx_constraint = "r" if a_in_tmem else "l"
|
|
assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64)
|
|
for _ in range(kn_tiling // instr_k):
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[d_addr, a_desc_or_addr, b_desc, i_desc, accumulate],
|
|
f"tcgen05.mma.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, $3, $4;",
|
|
f"r,{a_ptx_constraint},l,r,b",
|
|
has_side_effects=True,
|
|
)
|
|
accumulate = arith.constant(i1, 1)
|
|
if not a_in_tmem:
|
|
a_desc_or_addr = arith.addi(
|
|
a_desc_or_addr, arith.constant(i64, a_k_stride >> 4)
|
|
)
|
|
else:
|
|
a_desc_or_addr = arith.addi(
|
|
a_desc_or_addr, arith.constant(i32, instr_k // packing)
|
|
)
|
|
b_desc = arith.addi(b_desc, arith.constant(i64, b_k_stride >> 4))
|
|
|
|
|
|
def commit_arrive(
|
|
barrier: utils.BarrierRef | ir.Value,
|
|
collective: bool = False,
|
|
ctx: LaunchContext | None = None,
|
|
):
|
|
if isinstance(barrier, utils.BarrierRef):
|
|
barrier = barrier.get_ptr()
|
|
elif barrier.type != ir.Type.parse("!llvm.ptr<3>"):
|
|
raise ValueError(
|
|
"barrier must be a Mosaic barrier or a SMEM pointer, got:"
|
|
f" {barrier.type}"
|
|
)
|
|
if collective:
|
|
if ctx is None:
|
|
raise ValueError("ctx must be provided for collective barriers")
|
|
# TODO(apaszke): This is just 0b11 shifted by the even CTA index.
|
|
if ctx.cluster_size != (2, 1, 1):
|
|
raise NotImplementedError("Collective arrivals only support (2, 1, 1)-shaped clusters")
|
|
ptx = """
|
|
{
|
|
.reg .b16 msk;
|
|
mov.b16 msk, 3;
|
|
tcgen05.commit.cta_group::2.mbarrier::arrive::one.multicast::cluster.b64 [$0], msk;
|
|
}
|
|
"""
|
|
else:
|
|
ptx = "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$0];"
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"), [barrier], ptx, "l", has_side_effects=True
|
|
)
|
|
|
|
|
|
def _alloc_ncols(ncols: int, exact: bool):
|
|
if exact:
|
|
if ncols.bit_count() != 1 or not 32 <= ncols <= 512:
|
|
raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}")
|
|
else:
|
|
ncols = max(32, 1 << (ncols - 1).bit_length())
|
|
if ncols > 512:
|
|
raise ValueError(
|
|
f"After rounding up, got {ncols} columns, exceeding the limit of 512"
|
|
)
|
|
return ncols
|
|
|
|
|
|
def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True):
|
|
if ir.MemRefType.isinstance(tmem_addr.type):
|
|
ref_ty = ir.MemRefType(tmem_addr.type)
|
|
if ref_ty.element_type != ir.IntegerType.get_signless(32):
|
|
raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}")
|
|
if ref_ty.memory_space != ir.Attribute.parse("#gpu.address_space<workgroup>"):
|
|
raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}")
|
|
if math.prod(ref_ty.shape) != 1:
|
|
raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}")
|
|
tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3)
|
|
elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"):
|
|
raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}")
|
|
ncols = _alloc_ncols(ncols, exact)
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[tmem_addr],
|
|
f"tcgen05.alloc.cta_group::{num_cta}.sync.aligned.shared::cta.b32 [$0], {ncols};",
|
|
"r",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True):
|
|
if tmem_addr.type != ir.IntegerType.get_signless(32):
|
|
raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}")
|
|
ncols = _alloc_ncols(ncols, exact)
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[tmem_addr],
|
|
f"tcgen05.dealloc.cta_group::{num_cta}.sync.aligned.b32 $0, {ncols};",
|
|
"r",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
def tmem_relinquish_alloc_permit(collective: bool):
|
|
num_cta = 2 if collective else 1
|
|
return llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[],
|
|
f"tcgen05.relinquish_alloc_permit.cta_group::{num_cta}.sync.aligned;",
|
|
"",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
def _tmem_access_helper(shape, num):
|
|
if num.bit_count() != 1 or num > 128:
|
|
raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
|
|
match shape:
|
|
case "32x32b":
|
|
num_regs = 1
|
|
case "16x128b":
|
|
num_regs = 2
|
|
case "16x256b":
|
|
num_regs = 4
|
|
case _:
|
|
raise NotImplementedError(f"{shape=} is unsupported")
|
|
num_regs *= num
|
|
if num_regs > 255:
|
|
raise ValueError(
|
|
f"TMEM translation too big : {shape=} and {num=} involve"
|
|
f" {num_regs} registers per-thread, which exceeds the limit of 255"
|
|
)
|
|
regs_vector = ",".join(f"${i}" for i in range(num_regs))
|
|
regs_vector = "{" + regs_vector + "}"
|
|
return num_regs, regs_vector
|
|
|
|
|
|
def tmem_load(tmem_addr, shape, num, pack: bool):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
|
|
pack_mod = ".pack::16b" if pack else ""
|
|
regs = llvm.inline_asm(
|
|
ir.Type.parse(
|
|
"!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
|
|
),
|
|
[tmem_addr],
|
|
f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];",
|
|
"=r," * num_out_regs + "r",
|
|
has_side_effects=True,
|
|
)
|
|
return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]
|
|
|
|
|
|
def wait_tmem_load():
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[],
|
|
"tcgen05.wait::ld.sync.aligned;",
|
|
"",
|
|
has_side_effects=True,
|
|
)
|
|
utils.warpgroup_barrier()
|
|
|
|
|
|
def tmem_store(tmem_addr, shape, num, regs, unpack: bool):
|
|
num_out_regs, regs_vector = _tmem_access_helper(shape, num)
|
|
pack_mod = ".unpack::16b" if unpack else ""
|
|
llvm.inline_asm(
|
|
ir.Type.parse("!llvm.void"),
|
|
[*regs, tmem_addr],
|
|
f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};",
|
|
"r," * num_out_regs + "r",
|
|
has_side_effects=True,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TMEMLayout:
|
|
"""Represents the way a shape is laid out in TMEM.
|
|
|
|
Only 2D shapes are supported. Row tiling must be between 32 and 128, and be
|
|
a power of 2. If the row tiling is smaller than 128 (the row count in TMEM),
|
|
the tiles are linearized in row-major order, but laid out in TMEM in a
|
|
column-major order.
|
|
|
|
Consider an array that is (128, 128) and we apply tiling of (64, 64):
|
|
|
|
+------------------+------------------+
|
|
| [0:64, 0:64] | [0:64, 64:128] |
|
|
+------------------+------------------+
|
|
| [64:128, 0:64] | [64:128, 64:128] |
|
|
+------------------+------------------+
|
|
|
|
In TMEM it will be laid out as follows:
|
|
|
|
+------------------+------------------+
|
|
| [0:64, 0:64] | [64:128, 0:64] |
|
|
+------------------+------------------+
|
|
| [0:64, 64:128] | [64:128, 64:128] |
|
|
+------------------+------------------+
|
|
|
|
The above is further complicated by column_tile_stride, which is used to
|
|
swizzle the ordering of column tiles. That is, if column_tile_stride is 2,
|
|
we will first lay out all tiles that have the column index 0, 2, 4, and so on
|
|
until we run out of tiles. Only then we lay out the tiles with column index
|
|
1, 3, etc.
|
|
"""
|
|
elements_in_tile: tuple[int, int]
|
|
column_tile_stride: int = 1
|
|
packing: int = 1
|
|
|
|
def __post_init__(self):
|
|
row_tiling = self.elements_in_tile[0]
|
|
if not 32 <= row_tiling <= 128:
|
|
raise ValueError(
|
|
f"Row tiling must be between 32 and 128, got: {row_tiling}"
|
|
)
|
|
if row_tiling.bit_count() != 1:
|
|
raise ValueError(f"Row tiling must be a power of 2, got: {row_tiling}")
|
|
if self.elements_in_tile[1] % self.packing:
|
|
raise ValueError(
|
|
f"Column tiling must be a multiple of packing={self.packing}, got:"
|
|
f" {self.elements_in_tile[1]}"
|
|
)
|
|
|
|
def check_type(self, shape: tuple[int, ...], dtype: ir.Type):
|
|
if len(shape) != 2:
|
|
raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
|
|
if any(s % t for s, t in zip(shape, self.elements_in_tile)):
|
|
raise ValueError(
|
|
f"{shape} is divisible into tiles of shape {self.elements_in_tile}"
|
|
)
|
|
if self.packing not in {1, fully_packed := 32 // utils.bitwidth(dtype)}:
|
|
raise ValueError(
|
|
f"For {utils.bitwidth(dtype)}-bit types, only packing=1 and"
|
|
f" packing={fully_packed} are supported, but got: {self.packing}"
|
|
)
|
|
|
|
def cols_in_shape(self, shape: tuple[int, int]):
|
|
cols_in_tile = self.elements_in_tile[1] // self.packing
|
|
tiles_in_row = TMEM_ROWS // self.elements_in_tile[0]
|
|
num_tiles = math.prod(utils.tile_shape(shape, self.elements_in_tile)[:-2])
|
|
assert num_tiles % tiles_in_row == 0
|
|
return num_tiles // tiles_in_row * cols_in_tile
|
|
|
|
|
|
def _infer_tmem_layout(shape: tuple[int, int], packing: int = 1) -> TMEMLayout:
|
|
if shape[0] > TMEM_ROWS:
|
|
raise ValueError(
|
|
"Can only infer TMEM layout for shapes with at most 128 rows, got:"
|
|
f" {shape[0]}"
|
|
)
|
|
if shape[0] < 32:
|
|
raise ValueError(
|
|
"Can only infer TMEM layout for shapes with at least 32 rows, got:"
|
|
f" {shape[0]}"
|
|
)
|
|
if shape[0].bit_count() != 1:
|
|
raise ValueError(
|
|
"Can only infer TMEM layout for shapes with row count that's a power of"
|
|
f" 2, got: {shape[0]}"
|
|
)
|
|
if shape[1] % 8:
|
|
raise ValueError(
|
|
"Can only infer TMEM layout for shapes with column count that's a"
|
|
f" multiple of 8, got: {shape[1]}"
|
|
)
|
|
return TMEMLayout(elements_in_tile=(shape[0], 8), packing=packing)
|
|
|
|
|
|
TMEM_DEFAULT_LAYOUT = TMEMLayout(elements_in_tile=(TMEM_ROWS, 8), packing=1)
|
|
TMEM_COLLECTIVE_N512_LAYOUT = TMEMLayout(
|
|
elements_in_tile=(TMEM_ROWS, 128), column_tile_stride=2, packing=1
|
|
)
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TMEMRef:
|
|
address: ir.Value
|
|
shape: tuple[int, int]
|
|
dtype: ir.Type
|
|
layout: TMEMLayout
|
|
|
|
@classmethod
|
|
def from_alloc(
|
|
cls,
|
|
tmem_addr_ref: ir.Value,
|
|
shape: tuple[int, int],
|
|
dtype,
|
|
collective: bool | None = None,
|
|
layout: TMEMLayout | None = None,
|
|
):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
if not ir.MemRefType.isinstance(tmem_addr_ref.type):
|
|
raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
|
|
addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
|
|
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
|
if addr_ref_ty.memory_space != smem:
|
|
raise ValueError(f"tmem_addr_ref must be in workgroup memory, got: {addr_ref_ty}")
|
|
if addr_ref_ty.element_type != i32:
|
|
raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
|
|
if math.prod(addr_ref_ty.shape) != 1:
|
|
raise ValueError(f"tmem_addr_ref must contain a single element, got: {addr_ref_ty}")
|
|
i0 = arith.ConstantOp.create_index(0)
|
|
tmem_addr = memref.load(tmem_addr_ref, [i0] * addr_ref_ty.rank)
|
|
if shape[0] < 32:
|
|
raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
|
|
if layout is None:
|
|
if collective is None:
|
|
raise ValueError(
|
|
"collective argument must be provided when TMEM layout is inferred"
|
|
)
|
|
layout = _infer_tmem_layout(shape, collective)
|
|
else:
|
|
layout.check_type(shape, dtype)
|
|
# TODO: Do we have to do this??
|
|
# warp_idx = utils.warp_idx(sync=False)
|
|
# tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
|
|
return cls(tmem_addr, shape, dtype, layout)
|
|
|
|
def slice(self, *idxs):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
|
|
if any(is_squeezed):
|
|
raise ValueError("TMEM can only be sliced, not indexed")
|
|
match self.layout:
|
|
case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if (
|
|
r == TMEM_ROWS
|
|
):
|
|
pass
|
|
case _:
|
|
raise NotImplementedError(
|
|
"Slicing only implemented for refs with standard layout, got:"
|
|
f" {self.layout}"
|
|
)
|
|
if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
|
|
raise NotImplementedError("TMEM cannot be sliced along rows")
|
|
if slice_shape[1] % 8:
|
|
raise NotImplementedError(
|
|
"TMEM column slice length must be a multiple of 8. "
|
|
f"Got {slice_shape[1]}."
|
|
)
|
|
col_idx = base_idx[1]
|
|
if not isinstance(col_idx, ir.Value):
|
|
col_idx = arith.constant(i32, col_idx)
|
|
if col_idx.type == ir.IndexType.get():
|
|
col_idx = arith.index_cast(i32, col_idx)
|
|
if packing != 1:
|
|
col_idx = arith.divui(col_idx, arith.constant(i32, packing))
|
|
return TMEMRef(
|
|
address=arith.addi(self.address, col_idx),
|
|
shape=tuple(slice_shape),
|
|
layout=self.layout,
|
|
dtype=self.dtype,
|
|
)
|
|
|
|
def load(self, layout: fa.TiledLayout = LAYOUT):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
if self.shape[1] % 8:
|
|
raise NotImplementedError
|
|
if utils.bitwidth(self.dtype) not in {16, 32}:
|
|
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
|
if layout == LAYOUT:
|
|
regs_shape = layout.registers_shape(self.shape)
|
|
match self.layout:
|
|
case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if (
|
|
r == TMEM_ROWS
|
|
):
|
|
# load_32xcols returns a 4xN array, but the FA tiling we use here tiles
|
|
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
|
|
registers = _load_32xcols(
|
|
self.address, self.shape[1], self.dtype, packing
|
|
).T.reshape(regs_shape)
|
|
case TMEMLayout(elements_in_tile=(r, 128), column_tile_stride=2) if r == TMEM_ROWS:
|
|
if self.shape[1] % 128 != 0:
|
|
raise ValueError(
|
|
f"TMEM layout {self.layout} is not compatible with shape {self.shape}"
|
|
)
|
|
num_column_tiles = self.shape[1] // 128
|
|
column_tile_stride = self.layout.column_tile_stride
|
|
num_strided_col_groups = utils.ceil_div(num_column_tiles, column_tile_stride)
|
|
tiles = []
|
|
for col_tile_base in range(num_strided_col_groups):
|
|
for col_tile in range(col_tile_base, num_column_tiles, column_tile_stride):
|
|
tiles.append(
|
|
_load_32xcols(
|
|
arith.addi(self.address, arith.constant(i32, col_tile * 128)),
|
|
cols=128,
|
|
dtype=self.dtype,
|
|
tmem_packing=1,
|
|
)
|
|
)
|
|
registers = np.concatenate(tiles, axis=1).T.reshape(regs_shape)
|
|
case _:
|
|
raise NotImplementedError(
|
|
f"Loads only implemented for refs with standard layout, got: {self.layout}"
|
|
)
|
|
elif layout == TMEM_NATIVE_LAYOUT:
|
|
regs_shape = layout.registers_shape(self.shape)
|
|
match self.layout:
|
|
case TMEMLayout(elements_in_tile=(r, c), packing=packing) if (
|
|
r == TMEM_ROWS and c % 2 == 0
|
|
):
|
|
registers = _load_32xcols_native(
|
|
self.address, self.shape[1], self.dtype, packing
|
|
).reshape(regs_shape)
|
|
case _:
|
|
raise NotImplementedError(
|
|
"Loads only implemented for refs with standard layout, got:"
|
|
f" {self.layout}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"TMEM loads can only produce results in the tcgen05 layouts"
|
|
f" ({LAYOUT} and {TMEM_NATIVE_LAYOUT}), but got: {layout}"
|
|
)
|
|
return fa.FragmentedArray(_registers=registers, _layout=layout, _is_signed=None)
|
|
|
|
def store(self, value):
|
|
if self.shape[1] % 8:
|
|
raise NotImplementedError
|
|
if utils.bitwidth(self.dtype) not in {16, 32}:
|
|
raise NotImplementedError(f"Unsupported dtype: {self.dtype}")
|
|
if not isinstance(value, fa.FragmentedArray):
|
|
raise ValueError(f"TMEM stores expect a FragmentedArray, got: {value}")
|
|
if value.shape != self.shape:
|
|
raise ValueError(
|
|
f"Stored array has shape {value.shape}, but TMEM has shape"
|
|
f" {self.shape}"
|
|
)
|
|
if value.mlir_dtype != self.dtype:
|
|
raise ValueError(
|
|
f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype"
|
|
f" {self.dtype}"
|
|
)
|
|
if value.layout == LAYOUT:
|
|
# TODO(apaszke): Collective MMA layout
|
|
match self.layout:
|
|
case TMEMLayout(elements_in_tile=(r, 8), packing=packing) if (
|
|
r == TMEM_ROWS
|
|
):
|
|
# store_32xcols needs a 4xN array, but the FA tiling we use here tiles
|
|
# columns before rows, and so it is Nx4 (after ignoring all 1 dims).
|
|
_store_32xcols(
|
|
self.address, value.registers.T.reshape((4, -1)), packing
|
|
)
|
|
case _:
|
|
raise NotImplementedError(
|
|
f"Stores only implemented for refs with standard layout, got: {self.layout}"
|
|
)
|
|
elif value.layout == TMEM_NATIVE_LAYOUT:
|
|
# TODO(apaszke): Collective MMA layout
|
|
match self.layout:
|
|
case TMEMLayout(elements_in_tile=(r, c), packing=packing) if (
|
|
r == TMEM_ROWS and c % 2 == 0
|
|
):
|
|
_store_32xcols_native(
|
|
self.address, value.registers.reshape(-1), packing
|
|
)
|
|
case _:
|
|
raise NotImplementedError(
|
|
f"Stores only implemented for refs with standard layout, got: {self.layout}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Stored array has layout {value.layout}, but only tcgen05.LAYOUT and"
|
|
" tcgen05.TMEM_NATIVE_LAYOUT are supported"
|
|
)
|
|
|
|
def _debug_print(self):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
num_cols = self.layout.cols_in_shape(self.shape)
|
|
lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE))
|
|
for c in range(num_cols):
|
|
val = llvm.inline_asm(
|
|
i32,
|
|
[arith.addi(self.address, arith.constant(i32, c))],
|
|
"tcgen05.ld.sync.aligned.32x32b.x1.b32 {$0}, [$1];",
|
|
"=r,r",
|
|
)
|
|
dtype_bitwidth = utils.bitwidth(self.dtype)
|
|
full_packing = 32 // dtype_bitwidth
|
|
if self.layout.packing == 1:
|
|
if dtype_bitwidth < 32:
|
|
val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val)
|
|
val = utils.bitcast(val, self.dtype)
|
|
elif self.layout.packing == full_packing:
|
|
val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype))
|
|
else:
|
|
raise NotImplementedError(f"Unsupported packing: {self.layout.packing}")
|
|
# TODO(apaszke): Make this print logical, not physical location.
|
|
utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False)
|
|
|
|
|
|
def _transfer_32xcols(
|
|
base_addr: ir.Value,
|
|
cols: int,
|
|
atom_shape: tuple[int, int],
|
|
tmem_packing: int,
|
|
reg_packing: int,
|
|
):
|
|
"""Generates a sequence of parameters for a given TMEM read or write.
|
|
|
|
Arguments:
|
|
base_addr: The base address of the TMEM region.
|
|
cols: The number of logical columns to transfer.
|
|
atom_shape: The logical shape of the tile written by the warp in a single
|
|
TMEM transfer.
|
|
tmem_packing: Packing degree in TMEM. When packing is 1, but the data is
|
|
16-bit, we expect that each transfer actually involves double the number
|
|
of physical columns.
|
|
reg_packing: The number of elements that fit in a single 32-bit register.
|
|
"""
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
atom_rows, atom_cols = atom_shape
|
|
assert cols % atom_cols == 0
|
|
total_num = cols // atom_cols
|
|
assert total_num.bit_count() == 1
|
|
regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing)
|
|
# We artificially lower the instr_num compared to its limits, because higher
|
|
# values can lead to register spills..
|
|
instr_num = min(total_num, 32 // regs_per_instr)
|
|
assert 32 % atom_rows == 0
|
|
num_row_steps = 32 // atom_rows
|
|
for lane_step in range(num_row_steps):
|
|
addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32))
|
|
cols_per_instr = instr_num * atom_cols
|
|
for num_step in range(total_num // instr_num):
|
|
num_slice = slice(num_step * instr_num, (num_step + 1) * instr_num)
|
|
addr_row_col = arith.addi(
|
|
addr_row, utils.c(num_step * cols_per_instr // tmem_packing, i32)
|
|
)
|
|
yield addr_row_col, instr_num, lane_step, num_slice
|
|
|
|
|
|
def _store_32xcols(base_addr, vector_regs, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4
|
|
cols = vector_regs.shape[1] * 8
|
|
|
|
reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
|
|
if reg_packing == 1:
|
|
store_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
|
|
regs = np.empty((4, vector_regs.shape[1], 2), dtype=object)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for idx, vreg in np.ndenumerate(vector_regs):
|
|
regs[(*idx, 0)] = llvm.extractelement(vreg, c0)
|
|
regs[(*idx, 1)] = llvm.extractelement(vreg, c1)
|
|
regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2)
|
|
# From a single lane perspective a num tile consists of a 2x2, with the
|
|
# minor dim traversing columns and major being 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
|
|
assert regs.shape[-2:] == (2, 2)
|
|
assert tmem_packing == 1
|
|
unpack = False
|
|
elif reg_packing == 2:
|
|
store_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
|
|
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
|
|
regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2)
|
|
assert 1 <= tmem_packing <= 2
|
|
unpack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
regs_slice = regs[lane_step, num_slice].flat
|
|
tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack)
|
|
|
|
|
|
def _store_32xcols_native(base_addr, vector_regs, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
assert vector_regs.ndim == 1
|
|
cols = len(vector_regs) * TMEM_NATIVE_LAYOUT.vector_length
|
|
|
|
reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
|
|
store_shape = "32x32b"
|
|
if reg_packing == 1:
|
|
store_atom_shape = (32, 1)
|
|
regs = [None] * (len(vector_regs) * 2)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for idx, vreg in enumerate(vector_regs):
|
|
regs[2 * idx] = llvm.extractelement(vreg, c0)
|
|
regs[2 * idx + 1] = llvm.extractelement(vreg, c1)
|
|
assert tmem_packing == 1
|
|
unpack = False
|
|
elif reg_packing == 2:
|
|
store_atom_shape = (32, 2)
|
|
regs = vector_regs
|
|
assert 1 <= tmem_packing <= 2
|
|
unpack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
assert lane_step == 0
|
|
regs_slice = regs[num_slice]
|
|
tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack)
|
|
|
|
|
|
def _load_32xcols(base_addr, cols, dtype, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
vec_ty = ir.VectorType.get((2,), dtype)
|
|
reg_packing = 32 // utils.bitwidth(dtype)
|
|
if reg_packing == 1:
|
|
load_shape = "16x256b" # 4 threads * 64 bits per vreg = 256 bits
|
|
assert tmem_packing == 1
|
|
pack = False
|
|
elif reg_packing == 2:
|
|
load_shape = "16x128b" # 4 threads * 32 bits per vreg = 128 bits
|
|
assert 1 <= tmem_packing <= 2
|
|
pack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
vector_regs = np.ndarray((4, cols // 8), dtype=object)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
regs = tmem_load(addr_row_col, load_shape, instr_num, pack)
|
|
row_slice = slice(lane_step * 2, (lane_step + 1) * 2)
|
|
# This aliases the original array, so updates will be reflected there.
|
|
vector_regs_update = vector_regs[row_slice, num_slice]
|
|
assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num)
|
|
if reg_packing == 1:
|
|
regs = [llvm.bitcast(dtype, r) for r in regs]
|
|
# From a single lane perspective a num tile consists of a 2x2, with the
|
|
# minor dim traversing columns and major being 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
|
|
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1)
|
|
undef = llvm.mlir_undef(vec_ty)
|
|
assert regs.shape == (*vector_regs_update.shape, 2)
|
|
for idx in np.ndindex(vector_regs_update.shape):
|
|
high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0)
|
|
vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1)
|
|
vector_regs_update[idx] = vreg
|
|
else:
|
|
assert reg_packing == 2
|
|
regs = [llvm.bitcast(vec_ty, r) for r in regs]
|
|
# From a single lane perspective a num tile has 2 registers, 8 rows apart.
|
|
# See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
|
|
regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1)
|
|
vector_regs_update[...] = regs
|
|
|
|
return vector_regs
|
|
|
|
|
|
def _load_32xcols_native(base_addr, cols, dtype, tmem_packing):
|
|
i32 = ir.IntegerType.get_signless(32)
|
|
vec_ty = ir.VectorType.get((2,), dtype)
|
|
reg_packing = 32 // utils.bitwidth(dtype)
|
|
load_shape = "32x32b"
|
|
if reg_packing == 1:
|
|
load_atom_shape = (32, 1)
|
|
assert tmem_packing == 1
|
|
pack = False
|
|
elif reg_packing == 2:
|
|
load_atom_shape = (32, 2)
|
|
assert 1 <= tmem_packing <= 2
|
|
pack = tmem_packing == 1
|
|
else:
|
|
raise NotImplementedError(reg_packing)
|
|
|
|
it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing)
|
|
c0 = arith.constant(i32, 0)
|
|
c1 = arith.constant(i32, 1)
|
|
regs = [None] * (cols // reg_packing)
|
|
for addr_row_col, instr_num, lane_step, num_slice in it:
|
|
assert lane_step == 0, lane_step
|
|
instr_regs = tmem_load(addr_row_col, load_shape, instr_num, pack)
|
|
if reg_packing == 1:
|
|
regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs]
|
|
else:
|
|
assert reg_packing == 2
|
|
regs[num_slice] = [llvm.bitcast(vec_ty, r) for r in instr_regs]
|
|
|
|
if reg_packing == 1:
|
|
vector_regs = np.ndarray((cols // 2,), dtype=object)
|
|
undef = llvm.mlir_undef(vec_ty)
|
|
for idx in range(vector_regs.size):
|
|
high_undef = llvm.insertelement(undef, regs[2 * idx], c0)
|
|
vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1)
|
|
vector_regs[idx] = vreg
|
|
else:
|
|
assert reg_packing == 2
|
|
vector_regs = np.asarray(regs, dtype=object)
|
|
|
|
assert vector_regs.shape == (cols // TMEM_NATIVE_LAYOUT.vector_length,)
|
|
return vector_regs
|
|
|
|
|
|
def _m128_layout(shape: tuple[int, ...]):
|
|
if len(shape) != 2:
|
|
raise ValueError(f"Shape {shape} is not 2D")
|
|
if shape[0] % 128 != 0 or shape[1] % 8 != 0:
|
|
raise ValueError(f"Shape {shape} is not a multiple of 64x8")
|
|
return LAYOUT
|
|
|
|
|
|
def commit_tmem():
|
|
void = ir.Type.parse("!llvm.void")
|
|
llvm.inline_asm(
|
|
void, [], "tcgen05.wait::st.sync.aligned;", "", has_side_effects=True,
|
|
)
|
|
utils.warpgroup_barrier()
|