sharded_checkpoint_utils
Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes.
Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.
Functions
keys_to_load: entire state_dict if None, else partial state_dict containing only these keys |
|
out_path is usually output_checkpoint_path / "model.safetensors" |
|
Set a submodule on a model by dotted path. |
- create_local_shard_(model, owned_block_indexes, descriptor, runtime)
- Parameters:
owned_block_indexes (set[int])
- create_sharded_model(runtime, descriptor, model_config, owned_block_indexes, device='meta', dtype=torch.float32)
- Parameters:
model_config (PretrainedConfig)
owned_block_indexes (set[int])
device (str | device | None)
dtype (dtype | None)
- is_in_safetensors_format(checkpoint_dir)
- Parameters:
checkpoint_dir (Path)
- Return type:
bool
- load_and_shard_model(descriptor, checkpoint_path, owned_block_indexes='auto', model_config=None)
- Parameters:
checkpoint_path (str | Path)
owned_block_indexes (set[int] | Literal['auto'])
model_config (PretrainedConfig | None)
- load_sharded_state_dict(model_name_or_path, keys_to_load=None, device='cpu')
keys_to_load: entire state_dict if None, else partial state_dict containing only these keys
- Parameters:
model_name_or_path (str | Path)
keys_to_load (Iterable[str] | None)
device (device | str)
- Return type:
dict[str, Tensor]
- load_state_dict_to_shards(model_shard, loaded_state_dict=None)
- Parameters:
model_shard (Module)
loaded_state_dict (dict | None)
- Return type:
None
- save_sharded_model(model_shard, out_path)
out_path is usually output_checkpoint_path / “model.safetensors”
- Parameters:
model_shard (Module | dict[str, Tensor])
out_path (str | Path)
- set_submodule(model, module_name, new_submodule)
Set a submodule on a model by dotted path.
- Parameters:
model (Module)
module_name (str)
new_submodule (Module)
- Return type:
None