Talker Mlp Kernels#
- void trt_edgellm::kernel::invokeTalkerMLP(
- void *cublasHandle,
- rt::Tensor const &input,
- rt::Tensor const &fc1Weight,
- rt::Tensor const &fc1Bias,
- rt::Tensor const &fc2Weight,
- rt::Tensor const &fc2Bias,
- rt::Tensor &output,
- rt::Tensor &workspace,
- cudaStream_t stream
Two-layer MLP with SiLU activation (Talker projection layers)
Performs: output = FC2(SiLU(FC1(input) + bias1)) + bias2 Where FC1: [inputDim → hiddenDim], FC2: [hiddenDim → outputDim]
Architecture: input [N, 2048] ↓ FC1 (Linear) [N, 2048] + bias1 ↓ SiLU [N, 2048] ↓ FC2 (Linear) [N, 1024] + bias2 ↓ output [N, 1024]
Note
Weight matrices are stored in column-major format (cuBLAS convention)
Note
Workspace must be pre-allocated with size [numTokens, 2048] * sizeof(half)
- Parameters:
cublasHandle – [in] cuBLAS handle for GEMM operations
input – [in] Input tensor with shape [numTokens, 2048] (FP16)
fc1Weight – [in] FC1 weight matrix with shape [2048, 2048] (FP16, column-major)
fc1Bias – [in] FC1 bias vector with shape [2048] (FP16)
fc2Weight – [in] FC2 weight matrix with shape [2048, 1024] (FP16, column-major)
fc2Bias – [in] FC2 bias vector with shape [1024] (FP16)
output – [out] Output tensor with shape [numTokens, 1024] (FP16)
workspace – [inout] Workspace buffer for intermediate FC1 output [numTokens, 2048] (FP16)
stream – [in] CUDA stream for execution
- void trt_edgellm::kernel::invokeLinearLayer(
- void *cublasHandle,
- rt::Tensor const &input,
- rt::Tensor const &weight,
- rt::Tensor const &bias,
- rt::Tensor &output,
- cudaStream_t stream
Single linear layer: output = input @ weight.T + bias.
- Parameters:
cublasHandle – [in] cuBLAS handle for GEMM operations
input – [in] Input tensor with shape [N, inputDim] (FP16)
weight – [in] Weight matrix with shape [outputDim, inputDim] (FP16, row-major)
bias – [in] Bias vector with shape [outputDim] (FP16)
output – [out] Output tensor with shape [N, outputDim] (FP16)
stream – [in] CUDA stream for execution
- void trt_edgellm::kernel::invokeGather( )#
Gather operation: select rows from source tensor by indices.
Performs: output[i] = source[indices[i]] where each row has hiddenDim elements.
- Parameters:
source – [in] Source tensor with shape [srcNumTokens, hiddenDim] (FP16)
indices – [in] Indices tensor with shape [numIndices] (INT32)
output – [out] Output tensor with shape [numIndices, hiddenDim] (FP16)
stream – [in] CUDA stream for execution
- void trt_edgellm::kernel::invokeScatter( )#
Scatter operation: place rows from source to output by indices.
Performs: output[indices[i]] = source[i] where each row has hiddenDim elements.
- Parameters:
source – [in] Source tensor with shape [numIndices, hiddenDim] (FP16)
indices – [in] Indices tensor with shape [numIndices] (INT32)
output – [out] Output tensor with shape [dstNumTokens, hiddenDim] (FP16)
stream – [in] CUDA stream for execution
- void trt_edgellm::kernel::invokeAssistantPreamble(
- rt::Tensor const &projected,
- rt::Tensor const &ttsPadEmbed,
- rt::Tensor const &ttsBosEmbed,
- rt::Tensor const &ttsEosEmbed,
- rt::Tensor const &talkerEmbTable,
- int32_t codecNothinkId,
- int32_t codecThinkBosId,
- int32_t codecThinkEosId,
- int32_t speakerId,
- int32_t codecPadId,
- int32_t codecBosId,
- int32_t textLen,
- rt::Tensor &output,
- cudaStream_t stream
Fused non-streaming assistant preamble construction for TTS input projection.
Builds the complete non-streaming prefill buffer in one pass. Total rows written = 8 + textLen + 2 (= seqLen + 2).
Row layout (written at outputOffset):
[3]: ttsPadEmbed + talkerEmbTable[codecNothinkId] [4]: ttsPadEmbed + talkerEmbTable[codecThinkBosId] [5]: ttsPadEmbed + talkerEmbTable[codecThinkEosId] [6]: ttsPadEmbed + talkerEmbTable[speakerId] [7]: ttsBosEmbed + talkerEmbTable[codecPadId] [8..8+N-1]: projected[3+i] + talkerEmbTable[codecPadId] (text tokens, N=textLen) [8+N]: ttsEosEmbed + talkerEmbTable[codecPadId] [8+N+1]: ttsPadEmbed + talkerEmbTable[codecBosId]
- Parameters:
projected – MLP output [seqLen, H] (FP16)
ttsPadEmbed/ttsBosEmbed/ttsEosEmbed – TTS special embeddings [H] (FP16)
talkerEmbTable – Talker embedding table [vocabSize, H] (FP16)
codecNothinkId..codecBosId – Codec token IDs used in rows [3-8+N+1]
speakerId – Speaker codec token ID (row 6)
textLen – Number of text token rows (N = seqLen - 8)
output – Full output buffer [8+N+2, H] (FP16)
stream – CUDA stream
- void trt_edgellm::kernel::invokeResidualConnection(
- rt::Tensor const &codecHiddens,
- rt::Tensor const &embTable0,
- rt::Tensor const &embTable15,
- int32_t code0,
- int32_t code15,
- half const *addend,
- rt::Tensor &output,
- cudaStream_t stream
Fused residual connection for TTS decode input.
Computes: output = embed0[code0] + embed15[code15] + addend + sum(codecHiddens[1..14]) Eliminates 7 separate dispatches (2x H→D, 2x embLookup, 2x D→D, sumReduce) in one kernel.
- Parameters:
codecHiddens – [1, 16, H] buffer — rows 1-14 pre-filled by CodePredictor (FP16)
embTable0 – Talker embedding table [vocabSize, H] (FP16) — for embed(code0)
embTable15 – CodePredictor embedding table[-1] [vocabSize, H] (FP16) — for embed(code15)
code0/code15 – Token IDs passed as scalars (no H→D upload needed)
addend – Row pointer [H] — trailing_text_hidden[generationStep] or tts_pad_embed (FP16)
output – Output tensor [1, 1, H] (FP16)
stream – CUDA stream
- void trt_edgellm::kernel::invokeTalkerLogitAdjust(
- rt::Tensor const &seenTokens,
- rt::Tensor &logits,
- int32_t suppressStart,
- int32_t suppressEnd,
- int32_t codecEosId,
- int32_t numSeenTokens,
- float repetitionPenalty,
- cudaStream_t stream
Adjust Talker logits: suppress special tokens and apply repetition penalty.
Performs two in-place modifications on the logits before sampling:
Suppression: sets logits[i] = -inf for all i in [suppressStart, suppressEnd), except for codecEosId which is always preserved.
Repetition penalty: for each token in seenTokens[], divides positive logits by repetitionPenalty and multiplies negative logits by repetitionPenalty, matching the HuggingFace repetition_penalty convention.
Operates on FP32 logits tensor with shape [1, vocabSize].
- Parameters:
seenTokens – [in] GPU tensor of previously generated token IDs [maxAudioLength] INT32
logits – [inout] Logits tensor [1, vocabSize] (FP32, in-place)
suppressStart – [in] Start of suppress range (inclusive)
suppressEnd – [in] End of suppress range (exclusive)
codecEosId – [in] Token ID exempt from suppression (EOS must remain samplable)
numSeenTokens – [in] Number of valid entries in seenTokens (0 to disable penalty)
repetitionPenalty – [in] Penalty factor >= 1.0 (1.0 = no penalty)
stream – [in] CUDA stream for execution