utils
PyTorch LLaMA model.
Classes
Multi-headed attention from 'Attention Is All You Need' paper. |
|
LlamaDecoderLayer class. |
|
LlamaMLP class. |
|
LlamaRMSNorm class. |
|
Llama Rotary Embedding. |
Functions
Apply rotary position embedding. |
|
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len]. |
|
Make causal mask used for bi-directional self-attention. |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). |
|
Rotates half the hidden dims of the input. |
- class LlamaAttention
Bases:
Module
Multi-headed attention from ‘Attention Is All You Need’ paper.
- __init__(hidden_size, num_attention_heads, num_key_value_heads, max_position_embeddings, rope_theta)
Init function for LlamaAttention.
- forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
Forward function for LlamaAttention.
- Parameters:
hidden_states (Tensor) –
attention_mask (Tensor | None) –
position_ids (LongTensor | None) –
past_key_value (Tuple[Tensor] | None) –
output_attentions (bool) –
use_cache (bool) –
- Return type:
Tuple[Tensor, Tensor | None, Tuple[Tensor] | None]
- class LlamaDecoderLayer
Bases:
Module
LlamaDecoderLayer class.
- __init__(index, hidden_size, intermediate_size=14336, rms_norm_eps=1e-05, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=131072, rope_theta=500000.0)
Init function for LlamaDecoderLayer.
- forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)
Forward function for LlamaDecoderLayer.
- Parameters:
hidden_states (Tensor) –
attention_mask (Tensor | None) –
position_ids (LongTensor | None) –
past_key_value (Tuple[Tensor] | None) –
output_attentions (bool | None) –
use_cache (bool | None) –
- Return type:
Tuple[FloatTensor, Tuple[FloatTensor, FloatTensor] | None]
- class LlamaMLP
Bases:
Module
LlamaMLP class.
- __init__(hidden_size, intermediate_size)
Init function for LlamaMLP.
- forward(x)
Forward function for LlamaMLP.
- class LlamaRMSNorm
Bases:
Module
LlamaRMSNorm class.
- __init__(hidden_size, eps=1e-06)
LlamaRMSNorm is equivalent to T5LayerNorm.
- forward(hidden_states)
Forward function for LlamaRMSNorm.
- class LlamaRotaryEmbedding
Bases:
Module
Llama Rotary Embedding.
- __init__(dim, max_position_embeddings=2048, base=10000, device=None)
Init function for LlamaRotaryEmbedding.
- forward(x, seq_len=None)
Forward function for LlamaRotaryEmbedding.
- apply_rotary_pos_emb(q, k, cos, sin, position_ids)
Apply rotary position embedding.
- expand_mask(mask, dtype, tgt_len=None)
Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].
- Parameters:
mask (Tensor) –
dtype (dtype) –
tgt_len (int | None) –
- make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)
Make causal mask used for bi-directional self-attention.
- Parameters:
input_ids_shape (Size) –
dtype (dtype) –
device (device) –
past_key_values_length (int) –
- repeat_kv(hidden_states, n_rep)
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- Parameters:
hidden_states (Tensor) –
n_rep (int) –
- Return type:
Tensor
- rotate_half(x)
Rotates half the hidden dims of the input.