functional
Some supportive functions.
Classes
An universal tensor clip function. |
|
The fast Hadamard transform. |
Functions
Normalized fast hadamard transform. |
- class ClipFunction
Bases:
FunctionAn universal tensor clip function.
Pytorch’s clamp() only supports scalar range and doesn’t support broadcast. This implementation uses min/max which is more general. The gradient is defined according to IBM’s PACT paper https://arxiv.org/abs/1805.06085, which is also the behavior of Tensorflow’s clip_by_value()
- static backward(ctx, grad_output)
Backward pass for the clip function.
- static forward(ctx, input, clip_value_min, clip_value_max)
Forward pass for the clip function.
- class FastHadamardTransform
Bases:
FunctionThe fast Hadamard transform.
This only works for inputs.shape[-1] == power of 2.
- static backward(ctx, grad_outputs)
Hadamard backward.
- static forward(ctx, inputs)
Hadamard forward.
- normalized_hadamard_transform(inputs, rotate_fp32=False, block_size=None)
Normalized fast hadamard transform.
Supports block-granular RHT for dimensions that are not a power of 2. When block_size is used, the last dimension is split into blocks of size block_size (must be power of 2), and Hadamard is applied per block. This enables RHT for MoE expert channel dimensions (e.g. 1920, 1536, 896) that are not powers of 2.
- Parameters:
inputs – Input tensor, Hadamard is applied along the last dimension.
rotate_fp32 – If True, compute rotation in float32.
block_size – Block size for block-granular RHT. Must be power of 2 and divide inputs.shape[-1]. If None: use full-dimension FHT when dim is power of 2; otherwise auto-select the largest power-of-2 divisor of the dimension.
- Returns:
Rotated tensor with same shape as inputs.