Skip to content

Batch augmentation

BatchAugmentation

Facilitates the creation of batch augmentation objects based on specified optimal transport types.

Parameters:

Name Type Description Default
device str

The device to use for computations (e.g., 'cpu', 'cuda').

required
num_threads int

The number of threads to utilize.

required
Source code in bionemo/moco/interpolants/batch_augmentation.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class BatchAugmentation:
    """Facilitates the creation of batch augmentation objects based on specified optimal transport types.

    Args:
        device (str): The device to use for computations (e.g., 'cpu', 'cuda').
        num_threads (int): The number of threads to utilize.
    """

    def __init__(self, device, num_threads):
        """Initializes a BatchAugmentation instance.

        Args:
            device (str): Device for computation.
            num_threads (int): Number of threads to use.
        """
        self.device = device
        self.num_threads = num_threads

    def create(self, method_type: OptimalTransportType):
        """Creates a batch augmentation object of the specified type.

        Args:
            method_type (OptimalTransportType): The type of optimal transport method.

        Returns:
            The augmentation object if the type is supported, otherwise **None**.
        """
        if method_type == OptimalTransportType.EXACT:
            augmentation = OTSampler(method="exact", device=self.device, num_threads=self.num_threads)
        elif method_type == OptimalTransportType.KABSCH:
            augmentation = KabschAugmentation()
        elif method_type == OptimalTransportType.EQUIVARIANT:
            augmentation = EquivariantOTSampler(method="exact", device=self.device, num_threads=self.num_threads)
        else:
            return None
        return augmentation

__init__(device, num_threads)

Initializes a BatchAugmentation instance.

Parameters:

Name Type Description Default
device str

Device for computation.

required
num_threads int

Number of threads to use.

required
Source code in bionemo/moco/interpolants/batch_augmentation.py
35
36
37
38
39
40
41
42
43
def __init__(self, device, num_threads):
    """Initializes a BatchAugmentation instance.

    Args:
        device (str): Device for computation.
        num_threads (int): Number of threads to use.
    """
    self.device = device
    self.num_threads = num_threads

create(method_type)

Creates a batch augmentation object of the specified type.

Parameters:

Name Type Description Default
method_type OptimalTransportType

The type of optimal transport method.

required

Returns:

Type Description

The augmentation object if the type is supported, otherwise None.

Source code in bionemo/moco/interpolants/batch_augmentation.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def create(self, method_type: OptimalTransportType):
    """Creates a batch augmentation object of the specified type.

    Args:
        method_type (OptimalTransportType): The type of optimal transport method.

    Returns:
        The augmentation object if the type is supported, otherwise **None**.
    """
    if method_type == OptimalTransportType.EXACT:
        augmentation = OTSampler(method="exact", device=self.device, num_threads=self.num_threads)
    elif method_type == OptimalTransportType.KABSCH:
        augmentation = KabschAugmentation()
    elif method_type == OptimalTransportType.EQUIVARIANT:
        augmentation = EquivariantOTSampler(method="exact", device=self.device, num_threads=self.num_threads)
    else:
        return None
    return augmentation