Sparse Attention#
Source NVIDIA/TensorRT-LLM.
1"""
2This example demonstrates how to use sparse attention with TensorRT-LLM.
3
4Supported sparse attention algorithms:
5- RocketKV
6
7Usage:
8```bash
9python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
10```
11"""
12import argparse
13import json
14
15from tensorrt_llm import LLM, SamplingParams
16from tensorrt_llm.llmapi import (CudaGraphConfig, DeepSeekSparseAttentionConfig,
17 KvCacheConfig, MoeConfig,
18 RocketSparseAttentionConfig)
19
20
21def read_input(input_file):
22 results = []
23 with open(input_file, 'r') as f:
24 for line in f:
25 ret = json.loads(line)
26 results.append(ret)
27 return results
28
29
30def parse_arguments():
31 parser = argparse.ArgumentParser()
32 parser.add_argument(
33 '--model_path',
34 type=str,
35 default=
36 "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
37 )
38 parser.add_argument(
39 '--input_file',
40 type=str,
41 default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
42 )
43 # Build config
44 parser.add_argument('--algo',
45 type=str,
46 default='ROCKETKV',
47 choices=['ROCKETKV', 'DSA'])
48 parser.add_argument('--attention_backend',
49 type=str,
50 default='TRTLLM',
51 choices=['VANILLA', 'TRTLLM'])
52 parser.add_argument('--window_size',
53 type=int,
54 default=32,
55 help="The window size for RocketKV.")
56 parser.add_argument('--kernel_size',
57 type=int,
58 default=63,
59 help="The kernel size for RocketKV.")
60 parser.add_argument('--prompt_budget',
61 type=int,
62 default=2048,
63 help="The prompt budget for RocketKV.")
64 parser.add_argument('--index_max_chunk_size',
65 type=int,
66 default=32768,
67 help="The maximum chunk size for the indexer.")
68 parser.add_argument("--max_seq_len",
69 type=int,
70 default=8192,
71 help="The maximum sequence length.")
72 parser.add_argument("--max_batch_size",
73 type=int,
74 default=256,
75 help="The maximum batch size.")
76 parser.add_argument("--max_new_tokens",
77 type=int,
78 default=128,
79 help="The maximum new tokens.")
80 parser.add_argument(
81 "--max_num_tokens",
82 type=int,
83 default=8192,
84 help=
85 "The maximum total tokens (context + generation) across all sequences in a batch."
86 )
87
88 # Parallelism
89 parser.add_argument('--moe_backend',
90 type=str,
91 default='CUTLASS',
92 choices=[
93 'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
94 'DEEPGEMM', 'CUTEDSL', 'TRITON'
95 ])
96 parser.add_argument('--tp_size', type=int, default=1)
97 parser.add_argument('--moe_ep_size', type=int, default=-1)
98 parser.add_argument('--enable_attention_dp',
99 default=False,
100 action='store_true')
101
102 # KV cache
103 parser.add_argument('--kv_cache_dtype', type=str, default='auto')
104 parser.add_argument("--kv_cache_fraction", type=float, default=None)
105 parser.add_argument('--num_samples', type=int, default=10)
106
107 # Runtime
108 parser.add_argument('--print_iter_log',
109 default=False,
110 action='store_true',
111 help='Print iteration logs during execution')
112 parser.add_argument('--use_cuda_graph', default=False, action='store_true')
113 parser.add_argument('--cuda_graph_padding_enabled',
114 default=False,
115 action='store_true')
116 parser.add_argument('--cuda_graph_batch_sizes',
117 nargs='+',
118 type=int,
119 default=None)
120 args = parser.parse_args()
121 return args
122
123
124def run_llm(args, sparse_attention_config):
125 data = read_input(args.input_file)
126 num_samples = args.num_samples if args.num_samples is not None else len(
127 data)
128 data = data[:num_samples]
129
130 kv_cache_config = KvCacheConfig(
131 enable_block_reuse=
132 False, # sparse attention does not support kv cache reuse now
133 free_gpu_memory_fraction=args.kv_cache_fraction,
134 dtype=args.kv_cache_dtype,
135 )
136
137 cuda_graph_config = CudaGraphConfig(
138 batch_sizes=args.cuda_graph_batch_sizes,
139 enable_padding=args.cuda_graph_padding_enabled,
140 ) if args.use_cuda_graph else None
141
142 llm = LLM(
143 model=args.model_path,
144 backend='pytorch',
145 kv_cache_config=kv_cache_config,
146 attn_backend=args.attention_backend,
147 sparse_attention_config=sparse_attention_config,
148 max_batch_size=args.max_batch_size,
149 max_seq_len=args.max_seq_len,
150 max_num_tokens=args.max_num_tokens,
151 tensor_parallel_size=args.tp_size,
152 moe_expert_parallel_size=args.moe_ep_size,
153 enable_attention_dp=args.enable_attention_dp,
154 cuda_graph_config=cuda_graph_config,
155 print_iter_log=args.print_iter_log,
156 enable_iter_perf_stats=args.print_iter_log,
157 moe_config=MoeConfig(backend=args.moe_backend),
158 )
159
160 prompts = []
161 reference = []
162 for sample in data:
163 prompts.append(
164 {'prompt': sample['input_context'] + sample['input_query']})
165 reference.append(sample['outputs'])
166
167 sampling_params = SamplingParams(add_special_tokens=False,
168 max_tokens=args.max_new_tokens,
169 temperature=0.8,
170 top_p=0.95)
171
172 outputs = llm.generate(prompts, sampling_params)
173 for idx, output in enumerate(outputs):
174 print(
175 f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
176 )
177
178
179def run_RocketKV(args):
180 sparse_attention_config = RocketSparseAttentionConfig(
181 window_size=args.window_size,
182 kernel_size=args.kernel_size,
183 prompt_budget=args.prompt_budget,
184 )
185 run_llm(args, sparse_attention_config)
186
187
188def run_DSA(args):
189 sparse_attention_config = DeepSeekSparseAttentionConfig(
190 indexer_max_chunk_size=args.index_max_chunk_size, )
191 run_llm(args, sparse_attention_config)
192
193
194def main():
195 args = parse_arguments()
196 if args.algo == 'ROCKETKV':
197 run_RocketKV(args)
198 elif args.algo == 'DSA':
199 run_DSA(args)
200 else:
201 raise ValueError(f"Invalid algorithm: {args.algo}")
202
203
204if __name__ == "__main__":
205 main()