245 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			245 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2024 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| """N:M-sparsity associated primitives."""
 | |
| 
 | |
| from jax._src import core
 | |
| from jax._src import dispatch
 | |
| from jax._src.lax.lax import DotDimensionNumbers
 | |
| from jax._src.lib import gpu_sparse
 | |
| from jax._src.lib.mlir.dialects import mhlo
 | |
| from jax._src.typing import Array, DTypeLike
 | |
| from jax.interpreters import mlir
 | |
| import jax.numpy as jnp
 | |
| import numpy as np
 | |
| 
 | |
| # --------------------------------------------------------------------
 | |
| # nm_spmm
 | |
| 
 | |
| nm_spmm_p = core.Primitive("sparse_dense_matmul")
 | |
| 
 | |
| _supported_input_types = (jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16)
 | |
| _supported_output_types = (jnp.bfloat16, jnp.float32)
 | |
| 
 | |
| 
 | |
| def nm_spmm(
 | |
|     lhs: Array,
 | |
|     rhs: Array,
 | |
|     metadata: Array,
 | |
|     dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())),
 | |
|     sparse_operand_idx: int = 0,
 | |
|     output_dtype: DTypeLike = jnp.bfloat16,
 | |
| ) -> Array:
 | |
|   """Dot operation where one of the operands has N:M sparsity.
 | |
| 
 | |
|   Args:
 | |
|     lhs: An ndarray (first dot operand).
 | |
|     rhs: An ndarray (second dot operand).
 | |
|     metadata: An ndarray with structured sparsity metadata for the contracting
 | |
|       dimension. For 2:4 sparsity it should contain (N=2) two-bit index values
 | |
|       for each (M=4) element group.
 | |
|     dimension_numbers: a tuple of tuples of the form `((lhs_contracting_dims,
 | |
|       rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
 | |
|     sparse_operand_idx: index of the sparse operand (0 or 1).
 | |
|     output_dtype: result type.
 | |
| 
 | |
|   Returns:
 | |
|     An ndarray dense array containing the result.
 | |
|   """
 | |
|   return nm_spmm_p.bind(
 | |
|       lhs,
 | |
|       rhs,
 | |
|       metadata,
 | |
|       dimension_numbers=dimension_numbers,
 | |
|       sparse_operand_idx=sparse_operand_idx,
 | |
|       output_dtype=output_dtype,
 | |
|   )
 | |
| 
 | |
| 
 | |
| def _calc_groups_per_element(n, m):
 | |
|   group_bits = n * (m.bit_length() - 1)  # 4 bits per group for 2:4
 | |
|   return 16 // group_bits
 | |
| 
 | |
| 
 | |
| def _validate_dnums(rank, contract, batch, name):
 | |
|   non_contract = tuple(sorted(set(range(rank)) - set(contract + batch)))
 | |
|   if sorted(non_contract + contract + batch) != list(range(rank)):
 | |
|     raise TypeError(f"Incorrect dimension numbers for {name}")
 | |
|   return non_contract
 | |
| 
 | |
| 
 | |
| def _validate_metadata(lhs, rhs, metadata, dimension_numbers, index, n=2, m=4):
 | |
|   assert index in (0, 1)
 | |
|   size_factor = n * _calc_groups_per_element(n, m)
 | |
| 
 | |
|   sparse = [lhs, rhs][index]
 | |
|   sparse_contract = dimension_numbers[0][index]
 | |
|   if metadata.dtype != np.uint16:
 | |
|     raise TypeError(f"Metadata must be uint16, got {metadata.dtype}")
 | |
|   if sparse_contract[0] != sparse.ndim - 1:
 | |
|     raise TypeError("Contracting dimension must be the minor one")
 | |
|   if metadata.shape[:-1] != sparse.shape[:-1]:
 | |
|     raise TypeError(
 | |
|         "Metadata shape must match the operand shape (except for the"
 | |
|         " contracting dimension)"
 | |
|     )
 | |
|   if metadata.shape[-1] * size_factor != sparse.shape[-1]:
 | |
|     raise TypeError(
 | |
|         f"Metadata must be exactly {size_factor} times less than the"
 | |
|         f" contracting dimension for {n}:{m} structured sparsity (expected"
 | |
|         f" {sparse.shape[-1] // size_factor}, got {metadata.shape[-1]})"
 | |
|     )
 | |
|   if sparse.shape[-1] % size_factor != 0:
 | |
|     raise NotImplementedError("Metadata with padding is not supported")
 | |
| 
 | |
|   dense = [lhs, rhs][1 - index]
 | |
|   dense_contract = dimension_numbers[0][1 - index]
 | |
|   a, b = sparse.shape[sparse_contract[0]], dense.shape[dense_contract[0]]
 | |
|   if n * b != m * a:
 | |
|     raise TypeError(
 | |
|         f"Contracting dimension sizes should have {n}:{m} ratio, got {a}:{b}"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _infer_result_shape(lhs, rhs, dimension_numbers):
 | |
|   ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers
 | |
|   if len(lhs_contract) != 1 or len(rhs_contract) != 1:
 | |
|     raise TypeError("Only single contracting dimension is supported")
 | |
|   lhs_dims = _validate_dnums(lhs.ndim, lhs_contract, lhs_batch, "lhs")
 | |
|   rhs_dims = _validate_dnums(rhs.ndim, rhs_contract, rhs_batch, "rhs")
 | |
|   if len(lhs_dims) != 1 or len(rhs_dims) != 1:
 | |
|     raise TypeError("Only single non-contracting dimension is supported")
 | |
|   batch = [lhs.shape[i] for i in lhs_batch]
 | |
|   if batch != [rhs.shape[i] for i in rhs_batch]:
 | |
|     raise TypeError("Batch dimension sizes do not match")
 | |
|   return tuple(batch + [lhs.shape[lhs_dims[0]], rhs.shape[rhs_dims[0]]])
 | |
| 
 | |
| 
 | |
| def _nm_spmm_default_lowering(*_args, **_kwargs):
 | |
|   raise NotImplementedError("Sparse N:M matmul is only implemented on GPU")
 | |
| 
 | |
| 
 | |
| def _nm_spmm_gpu_lowering(
 | |
|     ctx,
 | |
|     lhs,
 | |
|     rhs,
 | |
|     metadata,
 | |
|     *,
 | |
|     dimension_numbers,
 | |
|     sparse_operand_idx,
 | |
|     output_dtype,
 | |
| ):
 | |
|   assert sparse_operand_idx in (0, 1)
 | |
|   sparsity_descriptor = mhlo.SparsityDescriptor.get(
 | |
|       dimension=dimension_numbers[0][sparse_operand_idx][0], n=2, m=4
 | |
|   )
 | |
|   dot_dnums = mhlo.DotDimensionNumbers.get(
 | |
|       lhs_batching_dimensions=dimension_numbers[1][sparse_operand_idx],
 | |
|       rhs_batching_dimensions=dimension_numbers[1][1 - sparse_operand_idx],
 | |
|       lhs_contracting_dimensions=dimension_numbers[0][sparse_operand_idx],
 | |
|       rhs_contracting_dimensions=dimension_numbers[0][1 - sparse_operand_idx],
 | |
|   )
 | |
|   dot_type = ctx.avals_out[0]
 | |
|   key = ["lhs_sparsity", "rhs_sparsity"][sparse_operand_idx]
 | |
|   kwargs = {key: sparsity_descriptor}
 | |
|   op = mhlo.SparseDotOp(
 | |
|       mlir.aval_to_ir_type(dot_type), lhs, rhs, [metadata], dot_dnums, **kwargs
 | |
|   )
 | |
|   return op.results
 | |
| 
 | |
| 
 | |
| @nm_spmm_p.def_abstract_eval
 | |
| def _nm_spmm_abstract_eval(
 | |
|     lhs, rhs, metadata, *, dimension_numbers, sparse_operand_idx, output_dtype
 | |
| ):
 | |
|   if lhs.dtype not in _supported_input_types:
 | |
|     raise TypeError(f"Unsupported lhs input type: {lhs.dtype}")
 | |
|   if rhs.dtype not in _supported_input_types:
 | |
|     raise TypeError(f"Unsupported rhs input type: {rhs.dtype}")
 | |
|   if output_dtype not in _supported_output_types:
 | |
|     raise TypeError(f"Unsupported output type: {output_dtype}")
 | |
| 
 | |
|   res_shape = _infer_result_shape(lhs, rhs, dimension_numbers)
 | |
|   _validate_metadata(lhs, rhs, metadata, dimension_numbers, sparse_operand_idx)
 | |
|   return core.ShapedArray(res_shape, output_dtype)
 | |
| 
 | |
| 
 | |
| mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering)
 | |
| dispatch.simple_impl(nm_spmm_p)
 | |
| 
 | |
| if gpu_sparse.cuda_is_supported:
 | |
|   mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda")
 | |
| 
 | |
| if gpu_sparse.rocm_is_supported:
 | |
|   mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="rocm")
 | |
| 
 | |
| # --------------------------------------------------------------------
 | |
| # nm_pack
 | |
| 
 | |
| nm_pack_p = core.Primitive("sparse_pack_nm")
 | |
| 
 | |
| 
 | |
| def nm_pack(mask: Array, n=2, m=4) -> Array:
 | |
|   """Generate metadata tensor for an N:M mask.
 | |
| 
 | |
|   Args:
 | |
|     mask: Predicates for the input tensor, where the elements are grouped in the
 | |
|       minor dimension. In each group of size M there should be exactly N true
 | |
|       values, which mark the data elements to keep.
 | |
|     n: Number of non-zero elements in a group.
 | |
|     m: Group size.
 | |
| 
 | |
|   Returns:
 | |
|     An ndarray containing only the masked input elements.
 | |
|   """
 | |
|   return nm_pack_p.bind(mask, n=n, m=m)
 | |
| 
 | |
| 
 | |
| def _compress(data, n, m, k):
 | |
|   result = []
 | |
|   expected = n * (k // m)
 | |
|   for i in range(0, len(data), k):
 | |
|     index = tuple(jnp.nonzero(data[i : i + k], size=expected)[0] % m)
 | |
|     value = sum(j * pow(m, i) for i, j in enumerate(index))
 | |
|     result.append(value)
 | |
|   return jnp.array(result, dtype=np.uint16)
 | |
| 
 | |
| 
 | |
| @nm_pack_p.def_impl
 | |
| def _nm_pack_impl(mask, *, n, m):
 | |
|   batch_size = m * _calc_groups_per_element(n, m)
 | |
|   return jnp.apply_along_axis(
 | |
|       lambda x: _compress(x, n, m, batch_size), -1, mask
 | |
|   )
 | |
| 
 | |
| 
 | |
| @nm_pack_p.def_abstract_eval
 | |
| def _nm_pack_abstract_eval(mask, *, n, m):
 | |
|   size_factor = m * _calc_groups_per_element(n, m)
 | |
|   if mask.dtype != bool:
 | |
|     raise TypeError(f"Mask should be bool, got {mask.dtype}")
 | |
|   if mask.shape[-1] % size_factor != 0:
 | |
|     raise TypeError(
 | |
|         f"Inner dimension size should be divisible by {size_factor}, got"
 | |
|         f" {mask.shape}"
 | |
|     )
 | |
|   res_shape = list(mask.shape)
 | |
|   res_shape[-1] //= size_factor
 | |
|   return core.ShapedArray(res_shape, np.uint16)
 | |
| 
 | |
| 
 | |
| _nm_pack_lowering = mlir.lower_fun(_nm_pack_impl, multiple_results=False)
 | |
| mlir.register_lowering(nm_pack_p, _nm_pack_lowering)
 | |
| dispatch.simple_impl(nm_pack_p)
 |