Sparse Attention#
Source NVIDIA/TensorRT-LLM.
1"""
2This example demonstrates how to use sparse attention with TensorRT-LLM.
3
4Supported sparse attention algorithms:
5- RocketKV
6- DSA
7
8Usage:
9```bash
10python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
11```
12"""
13import argparse
14import json
15
16from tensorrt_llm import LLM, SamplingParams
17from tensorrt_llm.llmapi import (CudaGraphConfig, DeepSeekSparseAttentionConfig,
18 KvCacheConfig, MoeConfig,
19 RocketSparseAttentionConfig)
20
21
22def read_input(input_file):
23 results = []
24 with open(input_file, 'r') as f:
25 for line in f:
26 ret = json.loads(line)
27 results.append(ret)
28 return results
29
30
31def parse_arguments():
32 parser = argparse.ArgumentParser()
33 parser.add_argument(
34 '--model_path',
35 type=str,
36 default=
37 "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
38 )
39 parser.add_argument(
40 '--input_file',
41 type=str,
42 default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
43 )
44 # Build config
45 parser.add_argument('--algo',
46 type=str,
47 default='ROCKETKV',
48 choices=['ROCKETKV', 'DSA'])
49 parser.add_argument('--attention_backend',
50 type=str,
51 default='TRTLLM',
52 choices=['VANILLA', 'TRTLLM'])
53 parser.add_argument('--window_size',
54 type=int,
55 default=32,
56 help="The window size for RocketKV.")
57 parser.add_argument('--kernel_size',
58 type=int,
59 default=63,
60 help="The kernel size for RocketKV.")
61 parser.add_argument('--prompt_budget',
62 type=int,
63 default=2048,
64 help="The prompt budget for RocketKV.")
65 parser.add_argument('--index_max_chunk_size',
66 type=int,
67 default=32768,
68 help="The maximum chunk size for the indexer.")
69 parser.add_argument("--max_seq_len",
70 type=int,
71 default=10240,
72 help="The maximum sequence length.")
73 parser.add_argument("--max_batch_size",
74 type=int,
75 default=256,
76 help="The maximum batch size.")
77 parser.add_argument("--max_new_tokens",
78 type=int,
79 default=128,
80 help="The maximum new tokens.")
81 parser.add_argument(
82 "--max_num_tokens",
83 type=int,
84 default=81920,
85 help=
86 "The maximum total tokens (context + generation) across all sequences in a batch."
87 )
88
89 # Parallelism
90 parser.add_argument('--moe_backend',
91 type=str,
92 default='CUTLASS',
93 choices=[
94 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
95 'DEEPGEMM', 'CUTEDSL', 'TRITON'
96 ])
97 parser.add_argument('--tp_size', type=int, default=1)
98 parser.add_argument('--moe_ep_size', type=int, default=-1)
99 parser.add_argument('--enable_attention_dp',
100 default=False,
101 action='store_true')
102
103 # KV cache
104 parser.add_argument('--kv_cache_dtype', type=str, default='auto')
105 parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
106 parser.add_argument('--num_samples', type=int, default=10)
107
108 # Runtime
109 parser.add_argument('--print_iter_log',
110 default=False,
111 action='store_true',
112 help='Print iteration logs during execution')
113 parser.add_argument('--use_cuda_graph', default=False, action='store_true')
114 parser.add_argument('--cuda_graph_padding_enabled',
115 default=False,
116 action='store_true')
117 parser.add_argument('--cuda_graph_batch_sizes',
118 nargs='+',
119 type=int,
120 default=None)
121 args = parser.parse_args()
122 return args
123
124
125def run_llm(args, sparse_attention_config):
126 data = read_input(args.input_file)
127 num_samples = args.num_samples if args.num_samples is not None else len(
128 data)
129 data = data[:num_samples]
130
131 kv_cache_config = KvCacheConfig(
132 enable_block_reuse=
133 False, # sparse attention does not support kv cache reuse now
134 free_gpu_memory_fraction=args.kv_cache_fraction,
135 dtype=args.kv_cache_dtype,
136 )
137
138 cuda_graph_config = CudaGraphConfig(
139 batch_sizes=args.cuda_graph_batch_sizes,
140 enable_padding=args.cuda_graph_padding_enabled,
141 ) if args.use_cuda_graph else None
142
143 llm = LLM(
144 model=args.model_path,
145 backend='pytorch',
146 kv_cache_config=kv_cache_config,
147 attn_backend=args.attention_backend,
148 sparse_attention_config=sparse_attention_config,
149 max_batch_size=args.max_batch_size,
150 max_seq_len=args.max_seq_len,
151 max_num_tokens=args.max_num_tokens,
152 tensor_parallel_size=args.tp_size,
153 moe_expert_parallel_size=args.moe_ep_size,
154 enable_attention_dp=args.enable_attention_dp,
155 cuda_graph_config=cuda_graph_config,
156 print_iter_log=args.print_iter_log,
157 enable_iter_perf_stats=args.print_iter_log,
158 moe_config=MoeConfig(backend=args.moe_backend),
159 )
160
161 prompts = []
162 reference = []
163 for sample in data:
164 prompts.append(
165 {'prompt': sample['input_context'] + sample['input_query']})
166 reference.append(sample['outputs'])
167
168 sampling_params = SamplingParams(add_special_tokens=False,
169 max_tokens=args.max_new_tokens,
170 temperature=0.8,
171 top_p=0.95)
172
173 outputs = llm.generate(prompts, sampling_params)
174 for idx, output in enumerate(outputs):
175 print(
176 f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
177 )
178
179
180def run_RocketKV(args):
181 sparse_attention_config = RocketSparseAttentionConfig(
182 window_size=args.window_size,
183 kernel_size=args.kernel_size,
184 prompt_budget=args.prompt_budget,
185 )
186 run_llm(args, sparse_attention_config)
187
188
189def run_DSA(args):
190 sparse_attention_config = DeepSeekSparseAttentionConfig(
191 indexer_max_chunk_size=args.index_max_chunk_size, )
192 run_llm(args, sparse_attention_config)
193
194
195def main():
196 args = parse_arguments()
197 if args.algo == 'ROCKETKV':
198 run_RocketKV(args)
199 elif args.algo == 'DSA':
200 run_DSA(args)
201 else:
202 raise ValueError(f"Invalid algorithm: {args.algo}")
203
204
205if __name__ == "__main__":
206 main()