SparseTensor and TensorField

SparseTensor

class MinkowskiEngine.MinkowskiSparseTensor.SparseTensor(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: 0>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)

A sparse tensor class. Can be accessed via MinkowskiEngine.SparseTensor.

The SparseTensor class is the basic tensor in MinkowskiEngine. For the definition of a sparse tensor, please visit the terminology page. We use the COOrdinate (COO) format to save a sparse tensor [1]. This representation is simply a concatenation of coordinates in a matrix \(C\) and associated features \(F\).

\[\begin{split}\mathbf{C} = \begin{bmatrix} b_1 & x_1^1 & x_1^2 & \cdots & x_1^D \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ b_N & x_N^1 & x_N^2 & \cdots & x_N^D \end{bmatrix}, \; \mathbf{F} = \begin{bmatrix} \mathbf{f}_1^T\\ \vdots\\ \mathbf{f}_N^T \end{bmatrix}\end{split}\]

where \(\mathbf{x}_i \in \mathcal{Z}^D\) is a \(D\)-dimensional coordinate and \(b_i \in \mathcal{Z}_+\) denotes the corresponding batch index. \(N\) is the number of non-zero elements in the sparse tensor, each with the coordinate \((b_i, x_i^1, x_i^1, \cdots, x_i^D)\), and the associated feature \(\mathbf{f}_i\). Internally, we handle the batch index as an additional spatial dimension.

Example:

>>> coords, feats = ME.utils.sparse_collate([coords_batch0, coords_batch1], [feats_batch0, feats_batch1])
>>> A = ME.SparseTensor(features=feats, coordinates=coords)
>>> B = ME.SparseTensor(features=feats, coordinate_map_key=A.coordiante_map_key, coordinate_manager=A.coordinate_manager)
>>> C = ME.SparseTensor(features=feats, coordinates=coords, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> D = ME.SparseTensor(features=feats, coordinates=coords, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE)
>>> E = ME.SparseTensor(features=feats, coordinates=coords, tensor_stride=2)

Warning

To use the GPU-backend for coordinate management, the coordinates must be a torch tensor on GPU. Applying to(device) after MinkowskiEngine.SparseTensor initialization with a CPU coordinates will waste time and computation on creating an unnecessary CPU CoordinateMap since the GPU CoordinateMap will be created from scratch as well.

Warning

Before MinkowskiEngine version 0.4, we put the batch indices on the last column. Thus, direct manipulation of coordinates will be incompatible with the latest versions. Instead, please use MinkowskiEngine.utils.batched_coordinates or MinkowskiEngine.utils.sparse_collate to create batched coordinates.

Also, to access coordinates or features batch-wise, use the functions coordinates_at(batch_index : int), features_at(batch_index : int) of a sparse tensor. Or to access all batch-wise coordinates and features, decomposed_coordinates, decomposed_features, decomposed_coordinates_and_features of a sparse tensor.

Example:

>>> coords, feats = ME.utils.sparse_collate([coords_batch0, coords_batch1], [feats_batch0, feats_batch1])
>>> A = ME.SparseTensor(feats=feats, coords=coords)
>>> coords_batch0 = A.coordinates_at(batch_index=0)
>>> feats_batch1 = A.features_at(batch_index=1)
>>> list_of_coords, list_of_featurs = A.decomposed_coordinates_and_features
__init__(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: 0>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)
Args:

features (torch.FloatTensor, torch.DoubleTensor, torch.cuda.FloatTensor, or torch.cuda.DoubleTensor): The features of a sparse tensor.

coordinates (torch.IntTensor): The coordinates associated to the features. If not provided, coordinate_map_key must be provided.

tensor_stride (int, list, numpy.array, or tensor.Tensor): The tensor stride of the current sparse tensor. By default, it is 1.

coordinate_map_key (MinkowskiEngine.CoordinateMapKey): When the coordinates are already cached in the MinkowskiEngine, we could reuse the same coordinate map by simply providing the coordinate map key. In most case, this process is done automatically. When you provide a coordinate_map_key, coordinates will be be ignored.

coordinate_manager (MinkowskiEngine.CoordinateManager): The MinkowskiEngine manages all coordinate maps using the _C.CoordinateMapManager. If not provided, the MinkowskiEngine will create a new computation graph. In most cases, this process is handled automatically and you do not need to use this.

quantization_mode (MinkowskiEngine.SparseTensorQuantizationMode): Defines how continuous coordinates will be quantized to define a sparse tensor. Please refer to SparseTensorQuantizationMode for details.

allocator_type (MinkowskiEngine.GPUMemoryAllocatorType): Defines the GPU memory allocator type. By default, it uses the c10 allocator.

minkowski_algorithm (MinkowskiEngine.MinkowskiAlgorithm): Controls the mode the minkowski engine runs, Use MinkowskiAlgorithm.MEMORY_EFFICIENT if you want to reduce the memory footprint. Or use MinkowskiAlgorithm.SPEED_OPTIMIZED if you want to make it run fasterat the cost of more memory.

requires_grad (bool): Set the requires_grad flag.

device (torch.device): Set the device the sparse tensor is defined.

cat_slice(X)
Args:

X (MinkowskiEngine.SparseTensor): a sparse tensor that discretized the original input.

Returns:

tensor_field (MinkowskiEngine.TensorField): the resulting tensor field contains the concatenation of features on the original continuous coordinates that generated the input X and the self.

Example:

>>> # coords, feats from a data loader
>>> print(len(coords))  # 227742
>>> sinput = ME.SparseTensor(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> print(len(sinput))  # 161890 quantization results in fewer voxels
>>> soutput = network(sinput)
>>> print(len(soutput))  # 161890 Output with the same resolution
>>> ofield = soutput.cat_slice(sinput)
>>> assert soutput.F.size(1) + sinput.F.size(1) == ofield.F.size(1)  # concatenation of features
coordinate_map_key
dense(shape=None, min_coordinate=None, contract_stride=True)

Convert the MinkowskiEngine.SparseTensor to a torch dense tensor.

Args:

shape (torch.Size, optional): The size of the output tensor.

min_coordinate (torch.IntTensor, optional): The min coordinates of the output sparse tensor. Must be divisible by the current tensor_stride. If 0 is given, it will use the origin for the min coordinate.

contract_stride (bool, optional): The output coordinates will be divided by the tensor stride to make features spatially contiguous. True by default.

Returns:

tensor (torch.Tensor): the torch tensor with size [Batch Dim, Feature Dim, Spatial Dim…, Spatial Dim]. The coordinate of each feature can be accessed via min_coordinate + tensor_stride * [the coordinate of the dense tensor].

min_coordinate (torch.IntTensor): the D-dimensional vector defining the minimum coordinate of the output tensor.

tensor_stride (torch.IntTensor): the D-dimensional vector defining the stride between tensor elements.

features_at_coordinates(query_coordinates: torch.Tensor)

Extract features at the specified continuous coordinate matrix.

Args:

query_coordinates (torch.FloatTensor): a coordinate matrix of size \(N \times (D + 1)\) where \(D\) is the size of the spatial dimension.

Returns:

queried_features (torch.Tensor): a feature matrix of size \(N \times D_F\) where \(D_F\) is the number of channels in the feature. For coordinates not present in the current sparse tensor, corresponding feature rows will be zeros.

initialize_coordinates(coordinates, features, coordinate_map_key)
inverse_mapping
quantization_mode
slice(X)
Args:

X (MinkowskiEngine.SparseTensor): a sparse tensor that discretized the original input.

Returns:

tensor_field (MinkowskiEngine.TensorField): the resulting tensor field contains features on the continuous coordinates that generated the input X.

Example:

>>> # coords, feats from a data loader
>>> print(len(coords))  # 227742
>>> tfield = ME.TensorField(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
>>> print(len(tfield))  # 227742
>>> sinput = tfield.sparse() # 161890 quantization results in fewer voxels
>>> soutput = MinkUNet(sinput)
>>> print(len(soutput))  # 161890 Output with the same resolution
>>> ofield = soutput.slice(tfield)
>>> assert isinstance(ofield, ME.TensorField)
>>> len(ofield) == len(coords)  # recovers the original ordering and length
>>> assert isinstance(ofield.F, torch.Tensor)  # .F returns the features
sparse(min_coords=None, max_coords=None, contract_coords=True)

Convert the MinkowskiEngine.SparseTensor to a torch sparse tensor.

Args:

min_coords (torch.IntTensor, optional): The min coordinates of the output sparse tensor. Must be divisible by the current tensor_stride.

max_coords (torch.IntTensor, optional): The max coordinates of the output sparse tensor (inclusive). Must be divisible by the current tensor_stride.

contract_coords (bool, optional): Given True, the output coordinates will be divided by the tensor stride to make features contiguous.

Returns:

spare_tensor (torch.sparse.Tensor): the torch sparse tensor representation of the self in [Batch Dim, Spatial Dims…, Feature Dim]. The coordinate of each feature can be accessed via min_coord + tensor_stride * [the coordinate of the dense tensor].

min_coords (torch.IntTensor): the D-dimensional vector defining the minimum coordinate of the output sparse tensor. If contract_coords is True, the min_coords will also be contracted.

tensor_stride (torch.IntTensor): the D-dimensional vector defining the stride between tensor elements.

unique_index

TensorField

class MinkowskiEngine.MinkowskiTensorField.TensorField(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_field_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: 1>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)
__init__(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_field_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: 1>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)
Args:

features (torch.FloatTensor, torch.DoubleTensor, torch.cuda.FloatTensor, or torch.cuda.DoubleTensor): The features of a sparse tensor.

coordinates (torch.IntTensor): The coordinates associated to the features. If not provided, coordinate_map_key must be provided.

tensor_stride (int, list, numpy.array, or tensor.Tensor): The tensor stride of the current sparse tensor. By default, it is 1.

coordinate_field_map_key (MinkowskiEngine.CoordinateMapKey): When the coordinates are already cached in the MinkowskiEngine, we could reuse the same coordinate map by simply providing the coordinate map key. In most case, this process is done automatically. When you provide a coordinate_field_map_key, coordinates will be be ignored.

coordinate_manager (MinkowskiEngine.CoordinateManager): The MinkowskiEngine manages all coordinate maps using the _C.CoordinateMapManager. If not provided, the MinkowskiEngine will create a new computation graph. In most cases, this process is handled automatically and you do not need to use this.

quantization_mode (MinkowskiEngine.SparseTensorQuantizationMode): Defines how continuous coordinates will be quantized to define a sparse tensor. Please refer to SparseTensorQuantizationMode for details.

allocator_type (MinkowskiEngine.GPUMemoryAllocatorType): Defines the GPU memory allocator type. By default, it uses the c10 allocator.

minkowski_algorithm (MinkowskiEngine.MinkowskiAlgorithm): Controls the mode the minkowski engine runs, Use MinkowskiAlgorithm.MEMORY_EFFICIENT if you want to reduce the memory footprint. Or use MinkowskiAlgorithm.SPEED_OPTIMIZED if you want to make it run fasterat the cost of more memory.

requires_grad (bool): Set the requires_grad flag.

device (torch.device): Set the device the sparse tensor is defined.

property C

The alias of coords.

coordinate_field_map_key
property coordinates

The coordinates of the current sparse tensor. The coordinates are represented as a \(N \times (D + 1)\) dimensional matrix where \(N\) is the number of points in the space and \(D\) is the dimension of the space (e.g. 3 for 3D, 4 for 3D + Time). Additional dimension of the column of the matrix C is for batch indices which is internally treated as an additional spatial dimension to disassociate different instances in a batch.

inverse_mapping(sparse_tensor_map_key: MinkowskiEngineBackend._C.CoordinateMapKey)
quantization_mode
sparse(tensor_stride: Union[int, collections.abc.Sequence, numpy.array] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, quantization_mode=None)

Converts the current sparse tensor field to a sparse tensor.

SparseTensorOperationMode

class MinkowskiEngine.MinkowskiTensor.SparseTensorOperationMode(value)

Enum class for SparseTensor internal instantiation modes.

SEPARATE_COORDINATE_MANAGER: always create a new coordinate manager.

SHARE_COORDINATE_MANAGER: always use the globally defined coordinate manager. Must clear the coordinate manager manually by MinkowskiEngine.SparseTensor.clear_global_coordinate_manager.

SparseTensorQuantizationMode

class MinkowskiEngine.MinkowskiTensor.SparseTensorQuantizationMode(value)

RANDOM_SUBSAMPLE: Subsample one coordinate per each quantization block randomly. UNWEIGHTED_AVERAGE: average all features within a quantization block equally. UNWEIGHTED_SUM: sum all features within a quantization block equally. NO_QUANTIZATION: No quantization is applied. Should not be used for normal operation.

set_sparse_tensor_operation_mode

MinkowskiEngine.MinkowskiTensor.set_sparse_tensor_operation_mode(operation_mode: MinkowskiEngine.MinkowskiTensor.SparseTensorOperationMode)

Define the sparse tensor coordinate manager operation mode.

By default, a MinkowskiEngine.SparseTensor.SparseTensor instantiation creates a new coordinate manager that is not shared with other sparse tensors. By setting this function with MinkowskiEngine.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER, you can share the coordinate manager globally with other sparse tensors. However, you must explicitly clear the coordinate manger after use. Please refer to MinkowskiEngine.clear_global_coordinate_manager.

Args:

operation_mode (MinkowskiEngine.SparseTensorOperationMode): The operation mode for the sparse tensor coordinate manager. By default MinkowskiEngine.SparseTensorOperationMode.SEPARATE_COORDINATE_MANAGER.

Example:

>>> import MinkowskiEngine as ME
>>> ME.set_sparse_tensor_operation_mode(ME.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER)
>>> ...
>>> a = ME.SparseTensor(...)
>>> b = ME.SparseTensor(...)  # coords_man shared
>>> ...  # one feed forward and backward
>>> ME.clear_global_coordinate_manager()  # Must use to clear the coordinates after one forward/backward

sparse_tensor_operation_mode

MinkowskiEngine.MinkowskiTensor.sparse_tensor_operation_mode()MinkowskiEngine.MinkowskiTensor.SparseTensorOperationMode

Return the current sparse tensor operation mode.

global_coordinate_manager

MinkowskiEngine.MinkowskiTensor.global_coordinate_manager()

Return the current global coordinate manager

set_global_coordinate_manager

MinkowskiEngine.MinkowskiTensor.set_global_coordinate_manager(coordinate_manager)

Set the global coordinate manager.

MinkowskiEngine.CoordinateManager The coordinate manager which will be set to the global coordinate manager.

clear_global_coordinate_manager

MinkowskiEngine.MinkowskiTensor.clear_global_coordinate_manager()

Clear the global coordinate manager cache.

When you use the operation mode: MinkowskiEngine.SparseTensor.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER, you must explicitly clear the coordinate manager after each feed forward/backward.