# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Experimental GPU backend for Pallas targeting H100. These APIs are highly unstable and can change weekly. Use at your own risk. """ from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace from jax._src.pallas.mosaic_gpu.core import kernel as kernel from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401 from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast from jax._src.pallas.mosaic_gpu.primitives import load as load from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`. GMEM = MemorySpace.GMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`. SMEM = MemorySpace.SMEM #: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`. TMEM = MemorySpace.TMEM