Sparse Attention#

Source NVIDIA/TensorRT-LLM.

  1"""
  2This example demonstrates how to use sparse attention with TensorRT-LLM.
  3
  4Supported sparse attention algorithms:
  5- RocketKV
  6
  7Usage:
  8```bash
  9python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
 10```
 11"""
 12import argparse
 13import json
 14
 15from tensorrt_llm import LLM, SamplingParams
 16from tensorrt_llm.llmapi import (CudaGraphConfig, DeepSeekSparseAttentionConfig,
 17                                 KvCacheConfig, MoeConfig,
 18                                 RocketSparseAttentionConfig)
 19
 20
 21def read_input(input_file):
 22    results = []
 23    with open(input_file, 'r') as f:
 24        for line in f:
 25            ret = json.loads(line)
 26            results.append(ret)
 27    return results
 28
 29
 30def parse_arguments():
 31    parser = argparse.ArgumentParser()
 32    parser.add_argument(
 33        '--model_path',
 34        type=str,
 35        default=
 36        "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
 37    )
 38    parser.add_argument(
 39        '--input_file',
 40        type=str,
 41        default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
 42    )
 43    # Build config
 44    parser.add_argument('--algo',
 45                        type=str,
 46                        default='ROCKETKV',
 47                        choices=['ROCKETKV', 'DSA'])
 48    parser.add_argument('--attention_backend',
 49                        type=str,
 50                        default='TRTLLM',
 51                        choices=['VANILLA', 'TRTLLM'])
 52    parser.add_argument('--window_size',
 53                        type=int,
 54                        default=32,
 55                        help="The window size for RocketKV.")
 56    parser.add_argument('--kernel_size',
 57                        type=int,
 58                        default=63,
 59                        help="The kernel size for RocketKV.")
 60    parser.add_argument('--prompt_budget',
 61                        type=int,
 62                        default=2048,
 63                        help="The prompt budget for RocketKV.")
 64    parser.add_argument('--index_max_chunk_size',
 65                        type=int,
 66                        default=32768,
 67                        help="The maximum chunk size for the indexer.")
 68    parser.add_argument("--max_seq_len",
 69                        type=int,
 70                        default=8192,
 71                        help="The maximum sequence length.")
 72    parser.add_argument("--max_batch_size",
 73                        type=int,
 74                        default=256,
 75                        help="The maximum batch size.")
 76    parser.add_argument("--max_new_tokens",
 77                        type=int,
 78                        default=128,
 79                        help="The maximum new tokens.")
 80    parser.add_argument(
 81        "--max_num_tokens",
 82        type=int,
 83        default=8192,
 84        help=
 85        "The maximum total tokens (context + generation) across all sequences in a batch."
 86    )
 87
 88    # Parallelism
 89    parser.add_argument('--moe_backend',
 90                        type=str,
 91                        default='CUTLASS',
 92                        choices=[
 93                            'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
 94                            'DEEPGEMM', 'CUTEDSL', 'TRITON'
 95                        ])
 96    parser.add_argument('--tp_size', type=int, default=1)
 97    parser.add_argument('--moe_ep_size', type=int, default=-1)
 98    parser.add_argument('--enable_attention_dp',
 99                        default=False,
100                        action='store_true')
101
102    # KV cache
103    parser.add_argument('--kv_cache_dtype', type=str, default='auto')
104    parser.add_argument("--kv_cache_fraction", type=float, default=None)
105    parser.add_argument('--num_samples', type=int, default=10)
106
107    # Runtime
108    parser.add_argument('--print_iter_log',
109                        default=False,
110                        action='store_true',
111                        help='Print iteration logs during execution')
112    parser.add_argument('--use_cuda_graph', default=False, action='store_true')
113    parser.add_argument('--cuda_graph_padding_enabled',
114                        default=False,
115                        action='store_true')
116    parser.add_argument('--cuda_graph_batch_sizes',
117                        nargs='+',
118                        type=int,
119                        default=None)
120    args = parser.parse_args()
121    return args
122
123
124def run_llm(args, sparse_attention_config):
125    data = read_input(args.input_file)
126    num_samples = args.num_samples if args.num_samples is not None else len(
127        data)
128    data = data[:num_samples]
129
130    kv_cache_config = KvCacheConfig(
131        enable_block_reuse=
132        False,  # sparse attention does not support kv cache reuse now
133        free_gpu_memory_fraction=args.kv_cache_fraction,
134        dtype=args.kv_cache_dtype,
135    )
136
137    cuda_graph_config = CudaGraphConfig(
138        batch_sizes=args.cuda_graph_batch_sizes,
139        enable_padding=args.cuda_graph_padding_enabled,
140    ) if args.use_cuda_graph else None
141
142    llm = LLM(
143        model=args.model_path,
144        backend='pytorch',
145        kv_cache_config=kv_cache_config,
146        attn_backend=args.attention_backend,
147        sparse_attention_config=sparse_attention_config,
148        max_batch_size=args.max_batch_size,
149        max_seq_len=args.max_seq_len,
150        max_num_tokens=args.max_num_tokens,
151        tensor_parallel_size=args.tp_size,
152        moe_expert_parallel_size=args.moe_ep_size,
153        enable_attention_dp=args.enable_attention_dp,
154        cuda_graph_config=cuda_graph_config,
155        print_iter_log=args.print_iter_log,
156        enable_iter_perf_stats=args.print_iter_log,
157        moe_config=MoeConfig(backend=args.moe_backend),
158    )
159
160    prompts = []
161    reference = []
162    for sample in data:
163        prompts.append(
164            {'prompt': sample['input_context'] + sample['input_query']})
165        reference.append(sample['outputs'])
166
167    sampling_params = SamplingParams(add_special_tokens=False,
168                                     max_tokens=args.max_new_tokens,
169                                     temperature=0.8,
170                                     top_p=0.95)
171
172    outputs = llm.generate(prompts, sampling_params)
173    for idx, output in enumerate(outputs):
174        print(
175            f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
176        )
177
178
179def run_RocketKV(args):
180    sparse_attention_config = RocketSparseAttentionConfig(
181        window_size=args.window_size,
182        kernel_size=args.kernel_size,
183        prompt_budget=args.prompt_budget,
184    )
185    run_llm(args, sparse_attention_config)
186
187
188def run_DSA(args):
189    sparse_attention_config = DeepSeekSparseAttentionConfig(
190        indexer_max_chunk_size=args.index_max_chunk_size, )
191    run_llm(args, sparse_attention_config)
192
193
194def main():
195    args = parse_arguments()
196    if args.algo == 'ROCKETKV':
197        run_RocketKV(args)
198    elif args.algo == 'DSA':
199        run_DSA(args)
200    else:
201        raise ValueError(f"Invalid algorithm: {args.algo}")
202
203
204if __name__ == "__main__":
205    main()