common

Functions

infer_weights_dtype

resolve_torch_dtype

Resolve a dtype that may be a string (e.g. from Hydra/OmegaConf config) to torch.dtype.

infer_weights_dtype(state_dict)
Parameters:

state_dict (dict[str, Tensor])

Return type:

dtype

resolve_torch_dtype(dtype)

Resolve a dtype that may be a string (e.g. from Hydra/OmegaConf config) to torch.dtype.

Accepts torch.dtype objects (returned as-is) and strings like "torch.bfloat16" or "bfloat16".

Parameters:

dtype (str | dtype)

Return type:

dtype