Skip to content

Permute

permute(index, length, seed)

Index into a permuted array with constant space and time complexity.

This function permutes an index i into a range [0, l) using a hash function. See https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and "Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.

Parameters:

Name Type Description Default
index int

The index to permute.

required
length int

The range of the permuted index.

required
seed int

The permutation seed.

required

Returns:

Type Description
int

The permuted index in range(0, length).

Source code in bionemo/core/data/permute.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def permute(index: int, length: int, seed: int) -> int:
    """Index into a permuted array with constant space and time complexity.

    This function permutes an index `i` into a range `[0, l)` using a hash function. See
    https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and
    "Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.

    Args:
        index: The index to permute.
        length: The range of the permuted index.
        seed: The permutation seed.

    Returns:
        The permuted index in range(0, length).
    """
    if length <= 1:
        raise ValueError("The length of the permuted range must be greater than 1.")

    if index not in range(length):
        raise ValueError("The index to permute must be in the range [0, l).")

    if seed < 0:
        raise ValueError("The permutation seed must be greater than or equal to 0.")

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        w = length - 1
        w |= w >> 1
        w |= w >> 2
        w |= w >> 4
        w |= w >> 8
        w |= w >> 16

        while True:
            index ^= seed
            index *= 0xE170893D
            index ^= seed >> 16
            index ^= (index & w) >> 4
            index ^= seed >> 8
            index *= 0x0929EB3F
            index ^= seed >> 23
            index ^= (index & w) >> 1
            index *= 1 | seed >> 27
            index *= 0x6935FA69
            index ^= (index & w) >> 11
            index *= 0x74DCB303
            index ^= (index & w) >> 2
            index *= 0x9E501CC3
            index ^= (index & w) >> 2
            index *= 0xC860A3DF
            index &= w
            if index < length:
                break

    return (index + seed) % length