401 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			401 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2021 The JAX Authors.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     https://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| """ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.
 | |
| 
 | |
| This package only optimizes the TPU backend. For other device types it fallbacks
 | |
| to sort and slice.
 | |
| 
 | |
| Usage::
 | |
| 
 | |
|   import functools
 | |
|   import jax
 | |
| 
 | |
|   # MIPS := maximal inner product search
 | |
|   # Inputs:
 | |
|   #   qy: f32[qy_size, feature_dim]
 | |
|   #   db: f32[db_size, feature_dim]
 | |
|   #
 | |
|   # Returns:
 | |
|   #   (f32[qy_size, k], i32[qy_size, k])
 | |
|   @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
 | |
|   def mips(qy, db, k=10, recall_target=0.95):
 | |
|     dists = jax.lax.dot(qy, db.transpose())
 | |
|     # Computes max_k along the last dimension
 | |
|     # returns (f32[qy_size, k], i32[qy_size, k])
 | |
|     return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
 | |
| 
 | |
|   # Multi-core example
 | |
|   # Inputs:
 | |
|   #   qy: f32[num_devices, qy_size, feature_dim]
 | |
|   #   db: f32[num_devices, per_device_db_size, feature_dim]
 | |
|   #   db_offset: i32[num_devices]
 | |
|   #   db_size = num_devices * per_device_db_size
 | |
|   #
 | |
|   # Returns:
 | |
|   #   (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
 | |
|   @functools.partial(
 | |
|       jax.pmap,
 | |
|       # static args: db_size, k, recall_target
 | |
|       static_broadcasted_argnums=[3, 4, 5],
 | |
|       out_axes=(1, 1))
 | |
|   def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
 | |
|     dists = jax.lax.dot(qy, db.transpose())
 | |
|     dists, neighbors = jax.lax.approx_max_k(
 | |
|         dists, k=k, recall_target=recall_target,
 | |
|         reduction_input_size_override=db_size)
 | |
|     return (dists, neighbors + db_offset)
 | |
| 
 | |
|   # i32[qy_size, num_devices, k]
 | |
|   pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
 | |
|   # i32[qy_size, num_devices * k]
 | |
|   neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)
 | |
| 
 | |
| Todos::
 | |
| 
 | |
|   * On host top-k aggregation
 | |
|   * Inaccurate but fast differentiation
 | |
| 
 | |
| """
 | |
| 
 | |
| from functools import partial
 | |
| 
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| from jax._src import ad_util
 | |
| from jax._src import core
 | |
| from jax._src import dispatch
 | |
| from jax._src import dtypes
 | |
| from jax._src.interpreters import ad
 | |
| from jax._src.interpreters import batching
 | |
| from jax._src.interpreters import mlir
 | |
| from jax._src.lax import lax
 | |
| from jax._src.lib import _jax
 | |
| from jax._src.lib.mlir import ir
 | |
| from jax._src.lib.mlir.dialects import func
 | |
| from jax._src.lib.mlir.dialects import hlo
 | |
| from jax._src.typing import Array
 | |
| 
 | |
| 
 | |
| def approx_max_k(operand: Array,
 | |
|                  k: int,
 | |
|                  reduction_dimension: int = -1,
 | |
|                  recall_target: float = 0.95,
 | |
|                  reduction_input_size_override: int = -1,
 | |
|                  aggregate_to_topk: bool = True) -> tuple[Array, Array]:
 | |
|   """Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
 | |
| 
 | |
|   See https://arxiv.org/abs/2206.14286 for the algorithm details.
 | |
| 
 | |
|   Args:
 | |
|     operand : Array to search for max-k. Must be a floating number type.
 | |
|     k : Specifies the number of max-k.
 | |
|     reduction_dimension : Integer dimension along which to search. Default: -1.
 | |
|     recall_target : Recall target for the approximation.
 | |
|     reduction_input_size_override : When set to a positive value, it overrides
 | |
|       the size determined by ``operand[reduction_dim]`` for evaluating the
 | |
|       recall. This option is useful when the given ``operand`` is only a subset
 | |
|       of the overall computation in SPMD or distributed pipelines, where the
 | |
|       true input size cannot be deferred by the operand shape.
 | |
|     aggregate_to_topk : When true, aggregates approximate results to the top-k
 | |
|       in sorted order. When false, returns the approximate results unsorted. In
 | |
|       this case, the number of the approximate results is implementation defined
 | |
|       and is greater or equal to the specified ``k``.
 | |
| 
 | |
|   Returns:
 | |
|     Tuple of two arrays. The arrays are the max ``k`` values and the
 | |
|     corresponding indices along the ``reduction_dimension`` of the input
 | |
|     ``operand``. The arrays' dimensions are the same as the input ``operand``
 | |
|     except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
 | |
|     the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
 | |
|     where the size is implementation-defined.
 | |
| 
 | |
|   We encourage users to wrap ``approx_max_k`` with jit. See the following
 | |
|   example for maximal inner production search (MIPS):
 | |
| 
 | |
|   >>> import functools
 | |
|   >>> import jax
 | |
|   >>> import numpy as np
 | |
|   >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
 | |
|   ... def mips(qy, db, k=10, recall_target=0.95):
 | |
|   ...   dists = jax.lax.dot(qy, db.transpose())
 | |
|   ...   # returns (f32[qy_size, k], i32[qy_size, k])
 | |
|   ...   return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
 | |
|   >>>
 | |
|   >>> qy = jax.numpy.array(np.random.rand(50, 64))
 | |
|   >>> db = jax.numpy.array(np.random.rand(1024, 64))
 | |
|   >>> dot_products, neighbors = mips(qy, db, k=10)
 | |
|   """
 | |
|   return approx_top_k_p.bind(
 | |
|       operand,
 | |
|       k=k,
 | |
|       reduction_dimension=reduction_dimension,
 | |
|       recall_target=recall_target,
 | |
|       is_max_k=True,
 | |
|       reduction_input_size_override=reduction_input_size_override,
 | |
|       aggregate_to_topk=aggregate_to_topk)
 | |
| 
 | |
| 
 | |
| def approx_min_k(operand: Array,
 | |
|                  k: int,
 | |
|                  reduction_dimension: int = -1,
 | |
|                  recall_target: float = 0.95,
 | |
|                  reduction_input_size_override: int = -1,
 | |
|                  aggregate_to_topk: bool = True) -> tuple[Array, Array]:
 | |
|   """Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
 | |
| 
 | |
|   See https://arxiv.org/abs/2206.14286 for the algorithm details.
 | |
| 
 | |
|   Args:
 | |
|     operand : Array to search for min-k. Must be a floating number type.
 | |
|     k : Specifies the number of min-k.
 | |
|     reduction_dimension: Integer dimension along which to search. Default: -1.
 | |
|     recall_target: Recall target for the approximation.
 | |
|     reduction_input_size_override : When set to a positive value, it overrides
 | |
|       the size determined by ``operand[reduction_dim]`` for evaluating the
 | |
|       recall. This option is useful when the given operand is only a subset of
 | |
|       the overall computation in SPMD or distributed pipelines, where the true
 | |
|       input size cannot be deferred by the ``operand`` shape.
 | |
|     aggregate_to_topk : When true, aggregates approximate results to the top-k
 | |
|       in sorted order. When false, returns the approximate results unsorted. In
 | |
|       this case, the number of the approximate results is implementation defined
 | |
|       and is greater or equal to the specified ``k``.
 | |
| 
 | |
|   Returns:
 | |
|     Tuple of two arrays. The arrays are the least ``k`` values and the
 | |
|     corresponding indices along the ``reduction_dimension`` of the input
 | |
|     ``operand``.  The arrays' dimensions are the same as the input ``operand``
 | |
|     except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
 | |
|     the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
 | |
|     where the size is implementation-defined.
 | |
| 
 | |
|   We encourage users to wrap ``approx_min_k`` with jit. See the following example
 | |
|   for nearest neighbor search over the squared l2 distance:
 | |
| 
 | |
|   >>> import functools
 | |
|   >>> import jax
 | |
|   >>> import numpy as np
 | |
|   >>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
 | |
|   ... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
 | |
|   ...   dists = half_db_norms - jax.lax.dot(qy, db.transpose())
 | |
|   ...   return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
 | |
|   >>>
 | |
|   >>> qy = jax.numpy.array(np.random.rand(50, 64))
 | |
|   >>> db = jax.numpy.array(np.random.rand(1024, 64))
 | |
|   >>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
 | |
|   >>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
 | |
| 
 | |
|   In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
 | |
|   ``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
 | |
|   arithmetic and produces the same set of neighbors.
 | |
|   """
 | |
|   return approx_top_k_p.bind(
 | |
|       operand,
 | |
|       k=k,
 | |
|       reduction_dimension=reduction_dimension,
 | |
|       recall_target=recall_target,
 | |
|       is_max_k=False,
 | |
|       reduction_input_size_override=reduction_input_size_override,
 | |
|       aggregate_to_topk=aggregate_to_topk)
 | |
| 
 | |
| 
 | |
| def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
 | |
|                                 recall_target, is_max_k,
 | |
|                                 reduction_input_size_override,
 | |
|                                 aggregate_to_topk):
 | |
|   if k <= 0:
 | |
|     raise ValueError(f'k must be positive, got {k}')
 | |
|   if len(operand.shape) == 0:
 | |
|     raise TypeError('approx_top_k operand must have >= 1 dimension, got {}'.format(
 | |
|         operand.shape))
 | |
|   dims = list(operand.shape)
 | |
|   if dims[reduction_dimension] < k:
 | |
|     raise ValueError(
 | |
|         'k must be smaller than the size of reduction_dim {}, got {}'.format(
 | |
|             dims[reduction_dimension], k))
 | |
|   if not dtypes.issubdtype(operand.dtype, np.floating):
 | |
|     raise ValueError('operand must be a floating type')
 | |
|   reduction_input_size = dims[reduction_dimension]
 | |
|   if aggregate_to_topk:
 | |
|     dims[reduction_dimension] = k
 | |
|   elif core.is_constant_shape((reduction_input_size, k)):
 | |
|     dims[reduction_dimension] = _jax.approx_top_k_reduction_output_size(
 | |
|         reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
 | |
|         reduction_input_size_override)[0]
 | |
|   else:
 | |
|     raise NotImplementedError(
 | |
|          "approx_top_k with aggregate_to_topk=False not yet implemented when "
 | |
|          f"either the `k` ({k}) or the "
 | |
|          f" reduction dimension size ({reduction_input_size}) are symbolic")
 | |
|   return (operand.update(shape=dims, dtype=operand.dtype,
 | |
|                          weak_type=operand.weak_type, vma=operand.vma),
 | |
|           operand.update(shape=dims, dtype=np.dtype(np.int32), vma=operand.vma))
 | |
| 
 | |
| def _get_init_val_literal(op_type, is_max_k):
 | |
|   return np.array(-np.inf if is_max_k else np.inf, dtype=op_type)
 | |
| 
 | |
| def _comparator_builder_mlir(ctx, op_type, is_max_k):
 | |
|   scalar = ir.RankedTensorType.get([], op_type)
 | |
|   index = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32))
 | |
|   ir_types = [scalar, scalar, index, index]
 | |
|   result_types = [ir.RankedTensorType.get([], ir.IntegerType.get_signless(1))]
 | |
| 
 | |
|   comparator_type = ir.FunctionType.get(ir_types, result_types)
 | |
|   with ir.InsertionPoint.at_block_begin(ctx.module_context.module.body):
 | |
|     comparator = func.FuncOp(
 | |
|         "top_k_{}_{}_comparator".format('gt' if is_max_k else 'lt', op_type),
 | |
|         comparator_type)
 | |
|   ctx.module_context.symbol_table.insert(comparator)
 | |
| 
 | |
|   entry_block = comparator.add_entry_block()
 | |
|   with ir.InsertionPoint(entry_block):
 | |
|     p0, p1, _, _ = entry_block.arguments
 | |
|     direction = hlo.ComparisonDirectionAttr.get('GT' if is_max_k else 'LT')
 | |
|     cmp_result = hlo.compare(p0, p1, comparison_direction=direction)
 | |
|     hlo.return_([cmp_result])
 | |
| 
 | |
|   return comparator
 | |
| 
 | |
| def _approx_top_k_lowering(ctx, operand, *, k,
 | |
|                                   reduction_dimension, recall_target, is_max_k,
 | |
|                                   reduction_input_size_override,
 | |
|                                   aggregate_to_topk, fallback=False):
 | |
|   assert ctx.avals_in
 | |
|   assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in)
 | |
| 
 | |
|   op_shape = ctx.avals_in[0].shape
 | |
|   if len(op_shape) == 0:
 | |
|     raise ValueError(f'operand must be an array, but was {op_shape}')
 | |
| 
 | |
|   op_dims = op_shape
 | |
|   op_type = mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)
 | |
|   recall_type = ir.F32Type.get()
 | |
|   if reduction_dimension < 0:
 | |
|     reduction_dimension = len(op_dims) + reduction_dimension
 | |
| 
 | |
|   comparator = _comparator_builder_mlir(ctx, op_type, is_max_k)
 | |
|   iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32),
 | |
|                    dimension=reduction_dimension)
 | |
| 
 | |
|   init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1)))
 | |
|   init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k)
 | |
|   init_val = mlir.ir_constant(init_val_array.reshape(()))
 | |
| 
 | |
|   backend_config = {
 | |
|     "reduction_dim" : mlir.i64_attr(reduction_dimension),
 | |
|     "recall_target" : mlir.ir.FloatAttr.get(recall_type, recall_target),
 | |
|     "aggregate_to_topk" : mlir.ir.BoolAttr.get(aggregate_to_topk),
 | |
|     "reduction_input_size_override" :
 | |
|       mlir.i64_attr(reduction_input_size_override)}
 | |
|   if fallback:
 | |
|     backend_config["is_fallback"] = mlir.ir.BoolAttr.get(fallback)
 | |
| 
 | |
|   if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
 | |
|     result_shapes = None
 | |
|   else:
 | |
|     result_shapes = [
 | |
|         mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
 | |
|         for aval_out in ctx.avals_out]
 | |
| 
 | |
|   if core.is_constant_dim(k):
 | |
|     backend_config["top_k"] = mlir.i64_attr(k)
 | |
|     out = mlir.custom_call(
 | |
|         "ApproxTopK",
 | |
|         result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
 | |
|         operands=[operand, iota, init_val, init_arg],
 | |
|         called_computations=[comparator.name.value],
 | |
|         backend_config=backend_config,
 | |
|         result_shapes=result_shapes)
 | |
|   else:
 | |
|     k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
 | |
|     out = mlir.custom_call(
 | |
|         "stablehlo.dynamic_approx_top_k",
 | |
|         result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
 | |
|         operands=[operand, iota, init_val, init_arg, k_value],
 | |
|         called_computations=[comparator.name.value],
 | |
|         backend_config=backend_config,
 | |
|         result_shapes=result_shapes)
 | |
| 
 | |
|   return out.results
 | |
| 
 | |
| def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
 | |
|                              reduction_dimension, recall_target, is_max_k,
 | |
|                              reduction_input_size_override, aggregate_to_topk):
 | |
|   assert len(batch_operands) == 1
 | |
|   assert len(batch_axes) == 1
 | |
|   operand, = batch_operands
 | |
|   batch_axis, = batch_axes
 | |
|   dim_map = [d for d in range(operand.ndim) if d is not batch_axis]
 | |
|   reduction_dimension = dim_map[reduction_dimension]
 | |
|   return approx_top_k_p.bind(
 | |
|       operand,
 | |
|       k=k,
 | |
|       reduction_dimension=reduction_dimension,
 | |
|       recall_target=recall_target,
 | |
|       is_max_k=is_max_k,
 | |
|       reduction_input_size_override=reduction_input_size_override,
 | |
|       aggregate_to_topk=aggregate_to_topk), (batch_axis, batch_axis)
 | |
| 
 | |
| 
 | |
| # Slow jvp implementation using gather.
 | |
| #
 | |
| # TODO(fchern): Some optimization ideas
 | |
| # 1. ApproxTopK is internally a variadic reduce, so we can simply call
 | |
| #    ApproxTopK(operand, tangent, iota) for jvp.
 | |
| # 2. vjp cannot benefit from the algorithm above. We must run scatter to
 | |
| #    distribute the output cotangent to input cotangent. A reasonable way to do
 | |
| #    this is to run it on CPU.
 | |
| def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
 | |
|                       recall_target, is_max_k, reduction_input_size_override,
 | |
|                       aggregate_to_topk):
 | |
|   operand, = primals
 | |
|   tangent, = tangents
 | |
|   if is_max_k:
 | |
|     val_out, arg_out = approx_max_k(operand, k, reduction_dimension,
 | |
|                                     recall_target,
 | |
|                                     reduction_input_size_override,
 | |
|                                     aggregate_to_topk)
 | |
|   else:
 | |
|     val_out, arg_out = approx_min_k(operand, k, reduction_dimension,
 | |
|                                     recall_target,
 | |
|                                     reduction_input_size_override,
 | |
|                                     aggregate_to_topk)
 | |
|   if type(tangent) is ad_util.Zero:
 | |
|     tangent_out = ad_util.Zero.from_primal_value(val_out)
 | |
|   else:
 | |
|     arg_shape = arg_out.shape
 | |
|     rank = len(arg_shape)
 | |
|     if reduction_dimension < 0:
 | |
|       reduction_dimension += rank
 | |
|     iotas = [
 | |
|         lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank)
 | |
|     ]
 | |
|     idx = tuple(
 | |
|         arg_out if i == reduction_dimension else iotas[i] for i in range(rank))
 | |
|     tangent_out = tangent[idx]
 | |
|   return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out))
 | |
| 
 | |
| 
 | |
| approx_top_k_p = core.Primitive('approx_top_k')
 | |
| approx_top_k_p.multiple_results = True
 | |
| approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
 | |
| approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
 | |
| mlir.register_lowering(approx_top_k_p,
 | |
|                       partial(_approx_top_k_lowering, fallback=True))
 | |
| mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
 | |
|                         platform='tpu')
 | |
| batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
 | |
| ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
 |