347 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			347 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2023 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """Primitives for calling out to cusparse.
 | |
| 
 | |
| In general, these primitives are not meant to be used directly, but rather
 | |
| are used internally in GPU translation rules of higher-level primitives.
 | |
| """
 | |
| 
 | |
| from functools import partial
 | |
| from typing import Any
 | |
| 
 | |
| from jax._src import core
 | |
| from jax._src import dispatch
 | |
| from jax._src import ffi
 | |
| from jax._src.interpreters import mlir
 | |
| from jax._src.lib import gpu_sparse
 | |
| from jax._src.lib import has_cpu_sparse
 | |
| import numpy as np
 | |
| 
 | |
| if hasattr(gpu_sparse, "registrations"):
 | |
|   for platform, targets in gpu_sparse.registrations().items():
 | |
|     for name, value, api_version in targets:
 | |
|       ffi.register_ffi_target(
 | |
|           name, value, platform=platform, api_version=api_version
 | |
|       )
 | |
| 
 | |
| if has_cpu_sparse:
 | |
|   from jax._src.lib import cpu_sparse
 | |
| 
 | |
|   if hasattr(cpu_sparse, "registrations"):
 | |
|     for platform, targets in cpu_sparse.registrations().items():
 | |
|       for name, value, api_version in targets:
 | |
|         ffi.register_ffi_target(
 | |
|             name, value, platform=platform, api_version=api_version
 | |
|         )
 | |
| 
 | |
| def _get_module(target_name_prefix: str) -> Any:
 | |
|   if target_name_prefix == "cu":
 | |
|     return gpu_sparse._cusparse
 | |
|   elif target_name_prefix == "hip":
 | |
|     return gpu_sparse._hipsparse
 | |
|   else:
 | |
|     raise ValueError(f"Unsupported target_name_prefix: {target_name_prefix}")
 | |
| 
 | |
| SUPPORTED_DATA_DTYPES = [np.float32, np.float64, np.complex64, np.complex128]
 | |
| SUPPORTED_INDEX_DTYPES = [np.int32]
 | |
| 
 | |
| # coo_spmv_p
 | |
| # This is an internal-only primitive that calls into cusparse coo SpMV.
 | |
| # This is a raw lowering that does no validation of inputs; the indices are
 | |
| # assumed to be lexicographically sorted, deduplicated, and in-bounds.
 | |
| coo_spmv_p = core.Primitive("coo_spmv")
 | |
| 
 | |
| def _coo_spmv_abstract_eval(data, row, col, x, *, transpose, shape):
 | |
|   # TODO(jakevdp) support for batched matvec.
 | |
|   assert data.shape == row.shape == col.shape
 | |
|   assert row.ndim == 1
 | |
|   assert x.ndim == 1
 | |
| 
 | |
|   assert row.dtype == col.dtype
 | |
|   assert row.dtype in SUPPORTED_INDEX_DTYPES
 | |
| 
 | |
|   assert data.dtype == x.dtype
 | |
|   assert x.dtype in SUPPORTED_DATA_DTYPES
 | |
| 
 | |
|   assert len(shape) == 2
 | |
|   assert x.shape[0] == (shape[0] if transpose else shape[1])
 | |
| 
 | |
|   return core.ShapedArray(
 | |
|     shape=shape[1:] if transpose else shape[:1],
 | |
|     dtype=x.dtype)
 | |
| 
 | |
| def _coo_spmv_gpu_lowering(ctx, data, row, col, x, *, transpose, shape,
 | |
|                            target_name_prefix):
 | |
|   rows, cols = shape
 | |
|   data_aval, row_aval, _, x_aval = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_coo_matvec_descriptor(
 | |
|       data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype,
 | |
|       rows, cols, nnz, transpose)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matvec_ffi")
 | |
|   return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1]
 | |
| 
 | |
| coo_spmv_p.def_abstract_eval(_coo_spmv_abstract_eval)
 | |
| dispatch.simple_impl(coo_spmv_p)
 | |
| if gpu_sparse.cuda_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     coo_spmv_p,
 | |
|     partial(_coo_spmv_gpu_lowering, target_name_prefix='cu'),
 | |
|     platform='cuda')
 | |
| if gpu_sparse.rocm_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     coo_spmv_p,
 | |
|     partial(_coo_spmv_gpu_lowering, target_name_prefix='hip'),
 | |
|     platform='rocm')
 | |
| 
 | |
| 
 | |
| # coo_spmm_p
 | |
| # This is an internal-only primitive that calls into cusparse COO SpMM.
 | |
| # This is a raw lowering that does no validation of inputs; the indices are
 | |
| # assumed to be lexicographically sorted, deduplicated, and in-bounds.
 | |
| coo_spmm_p = core.Primitive("coo_spmm")
 | |
| 
 | |
| def _coo_spmm_abstract_eval(data, row, col, x, *, transpose, shape):
 | |
|   # TODO(jakevdp) support for batched matmat.
 | |
|   assert data.shape == row.shape == col.shape
 | |
|   assert row.ndim == 1
 | |
|   assert x.ndim == 2
 | |
| 
 | |
|   assert row.dtype == col.dtype
 | |
|   assert row.dtype in SUPPORTED_INDEX_DTYPES
 | |
| 
 | |
|   assert data.dtype == x.dtype
 | |
|   assert x.dtype in SUPPORTED_DATA_DTYPES
 | |
| 
 | |
|   assert len(shape) == 2
 | |
|   assert x.shape[0] == (shape[0] if transpose else shape[1])
 | |
| 
 | |
|   return core.ShapedArray(
 | |
|     shape=(shape[1] if transpose else shape[0], x.shape[1]),
 | |
|     dtype=x.dtype)
 | |
| 
 | |
| def _coo_spmm_gpu_lowering(ctx, data, row, col, x, *, transpose, shape,
 | |
|                            target_name_prefix):
 | |
|   data_aval, row_aval, _, x_aval = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   _, Ccols = x_aval.shape
 | |
| 
 | |
|   batch_count = 1
 | |
|   if len(shape) == 2:
 | |
|     rows, cols = shape
 | |
|   elif len(shape) == 3:
 | |
|     batch_count, rows, cols = shape
 | |
|     nnz = nnz // batch_count
 | |
|   else:
 | |
|     raise NotImplementedError(f"Unsupported shape: {shape}")
 | |
| 
 | |
|   # TODO(tianjianlu): use batch stride to trigger different mode of batch
 | |
|   # computation. Currently batch_stride = 0 is not allowed because of the issue
 | |
|   # in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643
 | |
|   # Set batch stride to be the matrix size for now.
 | |
|   lhs_batch_stride = nnz
 | |
|   B_rows = rows if transpose else cols
 | |
|   rhs_batch_stride =  B_rows * Ccols
 | |
| 
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_coo_matmat_descriptor(
 | |
|       data_aval.dtype, x_aval.dtype, data_aval.dtype, row_aval.dtype,
 | |
|       rows, cols, Ccols, nnz, transpose, batch_count, lhs_batch_stride,
 | |
|       rhs_batch_stride)
 | |
| 
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_matmat_ffi")
 | |
|   return rule(sub_ctx, data, row, col, x, opaque=opaque)[:1]
 | |
| 
 | |
| 
 | |
| coo_spmm_p.def_abstract_eval(_coo_spmm_abstract_eval)
 | |
| dispatch.simple_impl(coo_spmm_p)
 | |
| if gpu_sparse.cuda_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     coo_spmm_p,
 | |
|     partial(_coo_spmm_gpu_lowering, target_name_prefix='cu'),
 | |
|     platform='cuda')
 | |
| if gpu_sparse.rocm_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     coo_spmm_p,
 | |
|     partial(_coo_spmm_gpu_lowering, target_name_prefix='hip'),
 | |
|     platform='rocm')
 | |
| 
 | |
| # csr_spmv_p
 | |
| # This is an internal-only primitive that calls into cusparse csr SpMV.
 | |
| # This is a raw lowering that does no validation of inputs; the indices are
 | |
| # assumed to be lexicographically sorted, deduplicated, and in-bounds.
 | |
| csr_spmv_p = core.Primitive("csr_spmv")
 | |
| 
 | |
| def _csr_spmv_abstract_eval(data, indices, indptr, x, *, transpose, shape):
 | |
|   # TODO(tianjianlu) support for batched matvec.
 | |
|   assert data.ndim == indices.ndim == indptr.ndim == 1
 | |
|   assert data.shape == indices.shape
 | |
|   assert indptr.shape[0] == shape[0] + 1
 | |
|   assert x.ndim == 1
 | |
| 
 | |
|   assert indices.dtype == indptr.dtype
 | |
|   assert indices.dtype in SUPPORTED_INDEX_DTYPES
 | |
|   assert data.dtype == x.dtype
 | |
|   assert x.dtype in SUPPORTED_DATA_DTYPES
 | |
| 
 | |
|   assert len(shape) == 2
 | |
|   assert x.shape[0] == (shape[0] if transpose else shape[1])
 | |
| 
 | |
|   return core.ShapedArray(
 | |
|     shape=shape[1:] if transpose else shape[:1],
 | |
|     dtype=x.dtype)
 | |
| 
 | |
| def _csr_spmv_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape,
 | |
|                            target_name_prefix):
 | |
|   rows, cols = shape
 | |
|   data_aval, indices_aval, _, x_aval = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_csr_matvec_descriptor(
 | |
|       data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype,
 | |
|       rows, cols, nnz, transpose)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matvec_ffi")
 | |
|   return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1]
 | |
| 
 | |
| csr_spmv_p.def_abstract_eval(_csr_spmv_abstract_eval)
 | |
| dispatch.simple_impl(csr_spmv_p)
 | |
| if gpu_sparse.cuda_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     csr_spmv_p,
 | |
|     partial(_csr_spmv_gpu_lowering, target_name_prefix='cu'),
 | |
|     platform='cuda')
 | |
| if gpu_sparse.rocm_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     csr_spmv_p,
 | |
|     partial(_csr_spmv_gpu_lowering, target_name_prefix='hip'),
 | |
|     platform='rocm')
 | |
| 
 | |
| # csr_spmm_p
 | |
| # This is an internal-only primitive that calls into cusparse CSR SpMM.
 | |
| # This is a raw lowering that does no validation of inputs; the indices are
 | |
| # assumed to be lexicographically sorted, deduplicated, and in-bounds.
 | |
| csr_spmm_p = core.Primitive("csr_spmm")
 | |
| 
 | |
| def _csr_spmm_abstract_eval(data, indices, indptr, x, *, transpose, shape):
 | |
|   # TODO(tianjianlu) support for batched matmat.
 | |
|   assert data.ndim == indices.ndim == indptr.ndim == 1
 | |
|   assert data.shape == indices.shape
 | |
|   assert indptr.shape[0] == shape[0] + 1
 | |
|   assert x.ndim == 2
 | |
| 
 | |
|   assert indices.dtype == indptr.dtype
 | |
|   assert indices.dtype in SUPPORTED_INDEX_DTYPES
 | |
|   assert data.dtype == x.dtype
 | |
|   assert x.dtype in SUPPORTED_DATA_DTYPES
 | |
| 
 | |
|   assert len(shape) == 2
 | |
|   assert x.shape[0] == (shape[0] if transpose else shape[1])
 | |
| 
 | |
|   return core.ShapedArray(
 | |
|     shape=(shape[1] if transpose else shape[0], x.shape[1]),
 | |
|     dtype=x.dtype)
 | |
| 
 | |
| def _csr_spmm_gpu_lowering(ctx, data, indices, indptr, x, *, transpose, shape,
 | |
|                            target_name_prefix):
 | |
|   rows, cols = shape
 | |
|   data_aval, indices_aval, _, x_aval = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   _, Ccols = x_aval.shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_csr_matmat_descriptor(
 | |
|       data_aval.dtype, x_aval.dtype, data_aval.dtype, indices_aval.dtype,
 | |
|       rows, cols, Ccols, nnz, transpose)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_matmat_ffi")
 | |
|   return rule(sub_ctx, data, indices, indptr, x, opaque=opaque)[:1]
 | |
| 
 | |
| csr_spmm_p.def_abstract_eval(_csr_spmm_abstract_eval)
 | |
| dispatch.simple_impl(csr_spmm_p)
 | |
| if gpu_sparse.cuda_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     csr_spmm_p,
 | |
|     partial(_csr_spmm_gpu_lowering, target_name_prefix='cu'),
 | |
|     platform='cuda')
 | |
| if gpu_sparse.rocm_is_supported:
 | |
|   mlir.register_lowering(
 | |
|     csr_spmm_p,
 | |
|     partial(_csr_spmm_gpu_lowering, target_name_prefix='hip'),
 | |
|     platform='rocm')
 | |
| 
 | |
| 
 | |
| if has_cpu_sparse:
 | |
|   def _csr_spmm_cpu_lowering(ctx, data, outer_indices, inner_indices, rhs):
 | |
|     rule = ffi.ffi_lowering("cpu_csr_sparse_dense_ffi")
 | |
|     return rule(ctx, data, outer_indices, inner_indices, rhs)
 | |
| 
 | |
| 
 | |
|   # _csr_spmm_cpu_lowering can handle both matrix-matrix and matrix-vector
 | |
|   # multiplication.
 | |
|   mlir.register_lowering(
 | |
|       csr_spmv_p,
 | |
|       _csr_spmm_cpu_lowering,
 | |
|       platform="cpu",
 | |
|   )
 | |
|   mlir.register_lowering(
 | |
|       csr_spmm_p,
 | |
|       _csr_spmm_cpu_lowering,
 | |
|       platform="cpu",
 | |
|   )
 | |
| 
 | |
| def coo_todense_gpu_lowering(ctx, data, row, col, *, shape, target_name_prefix):
 | |
|   data_aval, row_aval, _ = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   rows, cols = shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_coo_todense_descriptor(
 | |
|       data_aval.dtype, row_aval.dtype, rows, cols, nnz)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_todense_ffi")
 | |
|   return rule(sub_ctx, data, row, col, opaque=opaque)[0]
 | |
| 
 | |
| def coo_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix):
 | |
|   mat_aval, = ctx.avals_in
 | |
|   rows, cols = mat_aval.shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_coo_fromdense_descriptor(
 | |
|       mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_coo_fromdense_ffi")
 | |
|   return rule(sub_ctx, mat, opaque=opaque)[:3]
 | |
| 
 | |
| def csr_todense_gpu_lowering(ctx, data, indices, indptr, *, shape, target_name_prefix):
 | |
|   data_aval, indices_aval, _, = ctx.avals_in
 | |
|   nnz, = data_aval.shape
 | |
|   rows, cols = shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_csr_todense_descriptor(
 | |
|       data_aval.dtype, indices_aval.dtype, rows, cols, nnz)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[ctx.avals_out[0], buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_todense_ffi")
 | |
|   return rule(sub_ctx, data, indices, indptr, opaque=opaque)[0]
 | |
| 
 | |
| def csr_fromdense_gpu_lowering(ctx, mat, *, nnz, index_dtype, target_name_prefix):
 | |
|   mat_aval, = ctx.avals_in
 | |
|   rows, cols = mat_aval.shape
 | |
|   buffer_size, opaque = _get_module(target_name_prefix).build_csr_fromdense_descriptor(
 | |
|       mat_aval.dtype, np.dtype(index_dtype), rows, cols, nnz)
 | |
|   buffer_aval = core.ShapedArray(shape=(buffer_size,), dtype=np.int8)
 | |
|   sub_ctx = ctx.replace(avals_out=[*ctx.avals_out, buffer_aval])
 | |
|   rule = ffi.ffi_lowering(f"{target_name_prefix}sparse_csr_fromdense_ffi")
 | |
|   return rule(sub_ctx, mat, opaque=opaque)[:3]
 |