sharded_checkpoint_utils

Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes.

Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.

Functions

create_local_shard_

create_sharded_model

is_in_safetensors_format

load_and_shard_model

load_sharded_state_dict

keys_to_load: entire state_dict if None, else partial state_dict containing only these keys

load_state_dict_to_shards

save_sharded_model

out_path is usually output_checkpoint_path / "model.safetensors"

set_submodule

Set a submodule on a model by dotted path.

create_local_shard_(model, owned_block_indexes, descriptor, runtime)
Parameters:

owned_block_indexes (set[int])

create_sharded_model(runtime, descriptor, model_config, owned_block_indexes, device='meta', dtype=torch.float32)
Parameters:
  • model_config (PretrainedConfig)

  • owned_block_indexes (set[int])

  • device (str | device | None)

  • dtype (dtype | None)

is_in_safetensors_format(checkpoint_dir)
Parameters:

checkpoint_dir (Path)

Return type:

bool

load_and_shard_model(descriptor, checkpoint_path, owned_block_indexes='auto', model_config=None)
Parameters:
  • checkpoint_path (str | Path)

  • owned_block_indexes (set[int] | Literal['auto'])

  • model_config (PretrainedConfig | None)

load_sharded_state_dict(model_name_or_path, keys_to_load=None, device='cpu')

keys_to_load: entire state_dict if None, else partial state_dict containing only these keys

Parameters:
  • model_name_or_path (str | Path)

  • keys_to_load (Iterable[str] | None)

  • device (device | str)

Return type:

dict[str, Tensor]

load_state_dict_to_shards(model_shard, loaded_state_dict=None)
Parameters:
  • model_shard (Module)

  • loaded_state_dict (dict | None)

Return type:

None

save_sharded_model(model_shard, out_path)

out_path is usually output_checkpoint_path / “model.safetensors”

Parameters:
  • model_shard (Module | dict[str, Tensor])

  • out_path (str | Path)

set_submodule(model, module_name, new_submodule)

Set a submodule on a model by dotted path.

Parameters:
  • model (Module)

  • module_name (str)

  • new_submodule (Module)

Return type:

None