medusa_model

Medusa model to support medusa decoding.

Classes

MedusaModel

Base Medusa Model.

ResBlock

A Residual Block module.

class MedusaModel

Bases: DynamicModule

Base Medusa Model.

modify(medusa_num_heads=0, medusa_num_layers=0)

Base Medusa Model modify function. Child class should implement the details.

class ResBlock

Bases: Module

A Residual Block module.

This module performs a linear transformation followed by a SiLU activation, and then adds the result to the original input, creating a residual connection.

Parameters:

hidden_size (int) – The size of the hidden layers in the block.

__init__(hidden_size)

Init function of ResBlock.

Args: hidden_size (int): The size of the hidden layers in the block.

forward(x)

Forward pass of the ResBlock.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output after the residual connection and activation.

Return type:

torch.Tensor