ffn_intermediate_pruning_mixin

Classes

FFNIntermediateLayerDescriptor

FFNIntermediateLayerDescriptor(down_proj_name: str, ffn_prefix_name: str, linear_weight_names: List[str] = <factory>)

FFNIntermediatePruningMixIn

class FFNIntermediateLayerDescriptor

Bases: LayerDescriptor

FFNIntermediateLayerDescriptor(down_proj_name: str, ffn_prefix_name: str, linear_weight_names: List[str] = <factory>)

__init__(down_proj_name, ffn_prefix_name, linear_weight_names=<factory>)
Parameters:
  • down_proj_name (str)

  • ffn_prefix_name (str)

  • linear_weight_names (List[str])

Return type:

None

down_proj_name: str
ffn_prefix(layer_idx)
Parameters:

layer_idx (int)

Return type:

str

ffn_prefix_name: str
linear_weight_names: List[str]
module_name_regex()
Return type:

str

class FFNIntermediatePruningMixIn

Bases: PruningMixIn

__init__(layer_descriptor)
Parameters:

layer_descriptor (FFNIntermediateLayerDescriptor)

prune_single_layer(layer_idx, parent_state_dict, new_state_dict, original_config, new_config, mlp_init_mode, mlp_init_config, keys, keys_to_remove, **kwargs)
Parameters:
  • layer_idx (int)

  • parent_state_dict (dict)

  • new_state_dict (dict)

  • original_config (PreTrainedConfig)

  • new_config (PreTrainedConfig)

  • mlp_init_mode (MlpInitMode)

  • mlp_init_config (dict[str, Any] | None)

  • keys (dict)

  • keys_to_remove (dict)

Return type:

Dict[str, Tensor]

supported_hooks()
Return type:

List[Type[ForwardHook]]