TensorRT-LLM Benchmarking

Important

This benchmarking suite is a work in progress. Expect breaking API changes.

TensorRT-LLM provides the trtllm-bench CLI, a packaged benchmarking utility.

Supported Networks for Benchmarking

The trtllm-bench CLI tool can automatically download the model from Hugging Face Model Hub. Export your token in the HF_TOKEN environment variable.

Support Quantization Modes

TensorRT-LLM supports a number of quantization modes:

  • None (no quantization applied)

  • W8A16

  • W4A16

  • W4A16_AWQ

  • W4A8_AWQ

  • W4A16_GPTQ

  • FP8

  • INT8

For more information about quantization, refer to Numerical Precision and the support matrix of the supported quantization methods for each network.

Inflight Benchmarking with a Dataset

This section covers how to benchmark TensorRT-LLM using inflight batching.

Quickstart

This quick start focuses on running a short max throughput benchmark on meta-llama/Llama-2-7b-hf on a synthetic dataset with a uniform distribution of prompts with ISL:OSL of 128:128. To run the benchmark from start to finish, run the following commands:

python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 > /tmp/synthetic_128_128.txt
trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1

And that’s it! After the benchmark completes, trtllm-bench prints a summary with summary metrics.

===========================================================
= ENGINE DETAILS
===========================================================
Model:                  meta-llama/Llama-2-7b-hf
Engine Directory:       /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
TensorRT-LLM Version:   0.12.0
Dtype:                  float16
KV Cache Dtype:         FP8
Quantization:           FP8
Max Input Length:       2048
Max Sequence Length:    4098

===========================================================
= WORLD + RUNTIME INFORMATION
===========================================================
TP Size:                1
PP Size:                1
Max Runtime Batch Size: 4096
Max Runtime Tokens:     8192
Scheduling Policy:      Guaranteed No Evict
KV Memory Percentage:   99.0%
Issue Rate (req/sec):   3.680275266452667e+18
===========================================================
= STATISTICS
===========================================================
Number of requests:             3000
Average Input Length (tokens):  128.0
Average Output Length (tokens): 128.0
Token Throughput (tokens/sec):  23405.927228471104
Request Throughput (req/sec):   182.8588064724305
Total Latency (seconds):        16.406100739
===========================================================

Workflow

The workflow for trtllm-bench is composed of the following steps:

  1. Prepare a dataset to drive the inflight batching benchmark.

  2. Build a benchmark engine using trtllm-bench build subcommand.

  3. Run the max throughput benchmark using the trtllm-bench throughput subcommand or low latency benchmark using the trtllm-bench latency subcommand.

Preparing a Dataset

The inflight benchmark utilizes a fixed JSON schema so that it is simple and straightforward to specify requests. The schema is defined as follows:

Key

Required

Type

Description

task_id

Y

String

Unique identifier for the request.

prompt

N*

String

Input text for a generation request.

logits

N*

List[Integer]

List of logits that make up the request prompt.

output_tokens

Y

Integer

Number of generated tokens for this request.

Prompt and logits are mutually exclusive, but one of prompt or logits is required. If you specify logits, the prompt entry is ignored for request generation.

Refer to the following examples of valid entries for the inflight benchmark:

  • Entries with a human-readable prompt and no logits.

    {"task_id": 1, "prompt": "Generate an infinite response to the following: This is the song that never ends, it goes on and on my friend.", "output_tokens": 1000}
    {"task_id": 2, "prompt": "Generate an infinite response to the following: Na, na, na, na", "output_tokens": 1000}
    
  • Entries which contain logits.

    {"task_id":0,"logits":[863,22056,25603,11943,8932,13195,3132,25032,21747,22213],"output_tokens":128}
    {"task_id":1,"logits":[14480,13598,15585,6591,1252,8259,30990,26778,7063,30065,21764,11023,1418],"output_tokens":128}
    

Tip

Specify each entry on one line. To simplify passing the data, a complete JSON entry is on each line so that the benchmarker can simply read a line and assume a complete entry. When creating a dataset, be sure that a complete JSON entry is on every line.

Using prepare_dataset.py to Create Synthetic Datasets

In order to prepare a synthetic dataset, you can use the provided script in the benchmarks/cpp directory. For example, to generate a synthetic dataset of 1000 requests with a uniform ISL/OSL of 128/128 for Llama-2-7b, simply run:

benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt

You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the benchmark script (example below).

Building a Benchmark Engine

The trtllm-bench CLI tool provides the build subcommand to build the TRT-LLM engines for max throughput benchmark.

How to Build the Engine

To build an engine for benchmarking, you can specify the dataset generated with prepare_dataset.py through --dataset option. The trtllm-bench’s tuning heuristic uses the high-level statistics of the dataset (average ISL/OSL, max sequence length) to optimize engine build settings. The following command builds an FP8 quantized engine optimized using the dataset’s ISL/OSL.

trtllm-bench --model meta-llama/Llama-2-7b-hf build --quantization FP8 --dataset /tmp/synthetic_128_128.txt

The build subcommand also provides other ways to build the engine where users have larger control over the tuning values.

  • Build engine with self-defined tuning values: You specify the tuning values to build the engine with by setting --max_batch_size and --max_num_tokens directly. max_batch_size and max_num_tokens control the maximum number of requests and tokens that can be scheduled in each iteration. If no value is specified, the default max_batch_size and max_num_tokens values of 2048 and 8192 are used. The following command builds an FP8 quantized engine by specifying the engine tuning values.

trtllm-bench --model meta-llama/Llama-2-7b-hf build --quantization FP8 --max_seq_len 4096 --max_batch_size 1024 --max_num_tokens 2048
  • [Experimental] Build engine with target ISL/OSL for optimization: In this experimental mode, you can provide hints to trtllm-bench’s tuning heuristic to optimize the engine on specific ISL and OSL targets. Generally, the target ISL and OSL aligns with the average ISL and OSL of the dataset, but you can experiment with different values to optimize the engine using this mode. The following command builds an FP8 quantized engine and optmizes for ISL:OSL targets of 128:128.

trtllm-bench --model meta-llama/Llama-2-7b-hf build --quantization FP8 --max_seq_len 4096 --target_isl 128 --target_osl 128

Parallelism Mapping Support

The trtllm-bench build subcommand supports combinations of tensor-parallel (TP) and pipeline-parallel (PP) mappings as long as the world size (tp_size x pp_size) <= 8. The parallelism mapping in build subcommad is controlled by --tp_size and --pp_size options. The following command builds an engine with TP2-PP2 mapping.

trtllm-bench --model meta-llama/Llama-2-7b-hf build --quantization FP8 --dataset /tmp/synthetic_128_128.txt --tp_size 2 --pp_size 2

Example of Build Subcommand Output:

The output of the build subcommand looks similar to the snippet below (for meta-llama/Llama-2-7b-hf):

trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
[TensorRT-LLM] TensorRT-LLM version: 0.12.0
[08/12/2024-19:13:06] [TRT-LLM] [I] Found dataset.
[08/12/2024-19:13:07] [TRT-LLM] [I]
===========================================================
= DATASET DETAILS
===========================================================
Max Input Sequence Length:      128
Max Output Sequence Length:     128
Max Sequence Length:    256
Number of Sequences:    3000
===========================================================


[08/12/2024-19:13:07] [TRT-LLM] [I] Set multiple_profiles to True.
[08/12/2024-19:13:07] [TRT-LLM] [I] Set use_paged_context_fmha to True.
[08/12/2024-19:13:07] [TRT-LLM] [I] Set use_fp8_context_fmha to True.
[08/12/2024-19:13:07] [TRT-LLM] [I]
===========================================================
= ENGINE BUILD INFO
===========================================================
Model Name:             meta-llama/Llama-2-7b-hf
Workspace Directory:    /tmp
Engine Directory:       /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1

===========================================================
= ENGINE CONFIGURATION DETAILS
===========================================================
Max Sequence Length:            256
Max Batch Size:                 4096
Max Num Tokens:                 8192
Quantization:                   FP8
===========================================================

Loading Model: [1/3]    Downloading HF model
Downloaded model to /data/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9
Time: 0.115s
Loading Model: [2/3]    Loading HF model to memory
current rank: 0, tp rank: 0, pp rank: 0
Time: 60.786s
Loading Model: [3/3]    Building TRT-LLM engine
Time: 163.331s
Loading model done.
Total latency: 224.232s
[TensorRT-LLM][INFO] Engine version 0.12.0 found in the config file, assuming engine(s) built by new builder API.

<snip verbose logging>

[08/12/2024-19:17:09] [TRT-LLM] [I]

===========================================================
ENGINE SAVED: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
===========================================================

The engine in this case will be written to /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1 (the end of the log).

Max Throughput Benchmark

The trtllm-bench command line tool provides a max throughput benchmark that is accessible via the throughput subcommand. This benchmark tests a TensorRT-LLM engine under maximum load to provide an upper bound throughput number.

How the Benchmarker Works

The benchmarker reads a data file where a single line contains a complete JSON request entry as specified in Preparing a Dataset. The process that the benchmarker is as follows:

  1. Iterate over all input requests. If logits is specified, construct the request using the specified list of logits. Otherwise, tokenize the prompt with as specified by --model $HF_MODEL_NAME.

  2. Submit the dataset to the TensorRT-LLM Executor API as fast as possible (offline mode).

  3. Wait for all requests to return, compute statistics, and then report results.

To run the benchmarker, run the following commands with the engine and dataset generated from previous steps:

trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
[TensorRT-LLM] TensorRT-LLM version: 0.12.0
[08/12/2024-19:36:48] [TRT-LLM] [I] Preparing to run throughput benchmark...
[08/12/2024-19:36:49] [TRT-LLM] [I] Setting up benchmarker and infrastructure.
[08/12/2024-19:36:49] [TRT-LLM] [I] Ready to start benchmark.
[08/12/2024-19:36:49] [TRT-LLM] [I] Initializing Executor.
[TensorRT-LLM][INFO] Engine version 0.12.0 found in the config file, assuming engine(s) built by new builder API.

<snip verbose logging>

[TensorRT-LLM][INFO] Executor instance created by worker
[08/12/2024-19:36:58] [TRT-LLM] [I] Starting response daemon...
[08/12/2024-19:36:58] [TRT-LLM] [I] Executor started.
[08/12/2024-19:36:58] [TRT-LLM] [I] Request serving started.
[08/12/2024-19:36:58] [TRT-LLM] [I] Starting statistics collection.
[08/12/2024-19:36:58] [TRT-LLM] [I] Benchmark started.
[08/12/2024-19:36:58] [TRT-LLM] [I] Collecting live stats...
[08/12/2024-19:36:59] [TRT-LLM] [I] Request serving stopped.
[08/12/2024-19:37:19] [TRT-LLM] [I] Collecting last stats...
[08/12/2024-19:37:19] [TRT-LLM] [I] Ending statistics collection.
[08/12/2024-19:37:19] [TRT-LLM] [I] Stop received.
[08/12/2024-19:37:19] [TRT-LLM] [I] Stopping response parsing.
[08/12/2024-19:37:19] [TRT-LLM] [I] Collecting last responses before shutdown.
[08/12/2024-19:37:19] [TRT-LLM] [I] Completed request parsing.
[08/12/2024-19:37:19] [TRT-LLM] [I] Parsing stopped.
[08/12/2024-19:37:19] [TRT-LLM] [I] Request generator successfully joined.
[08/12/2024-19:37:19] [TRT-LLM] [I] Statistics process successfully joined.
[08/12/2024-19:37:19] [TRT-LLM] [I]
===========================================================
= ENGINE DETAILS
===========================================================
Model:                  meta-llama/Llama-2-7b-hf
Engine Directory:       /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
TensorRT-LLM Version:   0.12.0
Dtype:                  float16
KV Cache Dtype:         FP8
Quantization:           FP8
Max Input Length:       256
Max Sequence Length:    256

===========================================================
= WORLD + RUNTIME INFORMATION
===========================================================
TP Size:                1
PP Size:                1
Max Runtime Batch Size: 4096
Max Runtime Tokens:     8192
Scheduling Policy:      Guaranteed No Evict
KV Memory Percentage:   90.0%
Issue Rate (req/sec):   2.0827970096792666e+19
===========================================================
= STATISTICS
===========================================================
Number of requests:             3000
Average Input Length (tokens):  128.0
Average Output Length (tokens): 128.0
Token Throughput (tokens/sec):  18886.813971319196
Request Throughput (req/sec):   147.55323415093122
Total Latency (seconds):        20.331645167
===========================================================

[TensorRT-LLM][INFO] Orchestrator sendReq thread exiting
[TensorRT-LLM][INFO] Orchestrator recv thread exiting
[TensorRT-LLM][INFO] Leader sendThread exiting
[TensorRT-LLM][INFO] Leader recvReq thread exiting
[TensorRT-LLM][INFO] Refreshed the MPI local session

Low Latency Benchmark

The low latency benchmark follows a similar workflow to the throughput benchmark but requires building the engine separately from trtllm-bench. Low latency benchmarks has the following modes:

  • A single-request low-latency engine

  • A Medusa-enabled speculative-decoding engine

Low Latency TensorRT-LLM Engine for Llama-3 70B

To build a low-latency engine for the latency benchmark, run the following quantize and build commands. The $checkpoint_dir is the path to the meta-llama/Meta-Llama-3-70B Hugging Face checkpoint in your cache or downloaded to a specific location with the huggingface-cli. To prepare a dataset, follow the same process as specified in Preparing a Dataset.

Benchmarking a non-Medusa Low Latency Engine

To quantize the checkpoint:

cd tensorrt_llm/examples/llama
python ../quantization/quantize.py \
    --model_dir $checkpoint_dir \
    --dtype bfloat16 \
    --qformat fp8 \
    --kv_cache_dtype fp8 \
    --output_dir /tmp/meta-llama/Meta-Llama-3-70B/checkpoint \
    --calib_size 512 \
    --tp_size $tp_size

then build,

trtllm-build \
    --checkpoint_dir /tmp/meta-llama/Meta-Llama-3-70B/checkpoint \
    --use_fused_mlp enable \
    --gpt_attention_plugin bfloat16 \
    --output_dir /tmp/meta-llama/Meta-Llama-3-70B/engine \
    --max_batch_size 1 \
    --max_seq_len $(($isl+$osl)) \
    --reduce_fusion enable \
    --gemm_plugin fp8 \
    --workers $tp_size \
    --use_fp8_context_fmha enable \
    --max_num_tokens $isl \
    --use_paged_context_fmha disable \
    --multiple_profiles enable

After the engine is built, run the low-latency benchmark:

env TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG=1 \
  TRTLLM_MMHA_KERNEL_BLOCK_SIZE=256 \
  TRTLLM_MMHA_BLOCKS_PER_SEQUENCE=32 \
  FORCE_MULTI_BLOCK_MODE=ON \
  TRTLLM_ENABLE_PDL=1 \
  trtllm-bench --model meta-llama/Meta-Llama-3-70B \
  latency \
  --dataset $DATASET_PATH \
  --engine_dir /tmp/meta-llama/Meta-Llama-3-70B/engine

Building a Medusa Low-Latency Engine

To build a Medusa-enabled engine requires checkpoints that contain Medusa heads. NVIDIA provides TensorRT-LLM checkpoints on the NVIDIA page on Hugging Face. The checkpoints are pre-quantized and can be directly built after downloading them with the huggingface-cli. After you download the checkpoints, run the following command. Make sure to specify the $tp_size supported by your Medusa checkpoint and the path to its stored location $checkpoint_dir. Additionally, $max_seq_len should be set to the model’s maximum position embedding.

Using Llama-3.1 70B as an example, for a tensor parallel 8 and bfloat16 dtype:

tp_size=8
max_seq_len=131072
trtllm-build --checkpoint_dir $checkpoint_dir \
    --speculative_decoding_mode medusa \
    --max_batch_size 1 \
    --gpt_attention_plugin bfloat16 \
    --max_seq_len $max_seq_len \
    --output_dir /tmp/meta-llama/Meta-Llama-3.1-70B/medusa/engine \
    --use_fused_mlp enable \
    --paged_kv_cache enable \
    --use_paged_context_fmha disable \
    --multiple_profiles enable \
    --reduce_fusion enable \
    --use_fp8_context_fmha enable \
    --workers $tp_size \
    --low_latency_gemm_plugin fp8

After the engine is built, you need to define the Medusa choices. The choices are specified with a YAML file like the following example (medusa.yaml):

- [0]
- [0, 0]
- [1]
- [0, 1]
- [2]
- [0, 0, 0]
- [1, 0]
- [0, 2]
- [3]
- [0, 3]
- [4]
- [0, 4]
- [2, 0]
- [0, 5]
- [0, 0, 1]

To run the Medusa-enabled engine, run the following command:

env TRTLLM_ENABLE_PDL=1 \
  UB_ONESHOT=1 \
  UB_TP_SIZE=$tp_size \
  TRTLLM_ENABLE_PDL=1 \
  TRTLLM_PDL_OVERLAP_RATIO=0.15 \
  TRTLLM_PREFETCH_RATIO=-1 \
  trtllm-bench --model meta-llama/Meta-Llama-3-70B \
  latency \
  --dataset $DATASET_PATH \
  --engine_dir /tmp/meta-llama/Meta-Llama-3-70B/medusa/engine \
  --medusa_choices medusa.yml

Summary

The following table summarizes the commands needed for running benchmarks:

Scenario

Phase

Command

Dataset

Preparation

python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer $HF_MODEL token-norm-dist --input-mean $ISL --output-mean $OSL --input-stdev 0 --output-stdev 0 --num-requests $NUM_REQUESTS > $DATASET_PATH

Throughput

Build

trtllm-bench --model $HF_MODEL build --dataset $DATASET_PATH

Throughput

Benchmark

trtllm-bench --model $HF_MODEL throughput --dataset $DATASET_PATH --engine_dir $ENGINE_DIR

Latency

Build

See section about building low latency engines

Non-Medusa Latency

Benchmark

trtllm-bench --model $HF_MODEL latency --dataset $DATASET_PATH --engine_dir $ENGINE_DIR

Medusa Latency

Benchmark

trtllm-bench --model $HF_MODEL latency --dataset $DATASET_PATH --engine_dir $ENGINE_DIR --medusa_choices $MEDUSA_CHOICES

where,

$HF_MODEL

The Hugging Face name of a model.

$NUM_REQUESTS

The number of requests to generate.

$DATASET_PATH

The path where the dataset was written when preparing the dataset.

$ENGINE_DIR

The engine directory as printed by trtllm-bench build.

$MEDUSA_CHOICES

A YAML config representing the Medusa tree for the benchmark.