kv_heads_pruning_mixin

Classes

KVHeadsLayerDescriptor

KVHeadsLayerDescriptor(o_proj_name: str, attn_prefix_name: str, qkvo_weight_names: List[str] = <factory>)

KVHeadsPruningMixIn

class KVHeadsLayerDescriptor

Bases: LayerDescriptor

KVHeadsLayerDescriptor(o_proj_name: str, attn_prefix_name: str, qkvo_weight_names: List[str] = <factory>)

__init__(o_proj_name, attn_prefix_name, qkvo_weight_names=<factory>)
Parameters:
  • o_proj_name (str)

  • attn_prefix_name (str)

  • qkvo_weight_names (List[str])

Return type:

None

attn_prefix(layer_idx)
Parameters:

layer_idx (int)

Return type:

str

attn_prefix_name: str
module_name_regex()
Return type:

str

o_proj_name: str
qkvo_weight_names: List[str]
class KVHeadsPruningMixIn

Bases: PruningMixIn

__init__(layer_descriptor)
Parameters:

layer_descriptor (KVHeadsLayerDescriptor)

prune_single_layer(layer_idx, parent_state_dict, new_state_dict, original_config, new_config, gqa_init_mode, mlp_init_config, is_original_mha, keys, keys_to_remove, **kwargs)
Parameters:
  • layer_idx (int)

  • parent_state_dict (dict)

  • new_state_dict (dict)

  • original_config (PreTrainedConfig)

  • new_config (PreTrainedConfig)

  • gqa_init_mode (GQAInitMode)

  • mlp_init_config (dict[str, Any] | None)

  • is_original_mha (bool)

  • keys (dict)

  • keys_to_remove (dict)

supported_hooks()
Return type:

List[Type[ForwardHook]]