sharded_checkpoint_utils
Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes.
Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.
Functions
Set a submodule on a model by dotted path. |
|
keys_to_load: entire state_dict if None, else partial state_dict containing only these keys |
|
- create_sharded_model(runtime, descriptor, model_config, owned_block_indexes, device='meta', dtype=torch.float32)
- Parameters:
model_config (PreTrainedConfig)
owned_block_indexes (set[int])
device (str | device | None)
dtype (dtype | None)
- is_in_safetensors_format(checkpoint_dir)
- Parameters:
checkpoint_dir (Path)
- Return type:
bool
- load_and_shard_model(descriptor, checkpoint_path, owned_block_indexes='auto', model_config=None)
- Parameters:
checkpoint_path (str | Path)
owned_block_indexes (set[int] | Literal['auto'])
model_config (PreTrainedConfig | None)
- load_sharded_state_dict(model_name_or_path, keys_to_load=None, device='cpu')
keys_to_load: entire state_dict if None, else partial state_dict containing only these keys
- Parameters:
model_name_or_path (str | Path)
keys_to_load (Iterable[str] | None)
device (device | str)
- Return type:
dict[str, Tensor]
- set_submodule(model, module_name, new_submodule)
Set a submodule on a model by dotted path.
- Parameters:
model (Module)
module_name (str)
new_submodule (Module)
- Return type:
None