ffn_intermediate_pruning_mixin
Classes
FFNIntermediateLayerDescriptor(down_proj_name: str, ffn_prefix_name: str, linear_weight_names: List[str] = <factory>) |
|
- class FFNIntermediateLayerDescriptor
Bases:
LayerDescriptorFFNIntermediateLayerDescriptor(down_proj_name: str, ffn_prefix_name: str, linear_weight_names: List[str] = <factory>)
- __init__(down_proj_name, ffn_prefix_name, linear_weight_names=<factory>)
- Parameters:
down_proj_name (str)
ffn_prefix_name (str)
linear_weight_names (List[str])
- Return type:
None
- down_proj_name: str
- ffn_prefix(layer_idx)
- Parameters:
layer_idx (int)
- Return type:
str
- ffn_prefix_name: str
- linear_weight_names: List[str]
- module_name_regex()
- Return type:
str
- class FFNIntermediatePruningMixIn
Bases:
PruningMixIn- __init__(layer_descriptor)
- Parameters:
layer_descriptor (FFNIntermediateLayerDescriptor)
- prune_single_layer(layer_idx, parent_state_dict, new_state_dict, original_config, new_config, mlp_init_mode, mlp_init_config, keys, keys_to_remove, **kwargs)
- Parameters:
layer_idx (int)
parent_state_dict (dict)
new_state_dict (dict)
original_config (PreTrainedConfig)
new_config (PreTrainedConfig)
mlp_init_mode (MlpInitMode)
mlp_init_config (dict[str, Any] | None)
keys (dict)
keys_to_remove (dict)
- Return type:
Dict[str, Tensor]
- supported_hooks()
- Return type:
List[Type[ForwardHook]]