2025-08-11 12:24:21 +08:00

72 lines
4.3 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental GPU backend for Pallas targeting H100.
These APIs are highly unstable and can change weekly. Use at your own risk.
"""
from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier
from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier
from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec
from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams
from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh
from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef
from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion
from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref
from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref
from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref
from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref
from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group
from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu
from jax._src.pallas.mosaic_gpu.primitives import Layout as Layout
from jax._src.pallas.mosaic_gpu.primitives import layout_cast as layout_cast
from jax._src.pallas.mosaic_gpu.primitives import load as load
from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma
from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem
from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`.
GMEM = MemorySpace.GMEM
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`.
SMEM = MemorySpace.SMEM
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`.
TMEM = MemorySpace.TMEM