sharded_checkpoint_utils

Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes.

Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations.

Functions

set_submodule

Set a submodule on a model by dotted path.

load_and_shard_model

create_sharded_model

load_sharded_state_dict

keys_to_load: entire state_dict if None, else partial state_dict containing only these keys

is_in_safetensors_format

create_sharded_model(runtime, descriptor, model_config, owned_block_indexes, device='meta', dtype=torch.float32)
Parameters:
  • model_config (PreTrainedConfig)

  • owned_block_indexes (set[int])

  • device (str | device | None)

  • dtype (dtype | None)

is_in_safetensors_format(checkpoint_dir)
Parameters:

checkpoint_dir (Path)

Return type:

bool

load_and_shard_model(descriptor, checkpoint_path, owned_block_indexes='auto', model_config=None)
Parameters:
  • checkpoint_path (str | Path)

  • owned_block_indexes (set[int] | Literal['auto'])

  • model_config (PreTrainedConfig | None)

load_sharded_state_dict(model_name_or_path, keys_to_load=None, device='cpu')

keys_to_load: entire state_dict if None, else partial state_dict containing only these keys

Parameters:
  • model_name_or_path (str | Path)

  • keys_to_load (Iterable[str] | None)

  • device (device | str)

Return type:

dict[str, Tensor]

set_submodule(model, module_name, new_submodule)

Set a submodule on a model by dotted path.

Parameters:
  • model (Module)

  • module_name (str)

  • new_submodule (Module)

Return type:

None