# Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence from typing import Any, Protocol import jax from jax._src import random from jax._src.typing import Array, ArrayLike from jax import numpy as jnp NdKeyList = Any Shape = random.Shape class SampleFn(Protocol): def __call__(self, key: ArrayLike, *args, shape: Shape, **kwargs) -> Array: ... def _compute_tile_index(block_index: Sequence[int], block_size_in_tiles: Shape, total_size_in_tiles: Shape, tile_index_in_block: Sequence[int]) -> int: ndims = len(block_index) dim_size = 1 total_idx = 0 for i in range(ndims-1, -1, -1): dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i] total_idx += dim_idx * dim_size dim_size *= total_size_in_tiles[i] return total_idx def blocked_fold_in( global_key: ArrayLike, total_size: Shape, block_size: Shape, tile_size: Shape, block_index: Sequence[ArrayLike], ) -> NdKeyList: """Computes a grid of keys for block-invariant sampling. Suppose we wished to construct a 16x512 array of random numbers, using block sizes of 16x128 and 16x256. We could select an tile size of 8x128 (which divides both 16x128 and 16x256) and divide the total array in tiles as: --------------------------------- | 8x128 | 8x128 | 8x128 | 8x128 | --------------------------------- | 8x128 | 8x128 | 8x128 | 8x128 | --------------------------------- We generate a key for each tile as: tile_key = fold_in(global_key, tile_idx) Where the tile_idx is the row-major raveled index of each element: ----------------- | 0 | 1 | 2 | 3 | ----------------- | 4 | 5 | 6 | 7 | ----------------- We then compute and return the keys required to sample the tiles that make up the current block (specified via `block_index`). With a 16x256 block size, each block requires 4 (2x2) tile keys: --------------- | 0, 1 | 2, 3 | | 4, 5 | 6, 7 | --------------- Therefore, we return a grid of 2x2 keys for each block (2 blocks total). With a 16x128 block size, each block requires 2 (2x1) tile keys: ----------------- | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | ----------------- Therefore, we return a grid of 2x1 keys for each block (4 blocks total). Args: global_key: The global key shared between all blocks. total_size: The shape of the array being generated. block_size: The shape of an individual block. tile_size: The shape of a `tile`, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to. block_index: The index denoting which block to generate keys for. Returns: An N-dimensional nested list of keys required to sample the tiles corresponding to the block specified by `block_index`. """ block_size_in_tiles = tuple( _shape // _element for _shape, _element in zip(block_size, tile_size) ) # Round up to make sure every tile is numbered. total_size_in_tiles = tuple( (_shape + _element - 1) // _element for _shape, _element in zip(total_size, tile_size) ) def _keygen_loop(axis, prefix): if axis == len(block_size_in_tiles): subtile_key = jax.random.fold_in( global_key, _compute_tile_index( block_index, block_size_in_tiles, total_size_in_tiles, prefix)) return subtile_key else: keys = [] for i in range(block_size_in_tiles[axis]): keys.append(_keygen_loop(axis+1, prefix+(i,))) return keys return _keygen_loop(0, tuple()) def sample_block( sampler_fn: SampleFn, keys: NdKeyList, block_size: Shape, tile_size: Shape, *args, **kwargs ) -> jax.Array: """Draws random samples for a single block. This function is intended to be used in conjunction with `blocked_fold_in`: ``` key_list = blocked_fold_in(global_key, total_size, block_size, tile_size, block_index) samples = sample_block(jax.random.uniform, key_list, block_size, tile_size) ``` Args: sampler_fn: A random sampling function, e.g. jax.random.uniform. keys: A grid of keys generated by `blocked_fold_in`. block_size: The shape of an individual block. tile_size: The shape of a `tile`, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to. args: varargs for sampler_fn. kwargs: kwargs for sampler_fn. Returns: An array of random samples drawn using sampler_fn. """ size_in_tiles = tuple( _shape // _element for _shape, _element in zip(block_size, tile_size)) def _nested_index(arr: jax.Array, idx: Sequence[int]) -> jax.Array: if len(idx) == 1: return arr[idx[0]] return _nested_index(arr[idx[0]], idx[1:]) def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array: if axis == len(size_in_tiles): return sampler_fn(_nested_index(keys, prefix), *args, shape=tile_size, **kwargs) else: samples = [] for i in range(size_in_tiles[axis]): samples.append(_sample_loop(axis+1, prefix+(i,))) return jnp.concatenate(samples, axis=axis) return _sample_loop(0, tuple())