kv_heads_pruning_mixin
Classes
KVHeadsLayerDescriptor(o_proj_name: str, attn_prefix_name: str, qkvo_weight_names: List[str] = <factory>) |
|
- class KVHeadsLayerDescriptor
Bases:
LayerDescriptorKVHeadsLayerDescriptor(o_proj_name: str, attn_prefix_name: str, qkvo_weight_names: List[str] = <factory>)
- __init__(o_proj_name, attn_prefix_name, qkvo_weight_names=<factory>)
- Parameters:
o_proj_name (str)
attn_prefix_name (str)
qkvo_weight_names (List[str])
- Return type:
None
- attn_prefix(layer_idx)
- Parameters:
layer_idx (int)
- Return type:
str
- attn_prefix_name: str
- module_name_regex()
- Return type:
str
- o_proj_name: str
- qkvo_weight_names: List[str]
- class KVHeadsPruningMixIn
Bases:
PruningMixIn- __init__(layer_descriptor)
- Parameters:
layer_descriptor (KVHeadsLayerDescriptor)
- prune_single_layer(layer_idx, parent_state_dict, new_state_dict, original_config, new_config, gqa_init_mode, mlp_init_config, is_original_mha, keys, keys_to_remove, **kwargs)
- Parameters:
layer_idx (int)
parent_state_dict (dict)
new_state_dict (dict)
original_config (PreTrainedConfig)
new_config (PreTrainedConfig)
gqa_init_mode (GQAInitMode)
mlp_init_config (dict[str, Any] | None)
is_original_mha (bool)
keys (dict)
keys_to_remove (dict)
- supported_hooks()
- Return type:
List[Type[ForwardHook]]