medusa_model
Medusa model to support medusa decoding.
Classes
Base Medusa Model. |
|
A Residual Block module. |
- class MedusaModel
Bases:
DynamicModule
Base Medusa Model.
- modify(medusa_num_heads=0, medusa_num_layers=0)
Base Medusa Model modify function. Child class should implement the details.
- class ResBlock
Bases:
Module
A Residual Block module.
This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection.
- Parameters:
hidden_size (int) – The size of the hidden layers in the block.
- __init__(hidden_size)
Init function of ResBlock.
Args: hidden_size (int): The size of the hidden layers in the block.
- forward(x)
Forward pass of the ResBlock.
- Parameters:
x (torch.Tensor) – Input tensor.
- Returns:
Output after the residual connection and activation.
- Return type:
torch.Tensor