Sparse Attention#

Source NVIDIA/TensorRT-LLM.

  1"""
  2This example demonstrates how to use sparse attention with TensorRT-LLM.
  3
  4Supported sparse attention algorithms:
  5- RocketKV
  6
  7Usage:
  8```bash
  9python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
 10```
 11"""
 12import argparse
 13import json
 14
 15from tensorrt_llm import LLM, SamplingParams
 16from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
 17
 18
 19def read_input(input_file):
 20    results = []
 21    with open(input_file, 'r') as f:
 22        for line in f:
 23            ret = json.loads(line)
 24            results.append(ret)
 25    return results
 26
 27
 28def parse_arguments():
 29    parser = argparse.ArgumentParser()
 30    parser.add_argument(
 31        '--model_path',
 32        type=str,
 33        default=
 34        "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
 35    )
 36    parser.add_argument(
 37        '--input_file',
 38        type=str,
 39        default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
 40    )
 41    # Build config
 42    parser.add_argument('--algo',
 43                        type=str,
 44                        default='ROCKETKV',
 45                        choices=['ROCKETKV'])
 46    parser.add_argument('--attention_backend',
 47                        type=str,
 48                        default='TRTLLM',
 49                        choices=['VANILLA', 'TRTLLM'])
 50    parser.add_argument('--window_size',
 51                        type=int,
 52                        default=32,
 53                        help="The window size for RocketKV.")
 54    parser.add_argument('--kernel_size',
 55                        type=int,
 56                        default=63,
 57                        help="The kernel size for RocketKV.")
 58    parser.add_argument('--prompt_budget',
 59                        type=int,
 60                        default=2048,
 61                        help="The prompt budget for RocketKV.")
 62    parser.add_argument("--max_seq_len",
 63                        type=int,
 64                        default=8192,
 65                        help="The maximum sequence length.")
 66    parser.add_argument("--max_batch_size",
 67                        type=int,
 68                        default=256,
 69                        help="The maximum batch size.")
 70    parser.add_argument("--max_new_tokens",
 71                        type=int,
 72                        default=128,
 73                        help="The maximum new tokens.")
 74    parser.add_argument(
 75        "--max_num_tokens",
 76        type=int,
 77        default=8192,
 78        help=
 79        "The maximum total tokens (context + generation) across all sequences in a batch."
 80    )
 81    parser.add_argument('--tensor_parallel_size', type=int, default=1)
 82
 83    # KV cache
 84    parser.add_argument('--kv_cache_dtype', type=str, default='auto')
 85    parser.add_argument("--kv_cache_fraction", type=float, default=None)
 86    parser.add_argument('--num_samples', type=int, default=10)
 87
 88    args = parser.parse_args()
 89    return args
 90
 91
 92def run_RocketKV(args):
 93    data = read_input(args.input_file)
 94    num_samples = args.num_samples if args.num_samples is not None else len(
 95        data)
 96    data = data[:num_samples]
 97
 98    kv_cache_config = KvCacheConfig(
 99        enable_block_reuse=
100        False,  # sparse attention does not support kv cache reuse now
101        free_gpu_memory_fraction=args.kv_cache_fraction,
102        dtype=args.kv_cache_dtype,
103    )
104    sparse_attention_config = RocketSparseAttentionConfig(
105        window_size=args.window_size,
106        kernel_size=args.kernel_size,
107        prompt_budget=args.prompt_budget,
108    )
109
110    llm = LLM(
111        model=args.model_path,
112        backend='pytorch',
113        kv_cache_config=kv_cache_config,
114        attn_backend=args.attention_backend,
115        sparse_attention_config=sparse_attention_config,
116        max_batch_size=args.max_batch_size,
117        max_seq_len=args.max_seq_len,
118        max_num_tokens=args.max_num_tokens,
119        tensor_parallel_size=args.tensor_parallel_size,
120        cuda_graph_config=
121        None,  # sparse attention does not support cuda graph now
122    )
123
124    prompts = []
125    reference = []
126    for sample in data:
127        prompts.append(
128            {'prompt': sample['input_context'] + sample['input_query']})
129        reference.append(sample['outputs'])
130
131    sampling_params = SamplingParams(add_special_tokens=False,
132                                     max_tokens=args.max_new_tokens,
133                                     temperature=0.8,
134                                     top_p=0.95)
135
136    outputs = llm.generate(prompts, sampling_params)
137    for idx, output in enumerate(outputs):
138        print(
139            f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
140        )
141
142
143def main():
144    args = parse_arguments()
145    if args.algo == 'ROCKETKV':
146        run_RocketKV(args)
147    else:
148        raise ValueError(f"Invalid algorithm: {args.algo}")
149
150
151if __name__ == "__main__":
152    main()