common
Functions
Resolve a dtype that may be a string (e.g. from Hydra/OmegaConf config) to torch.dtype. |
- infer_weights_dtype(state_dict)
- Parameters:
state_dict (dict[str, Tensor])
- Return type:
dtype
- resolve_torch_dtype(dtype)
Resolve a dtype that may be a string (e.g. from Hydra/OmegaConf config) to torch.dtype.
Accepts
torch.dtypeobjects (returned as-is) and strings like"torch.bfloat16"or"bfloat16".- Parameters:
dtype (str | dtype)
- Return type:
dtype