Skip to content

Torch dataloader utils

collate_sparse_matrix_batch(batch)

Collate function to create a batch out of sparse tensors.

This is necessary to collate sparse matrices of various lengths.

Parameters:

Name Type Description Default
batch list[Tensor]

A list of Tensors to collate into a batch.

required

Returns:

Type Description
Tensor

The tensors collated into a CSR (Compressed Sparse Row) Format.

Source code in bionemo/scdl/util/torch_dataloader_utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def collate_sparse_matrix_batch(batch: list[torch.Tensor]) -> torch.Tensor:
    """Collate function to create a batch out of sparse tensors.

    This is necessary to collate sparse matrices of various lengths.

    Args:
        batch: A list of Tensors to collate into a batch.

    Returns:
        The tensors collated into a CSR (Compressed Sparse Row) Format.
    """
    batch_rows = torch.cumsum(
        torch.tensor([0] + [sparse_representation.shape[1] for sparse_representation in batch]), dim=0
    )
    batch_cols = torch.cat([sparse_representation[1] for sparse_representation in batch]).to(torch.int32)
    batch_values = torch.cat([sparse_representation[0] for sparse_representation in batch])
    if len(batch_cols) == 0:
        max_pointer = 0
    else:
        max_pointer = int(batch_cols.max().item() + 1)
    batch_sparse_tensor = torch.sparse_csr_tensor(batch_rows, batch_cols, batch_values, size=(len(batch), max_pointer))
    return batch_sparse_tensor