Sparse Attention#

Source NVIDIA/TensorRT-LLM.

  1"""
  2This example demonstrates how to use sparse attention with TensorRT-LLM.
  3
  4Supported sparse attention algorithms:
  5- RocketKV
  6- DSA
  7
  8Usage:
  9```bash
 10python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
 11```
 12"""
 13import argparse
 14import json
 15
 16from tensorrt_llm import LLM, SamplingParams
 17from tensorrt_llm.llmapi import (CudaGraphConfig, DeepSeekSparseAttentionConfig,
 18                                 KvCacheConfig, MoeConfig,
 19                                 RocketSparseAttentionConfig)
 20
 21
 22def read_input(input_file):
 23    results = []
 24    with open(input_file, 'r') as f:
 25        for line in f:
 26            ret = json.loads(line)
 27            results.append(ret)
 28    return results
 29
 30
 31def parse_arguments():
 32    parser = argparse.ArgumentParser()
 33    parser.add_argument(
 34        '--model_path',
 35        type=str,
 36        default=
 37        "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
 38    )
 39    parser.add_argument(
 40        '--input_file',
 41        type=str,
 42        default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
 43    )
 44
 45    # Build config
 46    parser.add_argument('--algo',
 47                        type=str,
 48                        default='ROCKETKV',
 49                        choices=['ROCKETKV', 'DSA'])
 50    parser.add_argument('--attention_backend',
 51                        type=str,
 52                        default='TRTLLM',
 53                        choices=['VANILLA', 'TRTLLM'])
 54
 55    # RocketKV config
 56    parser.add_argument('--window_size',
 57                        type=int,
 58                        default=32,
 59                        help="The window size for RocketKV.")
 60    parser.add_argument('--kernel_size',
 61                        type=int,
 62                        default=63,
 63                        help="The kernel size for RocketKV.")
 64    parser.add_argument('--prompt_budget',
 65                        type=int,
 66                        default=2048,
 67                        help="The prompt budget for RocketKV.")
 68    parser.add_argument('--topk',
 69                        type=int,
 70                        default=64,
 71                        help='Top-k for RocketKV')
 72    parser.add_argument('--kt_cache_dtype',
 73                        type=str,
 74                        default='float8_e5m2',
 75                        choices=['bfloat16', 'float8_e5m2'])
 76    parser.add_argument('--index_max_chunk_size',
 77                        type=int,
 78                        default=32768,
 79                        help="The maximum chunk size for the indexer.")
 80    parser.add_argument("--max_seq_len",
 81                        type=int,
 82                        default=10240,
 83                        help="The maximum sequence length.")
 84    parser.add_argument("--max_batch_size",
 85                        type=int,
 86                        default=256,
 87                        help="The maximum batch size.")
 88    parser.add_argument("--max_new_tokens",
 89                        type=int,
 90                        default=128,
 91                        help="The maximum new tokens.")
 92    parser.add_argument(
 93        "--max_num_tokens",
 94        type=int,
 95        default=81920,
 96        help=
 97        "The maximum total tokens (context + generation) across all sequences in a batch."
 98    )
 99
100    # Parallelism
101    parser.add_argument('--moe_backend',
102                        type=str,
103                        default='CUTLASS',
104                        choices=[
105                            'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
106                            'DEEPGEMM', 'CUTEDSL', 'TRITON'
107                        ])
108    parser.add_argument('--tp_size', type=int, default=1)
109    parser.add_argument('--moe_ep_size', type=int, default=-1)
110    parser.add_argument('--enable_attention_dp',
111                        default=False,
112                        action='store_true')
113
114    # KV cache
115    parser.add_argument('--kv_cache_dtype', type=str, default='auto')
116    parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
117    parser.add_argument('--tokens_per_block', type=int, default=32)
118    parser.add_argument('--num_samples', type=int, default=10)
119
120    # Runtime
121    parser.add_argument('--print_iter_log',
122                        default=False,
123                        action='store_true',
124                        help='Print iteration logs during execution')
125    parser.add_argument('--use_cuda_graph', default=False, action='store_true')
126    parser.add_argument('--cuda_graph_padding_enabled',
127                        default=False,
128                        action='store_true')
129    parser.add_argument('--cuda_graph_batch_sizes',
130                        nargs='+',
131                        type=int,
132                        default=None)
133    parser.add_argument('--enable_chunked_prefill',
134                        default=False,
135                        action='store_true',
136                        help='Enable chunked prefill')
137    args = parser.parse_args()
138    return args
139
140
141def run_llm(args, sparse_attention_config):
142    data = read_input(args.input_file)
143    num_samples = args.num_samples if args.num_samples is not None else len(
144        data)
145    data = data[:num_samples]
146
147    kv_cache_config = KvCacheConfig(
148        enable_block_reuse=
149        False,  # sparse attention does not support kv cache reuse now
150        free_gpu_memory_fraction=args.kv_cache_fraction,
151        tokens_per_block=args.tokens_per_block,
152        dtype=args.kv_cache_dtype,
153    )
154
155    cuda_graph_config = CudaGraphConfig(
156        batch_sizes=args.cuda_graph_batch_sizes,
157        enable_padding=args.cuda_graph_padding_enabled,
158    ) if args.use_cuda_graph else None
159
160    llm = LLM(
161        model=args.model_path,
162        backend='pytorch',
163        kv_cache_config=kv_cache_config,
164        attn_backend=args.attention_backend,
165        sparse_attention_config=sparse_attention_config,
166        max_batch_size=args.max_batch_size,
167        max_seq_len=args.max_seq_len,
168        max_num_tokens=args.max_num_tokens,
169        tensor_parallel_size=args.tp_size,
170        moe_expert_parallel_size=args.moe_ep_size,
171        enable_attention_dp=args.enable_attention_dp,
172        cuda_graph_config=cuda_graph_config,
173        print_iter_log=args.print_iter_log,
174        enable_iter_perf_stats=args.print_iter_log,
175        moe_config=MoeConfig(backend=args.moe_backend),
176        enable_chunked_prefill=args.enable_chunked_prefill,
177    )
178
179    prompts = []
180    reference = []
181    for sample in data:
182        prompts.append(
183            {'prompt': sample['input_context'] + sample['input_query']})
184        reference.append(sample['outputs'])
185
186    sampling_params = SamplingParams(add_special_tokens=False,
187                                     max_tokens=args.max_new_tokens,
188                                     temperature=0.8,
189                                     top_p=0.95)
190
191    outputs = llm.generate(prompts, sampling_params)
192    for idx, output in enumerate(outputs):
193        print(
194            f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
195        )
196
197
198def run_RocketKV(args):
199    sparse_attention_config = RocketSparseAttentionConfig(
200        window_size=args.window_size,
201        kernel_size=args.kernel_size,
202        prompt_budget=args.prompt_budget,
203        topk=args.topk,
204        kt_cache_dtype=args.kt_cache_dtype,
205    )
206    run_llm(args, sparse_attention_config)
207
208
209def run_DSA(args):
210    sparse_attention_config = DeepSeekSparseAttentionConfig(
211        indexer_max_chunk_size=args.index_max_chunk_size, )
212    run_llm(args, sparse_attention_config)
213
214
215def main():
216    args = parse_arguments()
217    if args.algo == 'ROCKETKV':
218        run_RocketKV(args)
219    elif args.algo == 'DSA':
220        run_DSA(args)
221    else:
222        raise ValueError(f"Invalid algorithm: {args.algo}")
223
224
225if __name__ == "__main__":
226    main()