Skip to content

Dtypes

get_autocast_dtype(precision)

Returns the torch dtype corresponding to the given precision.

Parameters:

Name Type Description Default
precision PrecisionTypes

The precision type.

required

Returns:

Type Description
dtype

torch.dtype: The torch dtype corresponding to the given precision.

Raises:

Type Description
ValueError

If the precision is not supported.

Source code in bionemo/core/utils/dtypes.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype:
    """Returns the torch dtype corresponding to the given precision.

    Args:
        precision: The precision type.

    Returns:
        torch.dtype: The torch dtype corresponding to the given precision.

    Raises:
        ValueError: If the precision is not supported.
    """
    # TODO move this to a utilities folder, or find/import the function that does this in NeMo
    if precision in precision_to_dtype:
        return precision_to_dtype[precision]
    else:
        raise ValueError(f"Unsupported precision: {precision}")