Sparse Attention#
Source NVIDIA/TensorRT-LLM.
1"""
2This example demonstrates how to use sparse attention with TensorRT-LLM.
3
4Supported sparse attention algorithms:
5- RocketKV
6
7Usage:
8```bash
9python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
10```
11"""
12import argparse
13import json
14
15from tensorrt_llm import LLM, SamplingParams
16from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
17
18
19def read_input(input_file):
20 results = []
21 with open(input_file, 'r') as f:
22 for line in f:
23 ret = json.loads(line)
24 results.append(ret)
25 return results
26
27
28def parse_arguments():
29 parser = argparse.ArgumentParser()
30 parser.add_argument(
31 '--model_path',
32 type=str,
33 default=
34 "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
35 )
36 parser.add_argument(
37 '--input_file',
38 type=str,
39 default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
40 )
41 # Build config
42 parser.add_argument('--algo',
43 type=str,
44 default='ROCKETKV',
45 choices=['ROCKETKV'])
46 parser.add_argument('--attention_backend',
47 type=str,
48 default='TRTLLM',
49 choices=['VANILLA', 'TRTLLM'])
50 parser.add_argument('--window_size',
51 type=int,
52 default=32,
53 help="The window size for RocketKV.")
54 parser.add_argument('--kernel_size',
55 type=int,
56 default=63,
57 help="The kernel size for RocketKV.")
58 parser.add_argument('--prompt_budget',
59 type=int,
60 default=2048,
61 help="The prompt budget for RocketKV.")
62 parser.add_argument("--max_seq_len",
63 type=int,
64 default=8192,
65 help="The maximum sequence length.")
66 parser.add_argument("--max_batch_size",
67 type=int,
68 default=256,
69 help="The maximum batch size.")
70 parser.add_argument("--max_new_tokens",
71 type=int,
72 default=128,
73 help="The maximum new tokens.")
74 parser.add_argument(
75 "--max_num_tokens",
76 type=int,
77 default=8192,
78 help=
79 "The maximum total tokens (context + generation) across all sequences in a batch."
80 )
81 parser.add_argument('--tensor_parallel_size', type=int, default=1)
82
83 # KV cache
84 parser.add_argument('--kv_cache_dtype', type=str, default='auto')
85 parser.add_argument("--kv_cache_fraction", type=float, default=None)
86 parser.add_argument('--num_samples', type=int, default=10)
87
88 args = parser.parse_args()
89 return args
90
91
92def run_RocketKV(args):
93 data = read_input(args.input_file)
94 num_samples = args.num_samples if args.num_samples is not None else len(
95 data)
96 data = data[:num_samples]
97
98 kv_cache_config = KvCacheConfig(
99 enable_block_reuse=
100 False, # sparse attention does not support kv cache reuse now
101 free_gpu_memory_fraction=args.kv_cache_fraction,
102 dtype=args.kv_cache_dtype,
103 )
104 sparse_attention_config = RocketSparseAttentionConfig(
105 window_size=args.window_size,
106 kernel_size=args.kernel_size,
107 prompt_budget=args.prompt_budget,
108 )
109
110 llm = LLM(
111 model=args.model_path,
112 backend='pytorch',
113 kv_cache_config=kv_cache_config,
114 attn_backend=args.attention_backend,
115 sparse_attention_config=sparse_attention_config,
116 max_batch_size=args.max_batch_size,
117 max_seq_len=args.max_seq_len,
118 max_num_tokens=args.max_num_tokens,
119 tensor_parallel_size=args.tensor_parallel_size,
120 cuda_graph_config=
121 None, # sparse attention does not support cuda graph now
122 )
123
124 prompts = []
125 reference = []
126 for sample in data:
127 prompts.append(
128 {'prompt': sample['input_context'] + sample['input_query']})
129 reference.append(sample['outputs'])
130
131 sampling_params = SamplingParams(add_special_tokens=False,
132 max_tokens=args.max_new_tokens,
133 temperature=0.8,
134 top_p=0.95)
135
136 outputs = llm.generate(prompts, sampling_params)
137 for idx, output in enumerate(outputs):
138 print(
139 f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
140 )
141
142
143def main():
144 args = parse_arguments()
145 if args.algo == 'ROCKETKV':
146 run_RocketKV(args)
147 else:
148 raise ValueError(f"Invalid algorithm: {args.algo}")
149
150
151if __name__ == "__main__":
152 main()