nemotron_h_v2_model_descriptor
Classes
NemotronHV2FFNIntermediateLayerDescriptor(down_proj_name: str = 'mixer.down_proj', ffn_prefix_name: str = 'backbone.layers.{layer_idx}.mixer', linear_weight_names: List[str] = <factory>) |
|
- class NemotronHV2FFNIntermediateLayerDescriptor
Bases:
FFNIntermediateLayerDescriptorNemotronHV2FFNIntermediateLayerDescriptor(down_proj_name: str = ‘mixer.down_proj’, ffn_prefix_name: str = ‘backbone.layers.{layer_idx}.mixer’, linear_weight_names: List[str] = <factory>)
- __init__(down_proj_name='mixer.down_proj', ffn_prefix_name='backbone.layers.{layer_idx}.mixer', linear_weight_names=<factory>)
- Parameters:
down_proj_name (str)
ffn_prefix_name (str)
linear_weight_names (List[str])
- Return type:
None
- down_proj_name: str = 'mixer.down_proj'
- ffn_prefix_name: str = 'backbone.layers.{layer_idx}.mixer'
- linear_weight_names: List[str]
- class NemotronHV2ModelDescriptor
Bases:
ModelDescriptor- static attn_no_op_post_init(decoder_layer)
- static block_config_to_layer_overrides(block_config)
- Parameters:
block_config (BlockConfig)
- classmethod create_dummy_block(original_layer, block_index)
- Parameters:
original_layer (Module)
block_index (int)
- Return type:
Module
- static decoder_layer_cls()
- static final_norm_name()
- classmethod get_weight_groups(layer_names, num_hidden_layers)
Problem with NemotronH is that norm.weight can be in both block_{i}_ffn and block_{i}_attention. duplicate groups with norm.weight should be removed.
- Parameters:
layer_names (Iterable[str])
num_hidden_layers (int)
- Return type:
Dict[str, List[str]]
- static init_rotary_embedding(model, runtime)
NemotronH has no positional embeddings
- static input_embedding_name()
- static layer_block_name(index)
- Parameters:
index (int)
- static layer_name_predicates(num_layers)
- Parameters:
num_layers (int)
- Return type:
Dict[str, Pattern]
- static mlp_no_op_post_init(decoder_layer)
- static output_embedding_name()
- static pruning_mixins()
- Return type:
Dict[str, PruningMixIn]
- static requires_trust_remote_code()
- Return type:
bool