Sparse Attention#
Source NVIDIA/TensorRT-LLM.
1"""
2This example demonstrates how to use sparse attention with TensorRT-LLM.
3
4Supported sparse attention algorithms:
5- RocketKV
6- DSA
7
8Usage:
9```bash
10python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
11```
12"""
13import argparse
14import json
15
16from tensorrt_llm import LLM, SamplingParams
17from tensorrt_llm.llmapi import (CudaGraphConfig, DeepSeekSparseAttentionConfig,
18 KvCacheConfig, MoeConfig,
19 RocketSparseAttentionConfig)
20
21
22def read_input(input_file):
23 results = []
24 with open(input_file, 'r') as f:
25 for line in f:
26 ret = json.loads(line)
27 results.append(ret)
28 return results
29
30
31def parse_arguments():
32 parser = argparse.ArgumentParser()
33 parser.add_argument(
34 '--model_path',
35 type=str,
36 default=
37 "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
38 )
39 parser.add_argument(
40 '--input_file',
41 type=str,
42 default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
43 )
44
45 # Build config
46 parser.add_argument('--algo',
47 type=str,
48 default='ROCKETKV',
49 choices=['ROCKETKV', 'DSA'])
50 parser.add_argument('--attention_backend',
51 type=str,
52 default='TRTLLM',
53 choices=['VANILLA', 'TRTLLM'])
54
55 # RocketKV config
56 parser.add_argument('--window_size',
57 type=int,
58 default=32,
59 help="The window size for RocketKV.")
60 parser.add_argument('--kernel_size',
61 type=int,
62 default=63,
63 help="The kernel size for RocketKV.")
64 parser.add_argument('--prompt_budget',
65 type=int,
66 default=2048,
67 help="The prompt budget for RocketKV.")
68 parser.add_argument('--topk',
69 type=int,
70 default=64,
71 help='Top-k for RocketKV')
72 parser.add_argument('--kt_cache_dtype',
73 type=str,
74 default='float8_e5m2',
75 choices=['bfloat16', 'float8_e5m2'])
76 parser.add_argument('--index_max_chunk_size',
77 type=int,
78 default=32768,
79 help="The maximum chunk size for the indexer.")
80 parser.add_argument("--max_seq_len",
81 type=int,
82 default=10240,
83 help="The maximum sequence length.")
84 parser.add_argument("--max_batch_size",
85 type=int,
86 default=256,
87 help="The maximum batch size.")
88 parser.add_argument("--max_new_tokens",
89 type=int,
90 default=128,
91 help="The maximum new tokens.")
92 parser.add_argument(
93 "--max_num_tokens",
94 type=int,
95 default=81920,
96 help=
97 "The maximum total tokens (context + generation) across all sequences in a batch."
98 )
99
100 # Parallelism
101 parser.add_argument('--moe_backend',
102 type=str,
103 default='CUTLASS',
104 choices=[
105 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
106 'DEEPGEMM', 'CUTEDSL', 'TRITON'
107 ])
108 parser.add_argument('--tp_size', type=int, default=1)
109 parser.add_argument('--moe_ep_size', type=int, default=-1)
110 parser.add_argument('--enable_attention_dp',
111 default=False,
112 action='store_true')
113
114 # KV cache
115 parser.add_argument('--kv_cache_dtype', type=str, default='auto')
116 parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
117 parser.add_argument('--tokens_per_block', type=int, default=32)
118 parser.add_argument('--num_samples', type=int, default=10)
119
120 # Runtime
121 parser.add_argument('--print_iter_log',
122 default=False,
123 action='store_true',
124 help='Print iteration logs during execution')
125 parser.add_argument('--use_cuda_graph', default=False, action='store_true')
126 parser.add_argument('--cuda_graph_padding_enabled',
127 default=False,
128 action='store_true')
129 parser.add_argument('--cuda_graph_batch_sizes',
130 nargs='+',
131 type=int,
132 default=None)
133 parser.add_argument('--enable_chunked_prefill',
134 default=False,
135 action='store_true',
136 help='Enable chunked prefill')
137 args = parser.parse_args()
138 return args
139
140
141def run_llm(args, sparse_attention_config):
142 data = read_input(args.input_file)
143 num_samples = args.num_samples if args.num_samples is not None else len(
144 data)
145 data = data[:num_samples]
146
147 kv_cache_config = KvCacheConfig(
148 enable_block_reuse=
149 False, # sparse attention does not support kv cache reuse now
150 free_gpu_memory_fraction=args.kv_cache_fraction,
151 tokens_per_block=args.tokens_per_block,
152 dtype=args.kv_cache_dtype,
153 )
154
155 cuda_graph_config = CudaGraphConfig(
156 batch_sizes=args.cuda_graph_batch_sizes,
157 enable_padding=args.cuda_graph_padding_enabled,
158 ) if args.use_cuda_graph else None
159
160 llm = LLM(
161 model=args.model_path,
162 backend='pytorch',
163 kv_cache_config=kv_cache_config,
164 attn_backend=args.attention_backend,
165 sparse_attention_config=sparse_attention_config,
166 max_batch_size=args.max_batch_size,
167 max_seq_len=args.max_seq_len,
168 max_num_tokens=args.max_num_tokens,
169 tensor_parallel_size=args.tp_size,
170 moe_expert_parallel_size=args.moe_ep_size,
171 enable_attention_dp=args.enable_attention_dp,
172 cuda_graph_config=cuda_graph_config,
173 print_iter_log=args.print_iter_log,
174 enable_iter_perf_stats=args.print_iter_log,
175 moe_config=MoeConfig(backend=args.moe_backend),
176 enable_chunked_prefill=args.enable_chunked_prefill,
177 )
178
179 prompts = []
180 reference = []
181 for sample in data:
182 prompts.append(
183 {'prompt': sample['input_context'] + sample['input_query']})
184 reference.append(sample['outputs'])
185
186 sampling_params = SamplingParams(add_special_tokens=False,
187 max_tokens=args.max_new_tokens,
188 temperature=0.8,
189 top_p=0.95)
190
191 outputs = llm.generate(prompts, sampling_params)
192 for idx, output in enumerate(outputs):
193 print(
194 f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
195 )
196
197
198def run_RocketKV(args):
199 sparse_attention_config = RocketSparseAttentionConfig(
200 window_size=args.window_size,
201 kernel_size=args.kernel_size,
202 prompt_budget=args.prompt_budget,
203 topk=args.topk,
204 kt_cache_dtype=args.kt_cache_dtype,
205 )
206 run_llm(args, sparse_attention_config)
207
208
209def run_DSA(args):
210 sparse_attention_config = DeepSeekSparseAttentionConfig(
211 indexer_max_chunk_size=args.index_max_chunk_size, )
212 run_llm(args, sparse_attention_config)
213
214
215def main():
216 args = parse_arguments()
217 if args.algo == 'ROCKETKV':
218 run_RocketKV(args)
219 elif args.algo == 'DSA':
220 run_DSA(args)
221 else:
222 raise ValueError(f"Invalid algorithm: {args.algo}")
223
224
225if __name__ == "__main__":
226 main()