C++ GPT Runtime
TensorRT-LLM includes a C++ component to execute TensorRT engines built with the Python API as described in the TensorRT-LLM Architecture section. That component is called the C++ runtime.
The API of the C++ runtime is composed of the classes declared in
cpp/include/tensorrt_llm/runtime
and
implemented in cpp/tensorrt_llm/runtime
.
Even if the different components described in that document mention GPT in their name, they are not restricted to this specific model. Those classes can be used to implement auto-regressive models like BLOOM, GPT-J, GPT-NeoX or LLaMA, for example.
Complete support of encoder-decoder models, like T5, will be added to
TensorRT-LLM in a future release. An experimental version, only in Python for
now, can be found in the examples/enc_dec
folder.
Overview
Runtime models are described by an instance of the
ModelConfig
class and a pointer to the TensorRT engine that must be
executed to perform the inference.
The environment is configured through the
WorldConfig
(that name comes from
MPI and its “famous”
MPI_COMM_WORLD
default communicator).
The SamplingConfig
class encapsulates parameters that control the
generation of new tokens.
Model Configuration
The model configuration is an instance of the
ModelConfig
class.
That class encapsulates the following parameters (they are declared as private
member variables and exposed through getters and setters):
vocabSize
, the size of the vocabulary,numLayers
, the number of layers in the model,numHeads
, the number of heads in the attention block,numKvHeads
, the number of heads for K and V in the attention component. When the number of K/V heads is the same as the number of (Q) heads, the model uses multi-head attention. When the number of K/V heads is 1, it uses multi-query attention. Otherwise, it uses group-query attention. Refer to Multi-Head, Multi-Query, and Group-Query Attention for more information,hiddenSize
, the size of the hidden dimension,dataType
, the datatype that was used to build the TensorRT engine and that must be used to run the model during inference,useGptAttentionPlugin
, indicates if the Multi-Head, Multi-Query, and Group-Query Attention operator was compiled using the GPT Attention plugin,inputPacked
, indicates that the input must be packed (or padded when set tofalse
). For performance reasons, it is recommended to always use packed, even if its default is set tofalse
(will be changed in a future release). Refer to Multi-Head, Multi-Query, and Group-Query Attention for more information,pagedKvCache
, indicates if the K/V cache uses paging. Refer to Multi-Head, Multi-Query, and Group-Query Attention for more information,tokensPerBlock
, is the number of tokens in each block of the K/V cache. It’s relevant when the paged K/V cache is enabled. By default, the value is 64. Refer to Multi-Head, Multi-Query, and Group-Query Attention for more information,quantMode
, controls the quantization method. Refer to Numerical Precision for more information.maxBatchSize
, indicates the maximum batch size that the TensorRT engine was built for,maxInputLen
, the maximum size of the input sequences,maxSequenceLen
, the maximum total size (input+output) of the sequences.
World Configuration
Familiarity with MPI, is not required to utilize the TensorRT-LMM C++ runtime. There are two main things you need to know:
The C++ Runtime in TensorRT-LLM uses processes to execute TensorRT engines on the different GPUs. Those GPUs can be located on a single node as well as on different nodes in a cluster. Each process is called a rank in MPI.
The ranks are grouped in communication groups. The TensorRT-LLM C++ Runtime calls that group the world.
The world configuration is an instance of the
WorldConfig
class, which encapsulates the following parameters:
tensorParallelism
, the number of ranks that collaborate together to implement Tensor Parallelism (TP). With TP, each GPU performs computations for all the layers of the model. Some of those computations are distributed across the GPU. TP is more balanced than Pipeline Parallelism (PP), in most cases, but requires higher bandwidth between the GPUs. It is the recommended setting in the presence of NVLINK between GPUs,pipelineParallelism
, the number of ranks that collaborate together to implement Pipeline Parallelism (PP). With PP, each GPU works on a subset of consecutive layers. Communications between the GPUs happen only at the boundaries of the subsets of layers. It is harder to guarantee the full utilization of the GPUs with PP but it requires less memory bandwidth. It is the recommended setting in the absence of NVLINK between GPUs,rank
, the unique identifier of the rank,gpusPerNode
, indicates the number of GPUs on each node. Having that information allows the C++ runtime to optimize communications between GPUs in a node (like taking advantage of the NVLINK interconnect between GPUs of an A100 DGX node).
Sampling Parameters
The SamplingConfig
class encapsulates parameters that control the
generation of new tokens.
A comparison of selecting decoding method is listed as the table below (X
means it is not supported yet).
Except for the beamWidth
parameter, all the fields are optional and the
runtime will use a default value if no values are provided by the user. For
vector fields, the TensorRT-LLM runtime supports one value per sequence (that is,
the vector contains batchSize
values). If all the sequences use the same
value for a given parameter, the vector can be limited to a single element
(that is, size() == 1
).
Method name in HF |
Condition in HF |
Method name in TRT-LLM |
Condition in TRT-LLM |
---|---|---|---|
assisted decoding |
|
X |
|
beam-search decoding |
|
beam search |
|
beam-search multinomial sampling |
|
X |
|
constrained beam-search decoding |
|
X |
|
contrastive search |
|
X |
|
diverse beam-search decoding |
|
X |
|
greedy decoding |
|
sampling |
|
multinomial sampling |
|
sampling |
|
General
Name in TRT-LLM |
Description |
Data type |
Range of value |
Default value |
Name in HF |
---|---|---|---|---|---|
|
modulation of logits in sampling workflow |
List[Float] |
[0.0f, $+\infty$) |
|
|
|
lower-bound on the number of tokens generated |
List[Int] |
[0, $+\infty$) |
|
|
|
penalize repetitive tokens |
List[Float] |
[0.0f, $+\infty$) |
|
|
|
penalize existed tokens |
List[Float] |
($-\infty$, $+\infty$) |
|
no |
|
penalize existed tokens |
List[Float] |
($-\infty$, $+\infty$) |
|
no |
|
List[Int] |
[0, $+\infty$) |
|
|
The tokens of input prompt are included during adopting
repetitionPenalty
,presencePenalty
, andfrequencyPenalty
onto logits.The parameters
repetitionPenalty
,presencePenalty
, andfrequencyPenalty
are not mutually exclusive.
Sampling
Name in TRT-LLM |
Description |
Data type |
Range of value |
Default value |
Name in HF |
---|---|---|---|---|---|
|
random seed for random number generator |
Int64 |
[0, 2^64-1] |
|
no |
|
the number of logits to sample from |
List[Int] |
[0, 1024] |
|
|
|
the top-P probability to sample from |
List[Float] |
[0.0f, 1.0f] |
|
|
|
the decay in the |
List[Float] |
(0.0f, 1.0f] |
|
no |
|
the decay in the |
List[Float] |
(0.0f, 1.0f] |
|
no |
|
the decay in the |
List[Int] |
[-1, $+\infty$) |
|
no |
If setting
topK = 0
andtopP = 0.0f
, greedy search is performed.If setting
topK > 0
andtopP = 0.0f
,topK
tokens of highest probilities will become the candidates of sampling (namedTopK sampling
in TRT-LLM).If setting
topK = 0
andtopP > 0.0f
, tokens will be sorted with probility descendly, then the tokens with highest probilities which the accumulated probility larger thantopP
will become the candidates of sampling (namedTopP sampling
in TRT-LLM).If setting
topK > 0
andtopP > 0.0f
,topK
tokens of highest probilities will be selected, then those selected tokens will be sorted with probility descendly and their probility will be normalized, then the tokens with highest normalized probilities which the accumulated probility larger thantopP
will become the candidates of sampling (namedTopKTopP sampling
in TRT-LLM)If different
topK
values are provided for the different sequences in the batch, the performance of the implementation will depend on the largest value. For efficiency reasons, we recommend to batch requests with similartopK
values together.topPDecay
,topPMin
andtopPResetIds
are explained in Factuality Enhanced Language Models for Open-Ended Text Generation.topPDecay
is the decay,topPMin
is the lower-bound andtopPResetIds
indicates where to reset the decay.
Beam-search
Name in TRT-LLM |
Description |
Data type |
Range of value |
Default value |
Name in HF |
---|---|---|---|---|---|
|
width for beam-search algorithm |
Int |
[0, 64] |
|
|
|
diversity of generated tokens |
List[Float] |
[0, $+\infty$) |
|
|
|
penalize longer sequences |
List[Float] |
[0, $+\infty$) |
|
|
|
see description below |
List[Int] |
($-\infty$, $+\infty$) |
|
|
Beam-search algorithm: beam search.
Parameter
diversity_penalty
in HF is only used fordiverse beam-search decoding
(or namedGroup-Beam-Search
), which is not supported by TRT-LLM yet.If setting
earlyStopping = 1
, decoding will stop oncebeamWidth
finished sentences are generated.If setting
earlyStopping = 0
, decoding will keep going until no better sentences (with better score) can be generated.If setting
earlyStopping
to other values, decoding will stop only depending onlengthlengthPenalty
.The
beamWidth
parameter is a scalar value. It means that in this release of TensorRT-LLM, it is not possible to specify a different width for each input sequence. This limitation is likely to be removed in a future release.
The Session
The runtime session is deprecated in favor of the Executor API. It will be removed in a future release of TensorRT-LLM.
An example of how to use the GptSession
to run a GPT-like auto-regressive model can be found in
cpp/tests/runtime/gptSessionTest.cpp
.
Internal Components
The GptSession
class encapsulates two main components. The
TllmRuntime
is in charge of the
execution of the TensorRT engine. The
GptDecoder
does the generation of the tokens from the logits. The TllmRuntime
class is
an internal component and you are not expected to use that class directly.
The GptDecoder
can be used directly to implement custom generation loop
and for use cases that cannot be satisfied by the implementation in
GptSession
.
In-flight Batching Support
In-flight batching is supported using separate decoders per
request. The biggest difference compared to using a single decoder is in how
the token generation from logits is managed. A batch is split into batchSize
individual requests and kernels are issued using separated CUDA streams.
This behavior may be revisited in a future release to maintain the structure
of the batch and improve efficiency.
Know Issues and Future Changes
In the current release of TensorRT-LLM, the C++ and Python runtimes are two separate software components and the C++ runtime is being more actively developed (with features like in-flight batching). An objective, for a future release, could be to rebuild the Python runtime on top of the C++ one.