Utility Functions and Classes¶
sparse_quantize¶
-
MinkowskiEngine.utils.
sparse_quantize
(coordinates, features=None, labels=None, ignore_label=- 100, return_index=False, return_inverse=False, return_maps_only=False, quantization_size=None, device='cpu')¶ Given coordinates, and features (optionally labels), the function generates quantized (voxelized) coordinates.
- Args:
coordinates
(numpy.ndarray
ortorch.Tensor
): a matrix of size \(N \times D\) where \(N\) is the number of points in the \(D\) dimensional space.features
(numpy.ndarray
ortorch.Tensor
, optional): a matrix of size \(N \times D_F\) where \(N\) is the number of points and \(D_F\) is the dimension of the features. Must have the same container as coords (i.e. if coords is a torch.Tensor, feats must also be a torch.Tensor).labels
(numpy.ndarray
ortorch.IntTensor
, optional): integer labels associated to eah coordinates. Must have the same container as coords (i.e. if coords is a torch.Tensor, labels must also be a torch.Tensor). For classification where a set of points are mapped to one label, do not feed the labels.ignore_label
(int
, optional): the int value of the IGNORE LABEL.torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
return_index
(bool
, optional): set True if you want the indices of the quantized coordinates. False by default.return_inverse
(bool
, optional): set True if you want the indices that can recover the discretized original coordinates. False by default. return_index must be True when return_reverse is True.return_maps_only
(bool
, optional): if set, return the unique_map or optionally inverse map, but not the coordinates. Can be used if you don’t care about final coordinates or if you use device==cuda and you don’t need coordinates on GPU. This returns either unique_map alone or (unique_map, inverse_map) if return_inverse is set.quantization_size
(attr:float, optional): if set, will use the quanziation size to define the smallest distance between coordinates.device
(attr:str, optional): Either ‘cpu’ or ‘cuda’.Example:
>>> unique_map, inverse_map = sparse_quantize(discrete_coords, return_index=True, return_inverse=True) >>> unique_coords = discrete_coords[unique_map] >>> print(unique_coords[inverse_map] == discrete_coords) # True
quantization_size
(float
,list
, ornumpy.ndarray
, optional): the length of the each side of the hyperrectangle of of the grid cell.Example:
>>> # Segmentation >>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100) >>> coords, feats, labels = MinkowskiEngine.utils.sparse_quantize( >>> coords, feats, labels, ignore_label=-100, quantization_size=0.1) >>> output = net(MinkowskiEngine.SparseTensor(feats, coords)) >>> loss = criterion(output.F, labels.long()) >>> >>> # Classification >>> criterion = torch.nn.CrossEntropyLoss(ignore_index=-100) >>> coords, feats = MinkowskiEngine.utils.sparse_quantize(coords, feats) >>> output = net(MinkowskiEngine.SparseTensor(feats, coords)) >>> loss = criterion(output.F, labels.long())
batched_coordinates¶
-
MinkowskiEngine.utils.
batched_coordinates
(coords, dtype=torch.int32, device=None)¶ Create a ME.SparseTensor coordinates from a sequence of coordinates
Given a list of either numpy or pytorch tensor coordinates, return the batched coordinates suitable for ME.SparseTensor.
- Args:
coords
(a sequence of torch.Tensor or numpy.ndarray): a list of coordinates.dtype
: torch data type of the return tensor. torch.int32 by default.- Returns:
batched_coordindates
(torch.Tensor): a batched coordinates.
Warning
From v0.4, the batch index will be prepended before all coordinates.
sparse_collate¶
-
MinkowskiEngine.utils.
sparse_collate
(coords, feats, labels=None, dtype=torch.int32, device=None)¶ Create input arguments for a sparse tensor the documentation.
Convert a set of coordinates and features into the batch coordinates and batch features.
- Args:
coords
(set of torch.Tensor or numpy.ndarray): a set of coordinates.feats
(set of torch.Tensor or numpy.ndarray): a set of features.labels
(set of torch.Tensor or numpy.ndarray): a set of labels associated to the inputs.
batch_sparse_collate¶
-
MinkowskiEngine.utils.
batch_sparse_collate
(data, dtype=torch.int32, device=None)¶ The wrapper function that can be used in in conjunction with torch.utils.data.DataLoader to generate inputs for a sparse tensor.
Please refer to the training example for the usage.
- Args:
data
: list of (coordinates, features, labels) tuples.
cat¶
-
MinkowskiEngine.
cat
(*sparse_tensors)¶ Concatenate sparse tensors
Concatenate sparse tensor features. All sparse tensors must have the same coordinate_map_key (the same coordinates). To concatenate sparse tensors with different sparsity patterns, use SparseTensor binary operations, or
MinkowskiEngine.MinkowskiUnion
.Example:
>>> import MinkowskiEngine as ME >>> sin = ME.SparseTensor(feats, coords) >>> sin2 = ME.SparseTensor(feats2, coordinate_map_key=sin.coordinate_map_key, coordinate_mananger=sin.coordinate_manager) >>> sout = UNet(sin) # Returns an output sparse tensor on the same coordinates >>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
to_sparse¶
-
MinkowskiEngine.
to_sparse
(x: torch.Tensor, format: Optional[str] = None, coordinates=None, device=None)¶ Convert a batched tensor (dimension 0 is the batch dimension) to a SparseTensor
x
(torch.Tensor
): a batched tensor. The first dimension is the batch dimension.format
(str
): Format of the tensor. It must include ‘B’ and ‘C’ indicating the batch and channel dimension respectively. The rest of the dimensions must be ‘X’. .e.g. format=”BCXX” if image data with BCHW format is used. If a 3D data with the channel at the last dimension, use format=”BXXXC” indicating Batch X Height X Width X Depth X Channel. If not provided, the format will be “BCX…X”.device
: Device the sparse tensor will be generated on. If not provided, the device of the input tensor will be used.
to_sparse_all¶
-
MinkowskiEngine.
to_sparse_all
(dense_tensor: torch.Tensor, coordinates: Optional[torch.Tensor] = None)¶ Converts a (differentiable) dense tensor to a sparse tensor with all coordinates.
Assume the input to have BxCxD1xD2x….xDN format.
If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage
Example:
>>> dense_tensor = torch.rand(3, 4, 5, 6, 7, 8) # BxCxD1xD2xD3xD4 >>> dense_tensor.requires_grad = True >>> stensor = to_sparse(dense_tensor)
SparseCollation¶
-
class
MinkowskiEngine.utils.
SparseCollation
(limit_numpoints=- 1, dtype=torch.int32, device=None)¶ Generates collate function for coords, feats, labels.
Please refer to the training example for the usage.
- Args:
limit_numpoints
(int): If positive integer, limits batch size so that the number of input coordinates is below limit_numpoints. If 0 or False, concatenate all points. -1 by default.
Example:
>>> data_loader = torch.utils.data.DataLoader( >>> dataset, >>> ..., >>> collate_fn=SparseCollation()) >>> for d in iter(data_loader): >>> print(d)
-
__init__
(limit_numpoints=- 1, dtype=torch.int32, device=None)¶ Initialize self. See help(type(self)) for accurate signature.
MinkowskiToSparseTensor¶
-
class
MinkowskiEngine.
MinkowskiToSparseTensor
(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)¶ Converts a (differentiable) dense tensor or a
MinkowskiEngine.TensorField
to aMinkowskiEngine.SparseTensor
.For dense tensor, the input must have the BxCxD1xD2x….xDN format.
remove_zeros
(bool): if True, removes zero valued coordinates. If False, use all coordinates to populate a sparse tensor. True by default.If the shape of the tensor do not change, use dense_coordinates to cache the coordinates. Please refer to tests/python/dense.py for usage.
Example:
>>> # Differentiable dense torch.Tensor to sparse tensor. >>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11) # BxCxD1xD2x....xDN >>> dense_tensor.requires_grad = True >>> # Since the shape is fixed, cache the coordinates for faster inference >>> coordinates = dense_coordinates(dense_tensor.shape) >>> network = nn.Sequential( >>> # Add layers that can be applied on a regular pytorch tensor >>> nn.ReLU(), >>> MinkowskiToSparseTensor(remove_zeros=False, coordinates=coordinates), >>> MinkowskiConvolution(4, 5, kernel_size=3, dimension=4), >>> MinkowskiBatchNorm(5), >>> MinkowskiReLU(), >>> ) >>> for i in range(5): >>> print(f"Iteration: {i}") >>> soutput = network(dense_tensor) >>> soutput.F.sum().backward() >>> soutput.dense(shape=dense_tensor.shape)
-
__init__
(remove_zeros=True, coordinates: Optional[torch.Tensor] = None)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
MinkowskiToDenseTensor¶
-
class
MinkowskiEngine.
MinkowskiToDenseTensor
(shape: Optional[torch.Size] = None)¶ Converts a (differentiable) sparse tensor to a torch tensor.
The return type has the BxCxD1xD2x….xDN format.
Example:
>>> dense_tensor = torch.rand(3, 4, 11, 11, 11, 11) # BxCxD1xD2x....xDN >>> dense_tensor.requires_grad = True >>> # Since the shape is fixed, cache the coordinates for faster inference >>> coordinates = dense_coordinates(dense_tensor.shape) >>> network = nn.Sequential( >>> # Add layers that can be applied on a regular pytorch tensor >>> nn.ReLU(), >>> MinkowskiToSparseTensor(coordinates=coordinates), >>> MinkowskiConvolution(4, 5, stride=2, kernel_size=3, dimension=4), >>> MinkowskiBatchNorm(5), >>> MinkowskiReLU(), >>> MinkowskiConvolutionTranspose(5, 6, stride=2, kernel_size=3, dimension=4), >>> MinkowskiToDenseTensor( >>> dense_tensor.shape >>> ), # must have the same tensor stride. >>> ) >>> for i in range(5): >>> print(f"Iteration: {i}") >>> output = network(dense_tensor) # returns a regular pytorch tensor >>> output.sum().backward()
-
__init__
(shape: Optional[torch.Size] = None)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
MinkowskiToFeature¶
-
class
MinkowskiEngine.
MinkowskiToFeature
¶ Extract features from a sparse tensor and returns a pytorch tensor.
Can be used to to make a network construction simpler.
Example:
>>> net = nn.Sequential(MinkowskiConvolution(...), MinkowskiGlobalMaxPooling(...), MinkowskiToFeature(), nn.Linear(...)) >>> torch_tensor = net(sparse_tensor)
-
__init__
()¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
MinkowskiStackCat¶
-
class
MinkowskiEngine.
MinkowskiStackCat
(*args: torch.nn.modules.module.Module)¶ -
class
MinkowskiEngine.
MinkowskiStackCat
(arg: OrderedDict[str, Module]) -
__init__
(*args: Any)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
MinkowskiStackSum¶
-
class
MinkowskiEngine.
MinkowskiStackSum
(*args: torch.nn.modules.module.Module)¶ -
class
MinkowskiEngine.
MinkowskiStackSum
(arg: OrderedDict[str, Module]) -
__init__
(*args: Any)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
MinkowskiStackMean¶
-
class
MinkowskiEngine.
MinkowskiStackMean
(*args: torch.nn.modules.module.Module)¶ -
class
MinkowskiEngine.
MinkowskiStackMean
(arg: OrderedDict[str, Module]) -
__init__
(*args: Any)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
MinkowskiStackVar¶
-
class
MinkowskiEngine.
MinkowskiStackVar
(*args: torch.nn.modules.module.Module)¶ -
class
MinkowskiEngine.
MinkowskiStackVar
(arg: OrderedDict[str, Module]) -
__init__
(*args: Any)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-