Utility Functions and Classes

sparse_quantize

MinkowskiEngine.utils.sparse_quantize(coordinates, features=None, labels=None, ignore_label=- 100, return_index=False, return_inverse=False, return_maps_only=False, quantization_size=None, device='cpu')

Given coordinates, and features (optionally labels), the function generates quantized (voxelized) coordinates.

Args:

coordinates (numpy.ndarray or torch.Tensor): a matrix of size \(N \times D\) where \(N\) is the number of points in the \(D\) dimensional space.

features (numpy.ndarray or torch.Tensor, optional): a matrix of size \(N \times D_F\) where \(N\) is the number of points and \(D_F\) is the dimension of the features. Must have the same container as coords (i.e. if coords is a torch.Tensor, feats must also be a torch.Tensor).

labels (numpy.ndarray or torch.IntTensor, optional): integer labels associated to eah coordinates. Must have the same container as coords (i.e. if coords is a torch.Tensor, labels must also be a torch.Tensor). For classification where a set of points are mapped to one label, do not feed the labels.

ignore_label (int, optional): the int value of the IGNORE LABEL. torch.nn.CrossEntropyLoss(ignore_index=ignore_label)

return_index (bool, optional): set True if you want the indices of the quantized coordinates. False by default.

return_inverse (bool, optional): set True if you want the indices that can recover the discretized original coordinates. False by default. return_index must be True when return_reverse is True.

return_maps_only (bool, optional): if set, return the unique_map or optionally inverse map, but not the coordinates. Can be used if you don’t care about final coordinates or if you use device==cuda and you don’t need coordinates on GPU. This returns either unique_map alone or (unique_map, inverse_map) if return_inverse is set.

quantization_size (attr:float, optional): if set, will use the quanziation size to define the smallest distance between coordinates.

device (attr:str, optional): Either ‘cpu’ or ‘cuda’.

Example:

>>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True)
>>> unique_coords = discrete_coords[unique_map]
>>> print(unique_coords[inverse_map] == discrete_coords)  # True

quantization_size (float, list, or numpy.ndarray, optional): the length of the each side of the hyperrectangle of of the grid cell.

Example:

>>> # Segmentation
>>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
>>> coords, feats, labels = MinkowskiEngine.utils.sparse_quantize(
>>>     coords, feats, labels, ignore_label=-100, quantization_size=0.1)
>>> output = net(MinkowskiEngine.SparseTensor(feats, coords))
>>> loss = criterion(output.F, labels.long())
>>>
>>> # Classification
>>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
>>> coords, feats = MinkowskiEngine.utils.sparse_quantize(coords, feats)
>>> output = net(MinkowskiEngine.SparseTensor(feats, coords))
>>> loss = criterion(output.F, labels.long())

batched_coordinates

MinkowskiEngine.utils.batched_coordinates(coords, dtype=torch.int32, device=None)

Create a ME.SparseTensor coordinates from a sequence of coordinates

Given a list of either numpy or pytorch tensor coordinates, return the batched coordinates suitable for ME.SparseTensor.

Args:

coords (a sequence of torch.Tensor or numpy.ndarray): a list of coordinates.

dtype: torch data type of the return tensor. torch.int32 by default.

Returns:

batched_coordindates (torch.Tensor): a batched coordinates.

Warning

From v0.4, the batch index will be prepended before all coordinates.

sparse_collate

MinkowskiEngine.utils.sparse_collate(coords, feats, labels=None, dtype=torch.int32, device=None)

Create input arguments for a sparse tensor the documentation.

Convert a set of coordinates and features into the batch coordinates and batch features.

Args:

coords (set of torch.Tensor or numpy.ndarray): a set of coordinates.

feats (set of torch.Tensor or numpy.ndarray): a set of features.

labels (set of torch.Tensor or numpy.ndarray): a set of labels associated to the inputs.

batch_sparse_collate

MinkowskiEngine.utils.batch_sparse_collate(data, dtype=torch.int32, device=None)

The wrapper function that can be used in in conjunction with torch.utils.data.DataLoader to generate inputs for a sparse tensor.

Please refer to the training example for the usage.

Args:

data: list of (coordinates, features, labels) tuples.

cat

MinkowskiEngine.cat(*sparse_tensors)

Concatenate sparse tensors

Concatenate sparse tensor features. All sparse tensors must have the same coordinate_map_key (the same coordinates). To concatenate sparse tensors with different sparsity patterns, use SparseTensor binary operations, or MinkowskiEngine.MinkowskiUnion.

Example:

>>> import MinkowskiEngine as ME
>>> sin = ME.SparseTensor(feats, coords)
>>> sin2 = ME.SparseTensor(feats2, coordinate_map_key=sin.coordinate_map_key, coordinate_mananger=sin.coordinate_manager)
>>> sout = UNet(sin)  # Returns an output sparse tensor on the same coordinates
>>> sout2 = ME.cat(sin, sin2, sout)  # Can concatenate multiple sparse tensors

to_sparse

MinkowskiEngine.to_sparse(x: torch.Tensor, format: Optional[str] = None, coordinates=None, device=None)

Convert a batched tensor (dimension 0 is the batch dimension) to a SparseTensor

x (torch.Tensor): a batched tensor. The first dimension is the batch dimension.

format (str): Format of the tensor. It must include ‘B’ and ‘C’ indicating the batch and channel dimension respectively. The rest of the dimensions must be ‘X’. .e.g. format=”BCXX” if image data with BCHW format is used. If a 3D data with the channel at the last dimension, use format=”BXXXC” indicating Batch X Height X Width X Depth X Channel. If not provided, the format will be “BCX…X”.

device: Device the sparse tensor will be generated on. If not provided, the device of the input tensor will be used.

to_sparse_all

MinkowskiEngine.to_sparse_all(dense_tensor: torch.Tensor, coordinates: Optional[torch.Tensor] = None)

Converts a (differentiable) dense tensor to a sparse tensor with all coordinates.

Assume the input to have BxCxD1xD2x….xDN format.

If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage

Example:

>>> dense_tensor = torch.rand(3, 4, 5, 6, 7, 8)  # BxCxD1xD2xD3xD4
>>> dense_tensor.requires_grad = True
>>> stensor = to_sparse(dense_tensor)

SparseCollation

class MinkowskiEngine.utils.SparseCollation(limit_numpoints=- 1, dtype=torch.int32, device=None)

Generates collate function for coords, feats, labels.

Please refer to the training example for the usage.

Args:

limit_numpoints (int): If positive integer, limits batch size so that the number of input coordinates is below limit_numpoints. If 0 or False, concatenate all points. -1 by default.

Example:

>>> data_loader = torch.utils.data.DataLoader(
>>>     dataset,
>>>     ...,
>>>     collate_fn=SparseCollation())
>>> for d in iter(data_loader):
>>>     print(d)
__init__(limit_numpoints=- 1, dtype=torch.int32, device=None)

Initialize self. See help(type(self)) for accurate signature.

MinkowskiToSparseTensor

class MinkowskiEngine.MinkowskiToSparseTensor(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)

Converts a (differentiable) dense tensor or a MinkowskiEngine.TensorField to a MinkowskiEngine.SparseTensor.

For dense tensor, the input must have the BxCxD1xD2x….xDN format.

remove_zeros (bool): if True, removes zero valued coordinates. If False, use all coordinates to populate a sparse tensor. True by default.

If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage.

Example:

>>> # Differentiable dense torch.Tensor to sparse tensor.
>>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11)  # BxCxD1xD2x....xDN
>>> dense_tensor.requires_grad = True

>>> # Since the shape is fixed, cache the coordinates for faster inference
>>> coordinates = dense_coordinates(dense_tensor.shape)

>>> network = nn.Sequential(
>>>     # Add layers that can be applied on a regular pytorch tensor
>>>     nn.ReLU(),
>>>     MinkowskiToSparseTensor(remove_zeros=False, coordinates=coordinates),
>>>     MinkowskiConvolution(4, 5, kernel_size=3, dimension=4),
>>>     MinkowskiBatchNorm(5),
>>>     MinkowskiReLU(),
>>> )

>>> for i in range(5):
>>>   print(f"Iteration: {i}")
>>>   soutput = network(dense_tensor)
>>>   soutput.F.sum().backward()
>>>   soutput.dense(shape=dense_tensor.shape)
__init__(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MinkowskiToDenseTensor

class MinkowskiEngine.MinkowskiToDenseTensor(shape: Optional[torch.Size] = None)

Converts a (differentiable) sparse tensor to a torch tensor.

The return type has the BxCxD1xD2x….xDN format.

Example:

>>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11)  # BxCxD1xD2x....xDN
>>> dense_tensor.requires_grad = True

>>> # Since the shape is fixed, cache the coordinates for faster inference
>>> coordinates = dense_coordinates(dense_tensor.shape)

>>> network = nn.Sequential(
>>>     # Add layers that can be applied on a regular pytorch tensor
>>>     nn.ReLU(),
>>>     MinkowskiToSparseTensor(coordinates=coordinates),
>>>     MinkowskiConvolution(4, 5, stride=2, kernel_size=3, dimension=4),
>>>     MinkowskiBatchNorm(5),
>>>     MinkowskiReLU(),
>>>     MinkowskiConvolutionTranspose(5, 6, stride=2, kernel_size=3, dimension=4),
>>>     MinkowskiToDenseTensor(
>>>         dense_tensor.shape
>>>     ),  # must have the same tensor stride.
>>> )

>>> for i in range(5):
>>>     print(f"Iteration: {i}")
>>>     output = network(dense_tensor) # returns a regular pytorch tensor
>>>     output.sum().backward()
__init__(shape: Optional[torch.Size] = None)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MinkowskiToFeature

class MinkowskiEngine.MinkowskiToFeature

Extract features from a sparse tensor and returns a pytorch tensor.

Can be used to to make a network construction simpler.

Example:

>>> net = nn.Sequential(MinkowskiConvolution(...), MinkowskiGlobalMaxPooling(...), MinkowskiToFeature(), nn.Linear(...))
>>> torch_tensor = net(sparse_tensor)
__init__()

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MinkowskiStackCat

class MinkowskiEngine.MinkowskiStackCat(*args: torch.nn.modules.module.Module)
class MinkowskiEngine.MinkowskiStackCat(arg: OrderedDict[str, Module])
__init__(*args: Any)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

MinkowskiStackSum

class MinkowskiEngine.MinkowskiStackSum(*args: torch.nn.modules.module.Module)
class MinkowskiEngine.MinkowskiStackSum(arg: OrderedDict[str, Module])
__init__(*args: Any)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

MinkowskiStackMean

class MinkowskiEngine.MinkowskiStackMean(*args: torch.nn.modules.module.Module)
class MinkowskiEngine.MinkowskiStackMean(arg: OrderedDict[str, Module])
__init__(*args: Any)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

MinkowskiStackVar

class MinkowskiEngine.MinkowskiStackVar(*args: torch.nn.modules.module.Module)
class MinkowskiEngine.MinkowskiStackVar(arg: OrderedDict[str, Module])
__init__(*args: Any)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.