173 lines
5.8 KiB
Python
173 lines
5.8 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections.abc import Sequence
|
|
from typing import Any, Protocol
|
|
import jax
|
|
from jax._src import random
|
|
from jax._src.typing import Array, ArrayLike
|
|
from jax import numpy as jnp
|
|
|
|
NdKeyList = Any
|
|
Shape = random.Shape
|
|
|
|
class SampleFn(Protocol):
|
|
def __call__(self, key: ArrayLike, *args, shape: Shape,
|
|
**kwargs) -> Array:
|
|
...
|
|
|
|
|
|
def _compute_tile_index(block_index: Sequence[int],
|
|
block_size_in_tiles: Shape,
|
|
total_size_in_tiles: Shape,
|
|
tile_index_in_block: Sequence[int]) -> int:
|
|
ndims = len(block_index)
|
|
dim_size = 1
|
|
total_idx = 0
|
|
for i in range(ndims-1, -1, -1):
|
|
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
|
|
total_idx += dim_idx * dim_size
|
|
dim_size *= total_size_in_tiles[i]
|
|
return total_idx
|
|
|
|
|
|
def blocked_fold_in(
|
|
global_key: ArrayLike,
|
|
total_size: Shape,
|
|
block_size: Shape,
|
|
tile_size: Shape,
|
|
block_index: Sequence[ArrayLike],
|
|
) -> NdKeyList:
|
|
"""Computes a grid of keys for block-invariant sampling.
|
|
|
|
Suppose we wished to construct a 16x512 array of random numbers, using
|
|
block sizes of 16x128 and 16x256. We could select an tile size of 8x128
|
|
(which divides both 16x128 and 16x256) and divide the total array in tiles as:
|
|
---------------------------------
|
|
| 8x128 | 8x128 | 8x128 | 8x128 |
|
|
---------------------------------
|
|
| 8x128 | 8x128 | 8x128 | 8x128 |
|
|
---------------------------------
|
|
|
|
We generate a key for each tile as:
|
|
tile_key = fold_in(global_key, tile_idx)
|
|
|
|
Where the tile_idx is the row-major raveled index of each element:
|
|
-----------------
|
|
| 0 | 1 | 2 | 3 |
|
|
-----------------
|
|
| 4 | 5 | 6 | 7 |
|
|
-----------------
|
|
|
|
We then compute and return the keys required to sample the tiles that make
|
|
up the current block (specified via `block_index`).
|
|
With a 16x256 block size, each block requires 4 (2x2) tile keys:
|
|
---------------
|
|
| 0, 1 | 2, 3 |
|
|
| 4, 5 | 6, 7 |
|
|
---------------
|
|
Therefore, we return a grid of 2x2 keys for each block (2 blocks total).
|
|
|
|
With a 16x128 block size, each block requires 2 (2x1) tile keys:
|
|
-----------------
|
|
| 0 | 1 | 2 | 3 |
|
|
| 4 | 5 | 6 | 7 |
|
|
-----------------
|
|
Therefore, we return a grid of 2x1 keys for each block (4 blocks total).
|
|
|
|
Args:
|
|
global_key: The global key shared between all blocks.
|
|
total_size: The shape of the array being generated.
|
|
block_size: The shape of an individual block.
|
|
tile_size: The shape of a `tile`, which is the smallest unit at
|
|
which samples are generated. This should be selected to be a divisor
|
|
of all block sizes one needs to be invariant to.
|
|
block_index: The index denoting which block to generate keys for.
|
|
|
|
Returns:
|
|
An N-dimensional nested list of keys required to sample the tiles
|
|
corresponding to the block specified by `block_index`.
|
|
"""
|
|
block_size_in_tiles = tuple(
|
|
_shape // _element for _shape, _element in zip(block_size, tile_size)
|
|
)
|
|
|
|
# Round up to make sure every tile is numbered.
|
|
total_size_in_tiles = tuple(
|
|
(_shape + _element - 1) // _element
|
|
for _shape, _element in zip(total_size, tile_size)
|
|
)
|
|
|
|
def _keygen_loop(axis, prefix):
|
|
if axis == len(block_size_in_tiles):
|
|
subtile_key = jax.random.fold_in(
|
|
global_key, _compute_tile_index(
|
|
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
|
|
return subtile_key
|
|
else:
|
|
keys = []
|
|
for i in range(block_size_in_tiles[axis]):
|
|
keys.append(_keygen_loop(axis+1, prefix+(i,)))
|
|
return keys
|
|
return _keygen_loop(0, tuple())
|
|
|
|
|
|
def sample_block(
|
|
sampler_fn: SampleFn,
|
|
keys: NdKeyList,
|
|
block_size: Shape,
|
|
tile_size: Shape,
|
|
*args,
|
|
**kwargs
|
|
) -> jax.Array:
|
|
"""Draws random samples for a single block.
|
|
|
|
This function is intended to be used in conjunction with `blocked_fold_in`:
|
|
```
|
|
key_list = blocked_fold_in(global_key, total_size, block_size, tile_size,
|
|
block_index)
|
|
samples = sample_block(jax.random.uniform, key_list, block_size, tile_size)
|
|
```
|
|
|
|
Args:
|
|
sampler_fn: A random sampling function, e.g. jax.random.uniform.
|
|
keys: A grid of keys generated by `blocked_fold_in`.
|
|
block_size: The shape of an individual block.
|
|
tile_size: The shape of a `tile`, which is the smallest unit at
|
|
which samples are generated. This should be selected to be a divisor
|
|
of all block sizes one needs to be invariant to.
|
|
args: varargs for sampler_fn.
|
|
kwargs: kwargs for sampler_fn.
|
|
|
|
Returns:
|
|
An array of random samples drawn using sampler_fn.
|
|
"""
|
|
size_in_tiles = tuple(
|
|
_shape // _element for _shape, _element in zip(block_size, tile_size))
|
|
def _nested_index(arr: jax.Array, idx: Sequence[int]) -> jax.Array:
|
|
if len(idx) == 1:
|
|
return arr[idx[0]]
|
|
return _nested_index(arr[idx[0]], idx[1:])
|
|
|
|
def _sample_loop(axis: int, prefix: tuple[int, ...]) -> jax.Array:
|
|
if axis == len(size_in_tiles):
|
|
return sampler_fn(_nested_index(keys, prefix), *args,
|
|
shape=tile_size, **kwargs)
|
|
else:
|
|
samples = []
|
|
for i in range(size_in_tiles[axis]):
|
|
samples.append(_sample_loop(axis+1, prefix+(i,)))
|
|
return jnp.concatenate(samples, axis=axis)
|
|
return _sample_loop(0, tuple())
|