Utility Functions and Classes¶
sparse_quantize¶
-
MinkowskiEngine.utils.
sparse_quantize
(coordinates, features=None, labels=None, ignore_label=- 100, return_index=False, return_inverse=False, return_maps_only=False, quantization_size=None, device='cpu')¶ Given coordinates, and features (optionally labels), the function generates quantized (voxelized) coordinates.
- Args:
coordinates
(numpy.ndarray
ortorch.Tensor
): a matrix of size \(N \times D\) where \(N\) is the number of points in the \(D\) dimensional space.features
(numpy.ndarray
ortorch.Tensor
, optional): a matrix of size \(N \times D_F\) where \(N\) is the number of points and \(D_F\) is the dimension of the features. Must have the same container as coords (i.e. if coords is a torch.Tensor, feats must also be a torch.Tensor).labels
(numpy.ndarray
ortorch.IntTensor
, optional): integer labels associated to eah coordinates. Must have the same container as coords (i.e. if coords is a torch.Tensor, labels must also be a torch.Tensor). For classification where a set of points are mapped to one label, do not feed the labels.ignore_label
(int
, optional): the int value of the IGNORE LABEL.torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
return_index
(bool
, optional): set True if you want the indices of the quantized coordinates. False by default.return_inverse
(bool
, optional): set True if you want the indices that can recover the discretized original coordinates. False by default. return_index must be True when return_reverse is True.return_maps_only
(bool
, optional): if set, return the unique_map or optionally inverse map, but not the coordinates. Can be used if you don’t care about final coordinates or if you use device==cuda and you don’t need coordinates on GPU. This returns either unique_map alone or (unique_map, inverse_map) if return_inverse is set.quantization_size
(attr:float, optional): if set, will use the quanziation size to define the smallest distance between coordinates.device
(attr:str, optional): Either ‘cpu’ or ‘cuda’.Example:
>>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True) >>> unique_coords = discrete_coords[unique_map] >>> print(unique_coords[inverse_map] == discrete_coords) # True
quantization_size
(float
,list
, ornumpy.ndarray
, optional): the length of the each side of the hyperrectangle of of the grid cell.Example:
>>> # Segmentation >>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100) >>> coords, feats, labels = MinkowskiEngine.utils.sparse_quantize( >>> coords, feats, labels, ignore_label=-100, quantization_size=0.1) >>> output = net(MinkowskiEngine.SparseTensor(feats, coords)) >>> loss = criterion(output.F, labels.long()) >>> >>> # Classification >>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100) >>> coords, feats = MinkowskiEngine.utils.sparse_quantize(coords, feats) >>> output = net(MinkowskiEngine.SparseTensor(feats, coords)) >>> loss = criterion(output.F, labels.long())
batched_coordinates¶
-
MinkowskiEngine.utils.
batched_coordinates
(coords, dtype=torch.int32, device=None)¶ Create a ME.SparseTensor coordinates from a sequence of coordinates
Given a list of either numpy or pytorch tensor coordinates, return the batched coordinates suitable for ME.SparseTensor.
- Args:
coords
(a sequence of torch.Tensor or numpy.ndarray): a list of coordinates.dtype
: torch data type of the return tensor. torch.int32 by default.- Returns:
batched_coordindates
(torch.Tensor): a batched coordinates.
Warning
From v0.4, the batch index will be prepended before all coordinates.
sparse_collate¶
-
MinkowskiEngine.utils.
sparse_collate
(coords, feats, labels=None, dtype=torch.int32, device=None)¶ Create input arguments for a sparse tensor the documentation.
Convert a set of coordinates and features into the batch coordinates and batch features.
- Args:
coords
(set of torch.Tensor or numpy.ndarray): a set of coordinates.feats
(set of torch.Tensor or numpy.ndarray): a set of features.labels
(set of torch.Tensor or numpy.ndarray): a set of labels associated to the inputs.
batch_sparse_collate¶
-
MinkowskiEngine.utils.
batch_sparse_collate
(data, dtype=torch.int32, device=None)¶ The wrapper function that can be used in in conjunction with torch.utils.data.DataLoader to generate inputs for a sparse tensor.
Please refer to the training example for the usage.
- Args:
data
: list of (coordinates, features, labels) tuples.
cat¶
-
MinkowskiEngine.
cat
(*sparse_tensors)¶ Concatenate sparse tensors
Concatenate sparse tensor features. All sparse tensors must have the same coordinate_map_key (the same coordinates). To concatenate sparse tensors with different sparsity patterns, use SparseTensor binary operations, or
MinkowskiEngine.MinkowskiUnion
.Example:
>>> import MinkowskiEngine as ME >>> sin = ME.SparseTensor(feats, coords) >>> sin2 = ME.SparseTensor(feats2, coordinate_map_key=sin.coordinate_map_key, coordinate_mananger=sin.coordinate_manager) >>> sout = UNet(sin) # Returns an output sparse tensor on the same coordinates >>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
to_sparse¶
-
MinkowskiEngine.
to_sparse
(x: torch.Tensor, format: Optional[str] = None, coordinates=None, device=None)¶ Convert a batched tensor (dimension 0 is the batch dimension) to a SparseTensor
x
(torch.Tensor
): a batched tensor. The first dimension is the batch dimension.format
(str
): Format of the tensor. It must include ‘B’ and ‘C’ indicating the batch and channel dimension respectively. The rest of the dimensions must be ‘X’. .e.g. format=”BCXX” if image data with BCHW format is used. If a 3D data with the channel at the last dimension, use format=”BXXXC” indicating Batch X Height X Width X Depth X Channel. If not provided, the format will be “BCX…X”.device
: Device the sparse tensor will be generated on. If not provided, the device of the input tensor will be used.
to_sparse_all¶
-
MinkowskiEngine.
to_sparse_all
(dense_tensor: torch.Tensor, coordinates: Optional[torch.Tensor] = None)¶ Converts a (differentiable) dense tensor to a sparse tensor with all coordinates.
Assume the input to have BxCxD1xD2x….xDN format.
If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage
Example:
>>> dense_tensor = torch.rand(3, 4, 5, 6, 7, 8) # BxCxD1xD2xD3xD4 >>> dense_tensor.requires_grad = True >>> stensor = to_sparse(dense_tensor)
SparseCollation¶
-
class
MinkowskiEngine.utils.
SparseCollation
(limit_numpoints=- 1, dtype=torch.int32, device=None)¶ Generates collate function for coords, feats, labels.
Please refer to the training example for the usage.
- Args:
limit_numpoints
(int): If positive integer, limits batch size so that the number of input coordinates is below limit_numpoints. If 0 or False, concatenate all points. -1 by default.
Example:
>>> data_loader = torch.utils.data.DataLoader( >>> dataset, >>> ..., >>> collate_fn=SparseCollation()) >>> for d in iter(data_loader): >>> print(d)
-
__init__
(limit_numpoints=- 1, dtype=torch.int32, device=None)¶ Initialize self. See help(type(self)) for accurate signature.
MinkowskiToSparseTensor¶
-
class
MinkowskiEngine.
MinkowskiToSparseTensor
(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)¶ Converts a (differentiable) dense tensor or a
MinkowskiEngine.TensorField
to aMinkowskiEngine.SparseTensor
.For dense tensor, the input must have the BxCxD1xD2x….xDN format.
remove_zeros
(bool): if True, removes zero valued coordinates. If False, use all coordinates to populate a sparse tensor. True by default.If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage.
Example:
>>> # Differentiable dense torch.Tensor to sparse tensor. >>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11) # BxCxD1xD2x....xDN >>> dense_tensor.requires_grad = True >>> # Since the shape is fixed, cache the coordinates for faster inference >>> coordinates = dense_coordinates(dense_tensor.shape) >>> network = nn.Sequential( >>> # Add layers that can be applied on a regular pytorch tensor >>> nn.ReLU(), >>> MinkowskiToSparseTensor(remove_zeros=False, coordinates=coordinates), >>> MinkowskiConvolution(4, 5, kernel_size=3, dimension=4), >>> MinkowskiBatchNorm(5), >>> MinkowskiReLU(), >>> ) >>> for i in range(5): >>> print(f"Iteration: {i}") >>> soutput = network(dense_tensor) >>> soutput.F.sum().backward() >>> soutput.dense(shape=dense_tensor.shape)
-
__init__
(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
MinkowskiToDenseTensor¶
-
class
MinkowskiEngine.
MinkowskiToDenseTensor
(shape: Optional[torch.Size] = None)¶ Converts a (differentiable) sparse tensor to a torch tensor.
The return type has the BxCxD1xD2x….xDN format.
Example:
>>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11) # BxCxD1xD2x....xDN >>> dense_tensor.requires_grad = True >>> # Since the shape is fixed, cache the coordinates for faster inference >>> coordinates = dense_coordinates(dense_tensor.shape) >>> network = nn.Sequential( >>> # Add layers that can be applied on a regular pytorch tensor >>> nn.ReLU(), >>> MinkowskiToSparseTensor(coordinates=coordinates), >>> MinkowskiConvolution(4, 5, stride=2, kernel_size=3, dimension=4), >>> MinkowskiBatchNorm(5), >>> MinkowskiReLU(), >>> MinkowskiConvolutionTranspose(5, 6, stride=2, kernel_size=3, dimension=4), >>> MinkowskiToDenseTensor( >>> dense_tensor.shape >>> ), # must have the same tensor stride. >>> ) >>> for i in range(5): >>> print(f"Iteration: {i}") >>> output = network(dense_tensor) # returns a regular pytorch tensor >>> output.sum().backward()
-
__init__
(shape: Optional[torch.Size] = None)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-