# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Layout utilities.""" import re from jax._src.lib.mlir import ir from . import fragmented_array as fa _splat_fragmented_layout_attr_pattern = re.compile( r"^#mosaic_gpu.WGSplatFragLayout<\[(?P.*)\]>$" ) def to_splat_fragmented_layout_attr(layout: fa.WGSplatFragLayout) -> ir.Attribute: """Constructs a #mosaic_gpu.WGSplatFragLayout attribute from a WGSplatFragLayout.""" return ir.Attribute.parse( f"#mosaic_gpu.WGSplatFragLayout<{list(layout.shape)}>" ) def from_splat_fragmented_layout_attr(attr: ir.Attribute) -> fa.WGSplatFragLayout: """Constructs a WGSplatFragLayout from a #mosaic_gpu.WGSplatFragLayout attribute. Raises: ValueError: If the attribute is not a #mosaic_gpu.WGSplatFragLayout attribute. """ match = _splat_fragmented_layout_attr_pattern.fullmatch(str(attr)) if not match: raise ValueError( f"Expected a #mosaic_gpu.WGSplatFragLayout attribute, got {attr}" ) return fa.WGSplatFragLayout( shape=tuple(int(s) for s in match.group("shape").split(",")) ) def is_splat_fragmented_layout(attr: ir.Attribute) -> bool: return bool(_splat_fragmented_layout_attr_pattern.search(str(attr))) _strided_fragmented_layout_attr_pattern = re.compile( r"^#mosaic_gpu.WGStridedFragLayout<\[(?P.*)\]," r" (?P\d+)>$" ) def to_strided_fragmented_layout_attr( layout: fa.WGStridedFragLayout, ) -> ir.Attribute: """Constructs a #mosaic_gpu.WGStridedFragLayout attribute from a WGStridedFragLayout.""" return ir.Attribute.parse( f"#mosaic_gpu.WGStridedFragLayout<{list(layout.shape)}," f" {layout.vec_size}>" ) def from_strided_fragmented_layout_attr( attr: ir.Attribute, ) -> fa.WGStridedFragLayout: """Constructs a WGStridedFragLayout from a #mosaic_gpu.WGStridedFragLayout attribute. Raises: ValueError: If the attribute is not a #mosaic_gpu.WGStridedFragLayout attribute. """ match = _strided_fragmented_layout_attr_pattern.fullmatch(str(attr)) if not match: raise ValueError( f"Expected a #mosaic_gpu.WGStridedFragLayout attribute, got {attr}" ) return fa.WGStridedFragLayout( shape=tuple(int(s) for s in match.group("shape").split(",")), vec_size=int(match.group("vector_size")), ) def is_strided_fragmented_layout(attr: ir.Attribute) -> bool: return bool(_strided_fragmented_layout_attr_pattern.search(str(attr))) _tiled_layout_attr_pattern = re.compile( r"^#mosaic_gpu.TiledLayout<\[(?P.*)\]," r" warp_dim\s*=\s*(?P.+)," r" lane_dims\s*=\s*\[(?P.*)\]," r" vector_dim\s*=\s*(?P[-\d]+)>$" ) def to_tiled_layout_attr( layout: fa.TiledLayout, ) -> ir.Attribute: """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout.""" def _int_or_replicated(d: int | fa.Replicated) -> str: if isinstance(d, fa.Replicated): return f"#mosaic_gpu.Replicated" return str(d) tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]" tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]" lane_dims = ( "[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]" ) return ir.Attribute.parse( f"#mosaic_gpu.TiledLayout<{tiling}," f" warp_dim={_int_or_replicated(layout.warp_dim)}," f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>" ) _list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[") _int_pattern = re.compile(r"^(?P[-\d]+)(\s*:\s*\w+)?$") _replicated_pattern = re.compile( r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P\d+)\s*>\s*$" ) def from_tiled_layout_attr( attr: ir.Attribute, ) -> fa.TiledLayout: """Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute. Raises: ValueError: If the attribute is not a #mosaic_gpu.TiledLayout attribute. """ match = _tiled_layout_attr_pattern.fullmatch(str(attr)) if not match: raise ValueError( f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}" ) def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated: match = _replicated_pattern.fullmatch(replicated_dim) if match: return fa.Replicated(int(match.group("times"))) match = _int_pattern.fullmatch(replicated_dim) if match: return int(match.group("num")) raise ValueError(f"Unexpected format for replicated dim {replicated_dim}") tiling_str = match.group("tiling") tile_strings = [] if len(tiling_str) > 2: tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1]) tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings) return fa.TiledLayout( tiling=fa.Tiling(tiles), warp_dim=_int_or_replicated(match.group("warp_dim")), lane_dims=tuple( _int_or_replicated(s.strip()) for s in match.group("lane_dims").split(",") ), vector_dim=int(match.group("vector_dim")), ) def is_tiled_layout(attr: ir.Attribute) -> bool: return bool(_tiled_layout_attr_pattern.search(str(attr))) def to_layout_attr( layout: ( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout ), ) -> ir.Attribute: """Constructs an MLIR attribute that corresponds to the given layout.""" match layout: case fa.WGSplatFragLayout(): return to_splat_fragmented_layout_attr(layout) case fa.WGStridedFragLayout(): return to_strided_fragmented_layout_attr(layout) case fa.TiledLayout(): return to_tiled_layout_attr(layout) case _: raise NotImplementedError( f"Unsupported layout for conversion to MLIR attribute: {layout}" ) def from_layout_attr( attr: ir.Attribute, ) -> ( fa.WGSplatFragLayout | fa.WGStridedFragLayout | fa.TiledLayout ): """Constructs a layout from an MLIR attribute.""" if is_splat_fragmented_layout(attr): return from_splat_fragmented_layout_attr(attr) elif is_strided_fragmented_layout(attr): return from_strided_fragmented_layout_attr(attr) elif is_tiled_layout(attr): return from_tiled_layout_attr(attr) else: raise NotImplementedError( f"Unsupported layout for conversion from MLIR attribute: {attr}" )