2025-08-11 12:24:21 +08:00

218 lines
6.7 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layout utilities."""
import re
from jax._src.lib.mlir import ir
from . import fragmented_array as fa
_splat_fragmented_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.WGSplatFragLayout<\[(?P<shape>.*)\]>$"
)
def to_splat_fragmented_layout_attr(layout: fa.WGSplatFragLayout) -> ir.Attribute:
"""Constructs a #mosaic_gpu.WGSplatFragLayout attribute from a WGSplatFragLayout."""
return ir.Attribute.parse(
f"#mosaic_gpu.WGSplatFragLayout<{list(layout.shape)}>"
)
def from_splat_fragmented_layout_attr(attr: ir.Attribute) -> fa.WGSplatFragLayout:
"""Constructs a WGSplatFragLayout from a #mosaic_gpu.WGSplatFragLayout attribute.
Raises:
ValueError: If the attribute is not a #mosaic_gpu.WGSplatFragLayout
attribute.
"""
match = _splat_fragmented_layout_attr_pattern.fullmatch(str(attr))
if not match:
raise ValueError(
f"Expected a #mosaic_gpu.WGSplatFragLayout attribute, got {attr}"
)
return fa.WGSplatFragLayout(
shape=tuple(int(s) for s in match.group("shape").split(","))
)
def is_splat_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_splat_fragmented_layout_attr_pattern.search(str(attr)))
_strided_fragmented_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.WGStridedFragLayout<\[(?P<shape>.*)\],"
r" (?P<vector_size>\d+)>$"
)
def to_strided_fragmented_layout_attr(
layout: fa.WGStridedFragLayout,
) -> ir.Attribute:
"""Constructs a #mosaic_gpu.WGStridedFragLayout attribute from a WGStridedFragLayout."""
return ir.Attribute.parse(
f"#mosaic_gpu.WGStridedFragLayout<{list(layout.shape)},"
f" {layout.vec_size}>"
)
def from_strided_fragmented_layout_attr(
attr: ir.Attribute,
) -> fa.WGStridedFragLayout:
"""Constructs a WGStridedFragLayout from a #mosaic_gpu.WGStridedFragLayout attribute.
Raises:
ValueError: If the attribute is not a #mosaic_gpu.WGStridedFragLayout
attribute.
"""
match = _strided_fragmented_layout_attr_pattern.fullmatch(str(attr))
if not match:
raise ValueError(
f"Expected a #mosaic_gpu.WGStridedFragLayout attribute, got {attr}"
)
return fa.WGStridedFragLayout(
shape=tuple(int(s) for s in match.group("shape").split(",")),
vec_size=int(match.group("vector_size")),
)
def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
_tiled_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
r" warp_dim\s*=\s*(?P<warp_dim>.+),"
r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
)
def to_tiled_layout_attr(
layout: fa.TiledLayout,
) -> ir.Attribute:
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
def _int_or_replicated(d: int | fa.Replicated) -> str:
if isinstance(d, fa.Replicated):
return f"#mosaic_gpu.Replicated<times={d.times}>"
return str(d)
tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]"
tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]"
lane_dims = (
"[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]"
)
return ir.Attribute.parse(
f"#mosaic_gpu.TiledLayout<{tiling},"
f" warp_dim={_int_or_replicated(layout.warp_dim)},"
f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>"
)
_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[")
_int_pattern = re.compile(r"^(?P<num>[-\d]+)(\s*:\s*\w+)?$")
_replicated_pattern = re.compile(
r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P<times>\d+)\s*>\s*$"
)
def from_tiled_layout_attr(
attr: ir.Attribute,
) -> fa.TiledLayout:
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute.
Raises:
ValueError: If the attribute is not a #mosaic_gpu.TiledLayout
attribute.
"""
match = _tiled_layout_attr_pattern.fullmatch(str(attr))
if not match:
raise ValueError(
f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}"
)
def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated:
match = _replicated_pattern.fullmatch(replicated_dim)
if match:
return fa.Replicated(int(match.group("times")))
match = _int_pattern.fullmatch(replicated_dim)
if match:
return int(match.group("num"))
raise ValueError(f"Unexpected format for replicated dim {replicated_dim}")
tiling_str = match.group("tiling")
tile_strings = []
if len(tiling_str) > 2:
tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1])
tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings)
return fa.TiledLayout(
tiling=fa.Tiling(tiles),
warp_dim=_int_or_replicated(match.group("warp_dim")),
lane_dims=tuple(
_int_or_replicated(s.strip())
for s in match.group("lane_dims").split(",")
),
vector_dim=int(match.group("vector_dim")),
)
def is_tiled_layout(attr: ir.Attribute) -> bool:
return bool(_tiled_layout_attr_pattern.search(str(attr)))
def to_layout_attr(
layout: (
fa.WGSplatFragLayout
| fa.WGStridedFragLayout
| fa.TiledLayout
),
) -> ir.Attribute:
"""Constructs an MLIR attribute that corresponds to the given layout."""
match layout:
case fa.WGSplatFragLayout():
return to_splat_fragmented_layout_attr(layout)
case fa.WGStridedFragLayout():
return to_strided_fragmented_layout_attr(layout)
case fa.TiledLayout():
return to_tiled_layout_attr(layout)
case _:
raise NotImplementedError(
f"Unsupported layout for conversion to MLIR attribute: {layout}"
)
def from_layout_attr(
attr: ir.Attribute,
) -> (
fa.WGSplatFragLayout
| fa.WGStridedFragLayout
| fa.TiledLayout
):
"""Constructs a layout from an MLIR attribute."""
if is_splat_fragmented_layout(attr):
return from_splat_fragmented_layout_attr(attr)
elif is_strided_fragmented_layout(attr):
return from_strided_fragmented_layout_attr(attr)
elif is_tiled_layout(attr):
return from_tiled_layout_attr(attr)
else:
raise NotImplementedError(
f"Unsupported layout for conversion from MLIR attribute: {attr}"
)