Multi-Head, Multi-Query, and Group-Query Attention

This document details the implementation of multi-head attention (MHA), multi-query attention (MQA) and group-query attention (GQA) for auto-regressive GPT-like models in TensorRT-LLM. As a quick reminder, the multi-head attention is the sequence of a batched matmul, a softmax and another batched matmul described in the Attention Is All You Need article. Multi-query Attention (MQA) and Group-query Attention (GQA) are variants of MHA that use fewer, so-called, K/V head than the number of query heads. TensorRT-LLM, MHA, MQA and GQA are implemented by the operator tensorrt_llm.functional.gpt_attention.

Important Note

As discussed below, the current implementation supports two input modes: Padded and packed (non-padded). As the packed mode is always more memory-efficient and faster than the padded mode, support for padded mode may be removed in the future.

Padded and Packed Tensors

In TensorRT-LLM, the GPT attention operator supports two different types of QKV inputs: Padded and packed (i.e. non padded) inputs. The mode is determined by the global configuration parameter remove_input_padding defined in tensorrt_llm.plugin.

When padding is enabled (that is, remove_input_padding is False), the sequences that are shorter than the max_sequence_length are padded to that maximum length. It may result in excessive memory consumption as well as unneeded computations on padding tokens (in the various matrix multiplications that surround the MHA block).

To overcome that problem, TensorRT-LLM supports a mode without padding where the different tokens are packed together and the user provides the operator with a 1D tensor containing the lengths of the different sequences. It is recommended that users to always use packed mode (and support for the padded mode may be removed in the future).

Context and Generation Phases

The GPT attention operator encapsulates different implementations for both context and generation phases in auto-regressive models like GPT.

Context Phase

If the context_fmha_type is set to disabled (refer to tensorrt_llm.plugin), the implementation maps to a sequence of GPU kernels that will store the intermediate Q*K^T tensor in memory before calling the softmax operator. It is the slowest method and the memory footprint is significant (quadratically depends on the sequence length).

Otherwise, if context_fmha_type is set to a enabled or enabled_with_fp32_acc (accumulation in the first batched matmul is forced to FP32), that function will trigger a kernel that performs the MHA/MQA block using a single kernel. For short sequences, that kernel uses a vanilla implementation of MHA/MQA. For larger sequences, this kernel uses the Flash Attention algorithm as described in FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.

Currently, the implementation triggers extra kernels that apply pre-processing to the elements (like RoPE) and populate the KV cache (see below). In a future release, the number of such kernels is planned on being reduced in order to improve the overall performance.

FP8 Context FMHA

When FP8 quantization is activated, the attention can be further accelerated by enabling FP8 Context FMHA (use_fp8_context_fmha = enable).

FP8 Paged Context FMHA is also supported with the fp8 quantization workflow. You need to specify use_fp8_context_fmha = enable and use_paged_context_fmha = enable at the same time.

Please be aware that this is an experimental feature only supported on Hopper. If you notice a significant decrease in accuracy, it is recommended to disable it.

Generation Phase

The generation phase is implemented using a single kernel called the masked multi-head attention in TensorRT-LLM. That kernel is able to apply pre-processing on the Q, K, and V elements on-the-fly: adds the QKV bias, applies RoPE, and performs dequantization and quantization. TensorRT-LLM will continue to add (or enable) additional features in future releases. For example, enable the support for IA3.

The masked MHA kernel has a special version that distributes the work across multiple CUDA thread-blocks on the GPU for cases where the GPU occupancy is low. That mode called multi-block is turned on by default starting from TRT-LLM 0.13, and can be disabled using --multi_block_mode=False during runtime. Users are recommended to test that mode in scenarios where both the batch size and the number of heads in the model are relatively small. The exact definition of small in that context will depend on the model of the GPU and is hard to predict but to provide with a rule of thumb, it is worth testing that mode when batch_size * num_heads is less than the number of multi-processors on the GPU (that suggestion may evolve in the future as more research is conducted and the software improves).

Note that even if the multi-block mode is enabled, the attention operator will not immediately trigger the multi-block version of the GPU kernel. There is a minimum number of tokens (input + generated) that are required for the multi-block version to become more efficient than the “vanilla” implementation that uses a single CUDA thread-block per head. It is controlled by an internal heuristic.

Another note is that as the masked MHA kernels use shared memory size proportional to sequence length, so there can be some cases that GPU’s shared memory is not enough when multi-block mode is not enabled. To get masked MHA kernel work in these cases, multi-block mode is forced on and a warning log is printed.

XQA Optimization

Another optimization for MQA/GQA in generation phase called XQA optimization. It is still experimental feature and support limited configurations. LLAMA2 70B is one model that it supports.

Support matrix of the XQA optimization:

  • FP16 / BF16 compute data type.

  • FP16 / BF16 / FP8 / INT8 KV cache data type.

  • Paged KV cache (64 / 128 tokens per block).

This is default enabled. To disable this, you need to use the flag --disable_xqa when building the engines. Note that a heuristic algorithm is also used to decide whether to use XQA kernel or masked MHA kernel to get better performance. That means even --disable_xqa is not set, XQA kernels may not also be used. If you want to always use that kernel when possible, TRTLLM_FORCE_XQA=1 can be set to force use XQA kernels when the model config is supported. Detailed supported configuration can be found function shouldUse of class DecoderXQARunner in cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h.

In-flight Batching

TensorRT-LLM supports in-flight batching of requests (also known as continuous batching or iteration-level batching) for higher serving throughput. With this feature, sequences in context phase can be processed together with sequences in generation phase. The purpose of that technique is to better interleave requests to reduce latency as well as make better use of the GPUs. For efficiency reasons (1), the support for inflight batching requires the input tensors to be packed (no padding).

In the current implementation, the sequences that are going through the context phase must be before the sequences in the generation phase in the input tensor. For example, for sequences S0, S1 and S2, if S0 and S2 are in context phase (and S1 in generation), tokens from S0 and S2 must appear before the tokens of S1 in the input tensor. The constraint may or may not be relaxed in a future version.

(1) Padding sequences in the generation phase, that contain a single token, to the length of the maximum input sequence is inefficient use of resources.

Chunked Context

In the original state, the common behavior was to process all context tokens at once. This feature splits the context into several chunks. In this way, the context chunks can be batched with more tokens during the generation phase, which is expected to increase the total throughput. Chunking contexts also removes constraints on input length. To enable this feature, the FMHA paged kv-cache also needs to be enabled. Except for the last one, the size of the context chunk needs to be an integer multiple of the kv-cache block size. Refer to the performance best practices for usage.

KV Cache

In the generation phase, a common optimization is to provide the MHA kernel with a cache containing the values of the past K and V elements that have already been computed. That cache is known as the KV cache. TensorRT-LLM uses that technique to accelerate its generation phase. In TensorRT-LLM, there is one KV cache per Transformer layer, which means that there are as many KV caches as layers in a model. The current version of TensorRT-LLM supports two different types of KV caches: contiguous and paged KV caches.

Contiguous KV Cache

The contiguous KV cache is a monolithic tensor. Its shape is:

[max_batch_size * max_beam_width, 2, num_heads, max_seqlen, hidden_dim_per_head].

That implementation uses a lot more memory than needed when the sequences are shorter than the maximum sequence length (even if they end up close to the limit after the generation of many output tokens, it may take a lot of steps to reach that point).

Paged KV Cache

The paged KV cache decomposes the KV cache into blocks that are distributed to the different requests by a cache manager during processing. That cache manager keeps track of the sequences, allocate new blocks from a pool and recycle those blocks when required. See the simplified implementation of tensorrt_llm.runtime.KVCacheManager. A more efficient C++ implementation is included in the Batch Manager.

INT8/FP8 KV Caches

In its current implementation, even if the rest of the network runs in INT8 or FP8, the GPT attention operator works with FP32, FP16, and BFloat16 inputs and outputs. However, TensorRT-LLM supports INT8 and FP8 (kv_cache_quant_mode=QuantMode.INT8_KV_CACHE and kv_cache_quant_mode=QuantMode.FP8_KV_CACHE) KV caches.

The GPT attention operator populates the KV cache. When INT8 or FP8 KV caches are enabled, the input values have to be quantized to 8 bits using a scaling factor. For quantization, the scaling factor is stored in the kv_cache_scaling_factor tensor. Its shape is [1] and only per-tensor quantization is supported in the current version. Quantization uses inversed scale since it does multiply as fp_value * (1.0 / kv_cache_scaling_factor) in plugin.

During generation, the values read from the cache are dequantized on-the-fly in the MHA/MQA kernel, dequantization can be described as quantized_value * kv_cache_scaling_factor.

Sliding Window Attention, Cyclic (Rolling Buffer) KV Cache

TensorRT-LLM has a feature called Cyclic KV Cache, which treats the kv cache as a circular buffer. This means that it only stores the kv cache for the last N tokens, where N is determined by the max_attention_window_size parameter in GenerationSession.setup. You can see examples of this in the run.py or summarize.py files. When the cache is full, new tokens’ kv cache will overwrite the “least recently used” caches.

In the context phase, if the input length surpasses the max_attention_window_size, Sliding Window Attention will be activated. This serves the same function as the sliding window_size.

This feature helps to reduce the memory footprint of the kv cache when dealing with very long sequences.

The feature, which allows different max_attention_window_size values for each layer, is also supported. To utilize this feature, simply provide an int32 torch.Tensor or list to the GenerationSession.setup when using python runtime session, or provide a vector to the KvCacheConfig when using cpp runtime. If the number of the provided elements is less than the number of layers, the provided tensor/list/vector will be repeated multiple times to the number of layers and then be saved as a new tensor. This tensor will serve as the buffer for max_attention_window_size, setting unique values for each layer. However, it’s important to note that the memory allocation for the kv cache still relies on the buffer’s maximum value.

_Note that the cyclic kv cache feature doesn’t work with beam searching currently as the context kv cache are shared across beams.

StreamingLLM

The StreamingLLM feature uses a window attention to perform efficient and stable LLM on long texts, which means that only N tokens need to be stored in the KV cache. Similar to the cyclic KV cache feature in TensorRT-LLM, max_attention_window_size parameter is used to determine N. Different from the cyclic KV cache feature, the first S tokens, called sink tokens, are always kept in the attention window, where S is determined by sink_token_length parameter in GenerationSession.setup. But in context phase, the self-attentions is dense in the official implementation of StreamingLLM, and it uses all of the tokens for computation and only saves N tokens to the KV cache.

In addition, the relative position embedding is also changed in StreamingLLM. When determining the relative distance and adding positional information to tokens, StreamingLLM use the positions within the cache rather than those in the original text.

streamingllm flag is used to enable this feature.

Input QKV tensor

The input QKV tensor packs the Q, K and V tensors (concatenated along the last dimension) after the projection of the hidden states. It is a 3D tensor. RoPE and quantization to INT8 or FP8 (when needed) are performed by the GPT attention operator.

In padded mode, its shape is [batch_beam_size, max_seqlen, 3 * hidden_dim] where batch_beam_size is the batch size (number of sequences) for the context phase and the batch size multiplied by the beam width for the generation phase. Having different beam widths per sequence in padded mode is not supported.

In packed mode, its shape is [num_tokens, 3 * hidden_dim] where num_tokens is the total number of tokens in the batch. For the sequences in context phase, the number of tokens of a sequence corresponds to its input length (even if the beam width is greater than 1 for beam search). For the sequences in generation phase, there are beam_width tokens per sequence. The beam width can be different for each sequence.

In other words, the pseudo-code to compute the number of tokens is:

num_tokens = 0

# Add the length of each sequence in context phase.
for seq in context_phase:
    num_tokens += seq.length

# Add the width of the beam for each sequence in generation phase.
for seq in generation_phase:
    num_tokens += seq.beam_width

Rotary Positional Embedding (RoPE)

The GPT attention operation can perform the computation of the Rotary Positional Embedding (RoPE). When that operation is enabled, rotary_embedding_dim is set to a value greater than 0, it is fused with other operations. The GPT operator supports GPT-NeoX and GPT-J forms of RoPE by setting position_embedding_type to PositionEmbeddingType.rope_gpt_neox or PositionEmbeddingType.rope_gptj.

ALiBi

The GPT attention operator can apply ALiBi to the result of the Q*K^T product. The bias is computed on-the-fly from the ALiBi slopes in the optimized kernel.

Scaling factor(s)

In MHA, the output of the Q*K^T product is scaled by a constant value that is computed as:

norm_factor = 1.f / (q_scaling * sqrt(head_size)).

Cross Attention

On top of the MHA as self attention needed by GPT-style decoder-only models, gpt_attention also supports cross attention.

This enables using gpt_attention in a broader aspect as a generic decoder component. For example, the Encoder-Decoder model uses gpt_attention to issue both the self attention and cross attention modules in its Decoder.

Relative Attention Bias (RAB)

Relative attention bias (RAB) is a kind of relative position modeling, adding an attention bias (Q*K^T+bias) according to relative positions. RAB is a lightweight method to include the information of relative positions, and is used in the popular Encoder-Decoder model T5 and also other models in the T5 family.

RAB is supported in two modes: i) regular mode which user passes in relative attention bias computed ahead of MHA. ii) implicit mode which computes the relative attention bias on the fly in MHA. The implicit mode suits the case when the relative attention bias is too large to fit in memory and can be turned on by passing in max_distance.