base

Classes

ModelDescriptor

class ModelDescriptor

Bases: ABC

static attn_no_op_post_init(decoder_layer)

Post-init callback to alter a decoder layer so that Attention subblock performs as no-op.

It is recommended to use the utils modules from no_op.py to replace layers to dummy counterparts.

Example for replacing a layernorm layer with identity:

>>> decoder_layer.post_attention_layernorm = Same()

Example for replacing an attention layer with zeroes:

>>> decoder_layer.self_attn = MatchingZeros()

In case the attention layer returns multiple outputs i.e hidden_states, _ = self.self_attn(), use the util method return_tuple_of_size to return trailing None values:

>>> decoder_layer.self_attn = return_tuple_of_size(MatchingZeros, size=2)()
Parameters:

decoder_layer (Module)

classmethod attn_no_op_supported()

Check whether attn_no_op_post_init is overridden for attention no-op support.

abstract static block_config_to_layer_overrides(block_config)

Map between BlockConfig and layer config overrides.

These overrides are consumed by a specific decoder layer and by the whole model. Usage can be seen in deci_x_patcher under the method _patched_decoder_layer_init.

Example implementation to override the FFN intermediate size of a block:
>>> def block_config_to_layer_overrides(block_config: BlockConfig) -> Dict[str, Any]:
>>>     return {"intermediate_size": block_config.ffn.intermediate_size}
Parameters:

block_config (BlockConfig)

Return type:

Dict[str, Any]

classmethod create_dummy_block(original_layer, block_index)

Create a dummy block to replace a layer for sharded model initialization.

Parameters:
  • original_layer (Module)

  • block_index (int)

Return type:

Module

abstract static decoder_layer_cls()

Decoder layer class types to patch for heterogeneous config support.

In most cases this class will hold as attributes both FFN & attention layers.

Returns:

nn.Module class type or a list if several class types should be patched.

Return type:

Type[Module] | List[Type[Module]]

abstract static final_norm_name()

Return the name of the final normalization layer.

static get_language_model_config(config)

Get the language model config from a PretrainedConfig.

For regular LM models, returns the config itself. For VL/multimodal models with nested configs, override to return the language model portion (e.g., config.text_config for Qwen-VL).

classmethod get_weight_groups(layer_names, num_hidden_layers)

Group model weights to support the puzzle subblock checkpointing format.

This method uses the abstract method layer_name_predicates by default.

Parameters:
  • layer_names (Iterable[str]) – state_dict layer names of the model.

  • num_hidden_layers (int) – number of decoder layers in the model.

Returns:

>>> {
...     "embedding": ["model.embed_tokens.weight"],
...     "lm_head": ["lm_head.weight", "model.norm.weight"],
...     "block_0_ffn": ["model.layers.0.mlp.down_proj", ...],
...     "block_0_attention": ["model.layers.0.self_attn.q_proj", ...],
... }

Return type:

Dictionary of group names to list of layer names per group, e.g.

abstract static init_rotary_embedding(model, runtime)

Re-initiate the rotary embeddings based on an existing model.

In puzzletron we initiate a sharded model by first creating a meta model then replacing to the actual device by loading the state_dict with the real weights.

Rotary embeddings frequencies are tensor buffers that are created dynamically during init and are not part of the model state_dict, so cannot be restored after a meta device initialization.

abstract static input_embedding_name()

Return the name of the input embedding layer.

abstract static layer_block_name(index)

Return the name of the decoder layer at the given index.

Parameters:

index (int)

abstract static layer_name_predicates(num_layers)

Return predicates for grouping model weights to support subblock checkpointing.

For every group name return a regex predicate whether a layer name is part of the group.

Returns:

Dictionary of group name to regex pattern predicate.

Parameters:

num_layers (int)

Return type:

Dict[str, Pattern]

static mlp_no_op_post_init(decoder_layer)

Post-init callback to alter a decoder layer so that FFN/mlp subblock performs as no-op.

It is recommended to use the utils modules from no_op.py to replace layers to dummy counterparts.

Example for replacing a layernorm layer with identity:

>>> decoder_layer.post_attention_layernorm = Same()

Example for replacing an MLP layer with zeroes (zeroes since hidden_states are added to the residuals hidden_states so a no-op implementation will leave residual the same):

>>> decoder_layer.mlp = MatchingZeros()

In case the MLP layer to replace returns multiple outputs i.e hidden_states, _ = self.mlp(), use the util method return_tuple_of_size to return trailing None values:

>>> decoder_layer.mlp = return_tuple_of_size(MatchingZeros, size=2)()
Parameters:

decoder_layer (Module)

classmethod mlp_no_op_supported()

Check whether mlp_no_op_post_init is overridden for mlp no-op support.

Return type:

bool

abstract static output_embedding_name()

Return the name of the output embedding layer.

static requires_trust_remote_code()

Whether this model descriptor requires trust_remote_code=True for loading.

Models that use custom code (e.g., via auto_map in config) should override this to return True.

Returns:

True if trust_remote_code=True is required, False otherwise.

Return type:

bool

static uses_autocast()

Whether this model supports torch.autocast.

Some models (e.g., Qwen3-VL MoE) have dtype bugs under autocast. Override and return False for models that do not support autocast.

Return type:

bool