utils

PyTorch LLaMA model.

Classes

LlamaAttention

Multi-headed attention from 'Attention Is All You Need' paper.

LlamaDecoderLayer

LlamaDecoderLayer class.

LlamaMLP

LlamaMLP class.

LlamaRMSNorm

LlamaRMSNorm class.

LlamaRotaryEmbedding

Llama Rotary Embedding.

Functions

apply_rotary_pos_emb

Apply rotary position embedding.

expand_mask

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

make_causal_mask

Make causal mask used for bi-directional self-attention.

repeat_kv

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).

rotate_half

Rotates half the hidden dims of the input.

class LlamaAttention

Bases: Module

Multi-headed attention from ‘Attention Is All You Need’ paper.

__init__(hidden_size, num_attention_heads, num_key_value_heads, max_position_embeddings, rope_theta)

Init function for LlamaAttention.

forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

Forward function for LlamaAttention.

Parameters:
  • hidden_states (Tensor) –

  • attention_mask (Tensor | None) –

  • position_ids (LongTensor | None) –

  • past_key_value (Tuple[Tensor] | None) –

  • output_attentions (bool) –

  • use_cache (bool) –

Return type:

Tuple[Tensor, Tensor | None, Tuple[Tensor] | None]

class LlamaDecoderLayer

Bases: Module

LlamaDecoderLayer class.

__init__(index, hidden_size, intermediate_size=14336, rms_norm_eps=1e-05, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=131072, rope_theta=500000.0)

Init function for LlamaDecoderLayer.

forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False)

Forward function for LlamaDecoderLayer.

Parameters:
  • hidden_states (Tensor) –

  • attention_mask (Tensor | None) –

  • position_ids (LongTensor | None) –

  • past_key_value (Tuple[Tensor] | None) –

  • output_attentions (bool | None) –

  • use_cache (bool | None) –

Return type:

Tuple[FloatTensor, Tuple[FloatTensor, FloatTensor] | None]

class LlamaMLP

Bases: Module

LlamaMLP class.

__init__(hidden_size, intermediate_size)

Init function for LlamaMLP.

forward(x)

Forward function for LlamaMLP.

class LlamaRMSNorm

Bases: Module

LlamaRMSNorm class.

__init__(hidden_size, eps=1e-06)

LlamaRMSNorm is equivalent to T5LayerNorm.

forward(hidden_states)

Forward function for LlamaRMSNorm.

class LlamaRotaryEmbedding

Bases: Module

Llama Rotary Embedding.

__init__(dim, max_position_embeddings=2048, base=10000, device=None)

Init function for LlamaRotaryEmbedding.

forward(x, seq_len=None)

Forward function for LlamaRotaryEmbedding.

apply_rotary_pos_emb(q, k, cos, sin, position_ids)

Apply rotary position embedding.

expand_mask(mask, dtype, tgt_len=None)

Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].

Parameters:
  • mask (Tensor) –

  • dtype (dtype) –

  • tgt_len (int | None) –

make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0)

Make causal mask used for bi-directional self-attention.

Parameters:
  • input_ids_shape (Size) –

  • dtype (dtype) –

  • device (device) –

  • past_key_values_length (int) –

repeat_kv(hidden_states, n_rep)

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).

The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

Parameters:
  • hidden_states (Tensor) –

  • n_rep (int) –

Return type:

Tensor

rotate_half(x)

Rotates half the hidden dims of the input.