SparseTensor and TensorField¶
SparseTensor¶
-
class
MinkowskiEngine.MinkowskiSparseTensor.
SparseTensor
(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: 0>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)¶ A sparse tensor class. Can be accessed via
MinkowskiEngine.SparseTensor
.The
SparseTensor
class is the basic tensor in MinkowskiEngine. For the definition of a sparse tensor, please visit the terminology page. We use the COOrdinate (COO) format to save a sparse tensor [1]. This representation is simply a concatenation of coordinates in a matrix \(C\) and associated features \(F\).\[\begin{split}\mathbf{C} = \begin{bmatrix} b_1 & x_1^1 & x_1^2 & \cdots & x_1^D \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ b_N & x_N^1 & x_N^2 & \cdots & x_N^D \end{bmatrix}, \; \mathbf{F} = \begin{bmatrix} \mathbf{f}_1^T\\ \vdots\\ \mathbf{f}_N^T \end{bmatrix}\end{split}\]where \(\mathbf{x}_i \in \mathcal{Z}^D\) is a \(D\)-dimensional coordinate and \(b_i \in \mathcal{Z}_+\) denotes the corresponding batch index. \(N\) is the number of non-zero elements in the sparse tensor, each with the coordinate \((b_i, x_i^1, x_i^1, \cdots, x_i^D)\), and the associated feature \(\mathbf{f}_i\). Internally, we handle the batch index as an additional spatial dimension.
Example:
>>> coords, feats = ME.utils.sparse_collate([coords_batch0, coords_batch1], [feats_batch0, feats_batch1]) >>> A = ME.SparseTensor(features=feats, coordinates=coords) >>> B = ME.SparseTensor(features=feats, coordinate_map_key=A.coordiante_map_key, coordinate_manager=A.coordinate_manager) >>> C = ME.SparseTensor(features=feats, coordinates=coords, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) >>> D = ME.SparseTensor(features=feats, coordinates=coords, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE) >>> E = ME.SparseTensor(features=feats, coordinates=coords, tensor_stride=2)
Warning
To use the GPU-backend for coordinate management, the
coordinates
must be a torch tensor on GPU. Applying to(device) afterMinkowskiEngine.SparseTensor
initialization with a CPU coordinates will waste time and computation on creating an unnecessary CPU CoordinateMap since the GPU CoordinateMap will be created from scratch as well.Warning
Before MinkowskiEngine version 0.4, we put the batch indices on the last column. Thus, direct manipulation of coordinates will be incompatible with the latest versions. Instead, please use
MinkowskiEngine.utils.batched_coordinates
orMinkowskiEngine.utils.sparse_collate
to create batched coordinates.Also, to access coordinates or features batch-wise, use the functions
coordinates_at(batch_index : int)
,features_at(batch_index : int)
of a sparse tensor. Or to access all batch-wise coordinates and features, decomposed_coordinates, decomposed_features, decomposed_coordinates_and_features of a sparse tensor.Example:
>>> coords, feats = ME.utils.sparse_collate([coords_batch0, coords_batch1], [feats_batch0, feats_batch1]) >>> A = ME.SparseTensor(feats=feats, coords=coords) >>> coords_batch0 = A.coordinates_at(batch_index=0) >>> feats_batch1 = A.features_at(batch_index=1) >>> list_of_coords, list_of_featurs = A.decomposed_coordinates_and_features
-
__init__
(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.RANDOM_SUBSAMPLE: 0>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)¶ - Args:
features
(torch.FloatTensor
,torch.DoubleTensor
,torch.cuda.FloatTensor
, ortorch.cuda.DoubleTensor
): The features of a sparse tensor.coordinates
(torch.IntTensor
): The coordinates associated to the features. If not provided,coordinate_map_key
must be provided.tensor_stride
(int
,list
,numpy.array
, ortensor.Tensor
): The tensor stride of the current sparse tensor. By default, it is 1.coordinate_map_key
(MinkowskiEngine.CoordinateMapKey
): When the coordinates are already cached in the MinkowskiEngine, we could reuse the same coordinate map by simply providing the coordinate map key. In most case, this process is done automatically. When you provide a coordinate_map_key, coordinates will be be ignored.coordinate_manager
(MinkowskiEngine.CoordinateManager
): The MinkowskiEngine manages all coordinate maps using the _C.CoordinateMapManager. If not provided, the MinkowskiEngine will create a new computation graph. In most cases, this process is handled automatically and you do not need to use this.quantization_mode
(MinkowskiEngine.SparseTensorQuantizationMode
): Defines how continuous coordinates will be quantized to define a sparse tensor. Please refer toSparseTensorQuantizationMode
for details.allocator_type
(MinkowskiEngine.GPUMemoryAllocatorType
): Defines the GPU memory allocator type. By default, it uses the c10 allocator.minkowski_algorithm
(MinkowskiEngine.MinkowskiAlgorithm
): Controls the mode the minkowski engine runs, UseMinkowskiAlgorithm.MEMORY_EFFICIENT
if you want to reduce the memory footprint. Or useMinkowskiAlgorithm.SPEED_OPTIMIZED
if you want to make it run fasterat the cost of more memory.requires_grad
(bool
): Set the requires_grad flag.device
(torch.device
): Set the device the sparse tensor is defined.
-
cat_slice
(X)¶ - Args:
X
(MinkowskiEngine.SparseTensor
): a sparse tensor that discretized the original input.- Returns:
tensor_field
(MinkowskiEngine.TensorField
): the resulting tensor field contains the concatenation of features on the original continuous coordinates that generated the input X and the self.
Example:
>>> # coords, feats from a data loader >>> print(len(coords)) # 227742 >>> sinput = ME.SparseTensor(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) >>> print(len(sinput)) # 161890 quantization results in fewer voxels >>> soutput = network(sinput) >>> print(len(soutput)) # 161890 Output with the same resolution >>> ofield = soutput.cat_slice(sinput) >>> assert soutput.F.size(1) + sinput.F.size(1) == ofield.F.size(1) # concatenation of features
-
coordinate_map_key
¶
-
dense
(shape=None, min_coordinate=None, contract_stride=True)¶ Convert the
MinkowskiEngine.SparseTensor
to a torch dense tensor.- Args:
shape
(torch.Size, optional): The size of the output tensor.min_coordinate
(torch.IntTensor, optional): The min coordinates of the output sparse tensor. Must be divisible by the currenttensor_stride
. If 0 is given, it will use the origin for the min coordinate.contract_stride
(bool, optional): The output coordinates will be divided by the tensor stride to make features spatially contiguous. True by default.- Returns:
tensor
(torch.Tensor): the torch tensor with size [Batch Dim, Feature Dim, Spatial Dim…, Spatial Dim]. The coordinate of each feature can be accessed via min_coordinate + tensor_stride * [the coordinate of the dense tensor].min_coordinate
(torch.IntTensor): the D-dimensional vector defining the minimum coordinate of the output tensor.tensor_stride
(torch.IntTensor): the D-dimensional vector defining the stride between tensor elements.
-
features_at_coordinates
(query_coordinates: torch.Tensor)¶ Extract features at the specified continuous coordinate matrix.
- Args:
query_coordinates
(torch.FloatTensor
): a coordinate matrix of size \(N \times (D + 1)\) where \(D\) is the size of the spatial dimension.- Returns:
queried_features
(torch.Tensor
): a feature matrix of size \(N \times D_F\) where \(D_F\) is the number of channels in the feature. For coordinates not present in the current sparse tensor, corresponding feature rows will be zeros.
-
initialize_coordinates
(coordinates, features, coordinate_map_key)¶
-
inverse_mapping
¶
-
quantization_mode
¶
-
slice
(X)¶ - Args:
X
(MinkowskiEngine.SparseTensor
): a sparse tensor that discretized the original input.- Returns:
tensor_field
(MinkowskiEngine.TensorField
): the resulting tensor field contains features on the continuous coordinates that generated the input X.
Example:
>>> # coords, feats from a data loader >>> print(len(coords)) # 227742 >>> tfield = ME.TensorField(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE) >>> print(len(tfield)) # 227742 >>> sinput = tfield.sparse() # 161890 quantization results in fewer voxels >>> soutput = MinkUNet(sinput) >>> print(len(soutput)) # 161890 Output with the same resolution >>> ofield = soutput.slice(tfield) >>> assert isinstance(ofield, ME.TensorField) >>> len(ofield) == len(coords) # recovers the original ordering and length >>> assert isinstance(ofield.F, torch.Tensor) # .F returns the features
-
sparse
(min_coords=None, max_coords=None, contract_coords=True)¶ Convert the
MinkowskiEngine.SparseTensor
to a torch sparse tensor.- Args:
min_coords
(torch.IntTensor, optional): The min coordinates of the output sparse tensor. Must be divisible by the currenttensor_stride
.max_coords
(torch.IntTensor, optional): The max coordinates of the output sparse tensor (inclusive). Must be divisible by the currenttensor_stride
.contract_coords
(bool, optional): Given True, the output coordinates will be divided by the tensor stride to make features contiguous.- Returns:
spare_tensor
(torch.sparse.Tensor): the torch sparse tensor representation of the self in [Batch Dim, Spatial Dims…, Feature Dim]. The coordinate of each feature can be accessed via min_coord + tensor_stride * [the coordinate of the dense tensor].min_coords
(torch.IntTensor): the D-dimensional vector defining the minimum coordinate of the output sparse tensor. Ifcontract_coords
is True, themin_coords
will also be contracted.tensor_stride
(torch.IntTensor): the D-dimensional vector defining the stride between tensor elements.
-
unique_index
¶
-
TensorField¶
-
class
MinkowskiEngine.MinkowskiTensorField.
TensorField
(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_field_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: 1>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)¶ -
__init__
(features: torch.Tensor, coordinates: Optional[torch.Tensor] = None, tensor_stride: Union[int, collections.abc.Sequence, numpy.ndarray, torch.IntTensor] = 1, coordinate_field_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, coordinate_manager: Optional[MinkowskiCoordinateManager.CoordinateManager] = None, quantization_mode: MinkowskiTensor.SparseTensorQuantizationMode = <SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE: 1>, allocator_type: Optional[MinkowskiEngineBackend._C.GPUMemoryAllocatorType] = None, minkowski_algorithm: Optional[MinkowskiEngineBackend._C.MinkowskiAlgorithm] = None, requires_grad=None, device=None)¶ - Args:
features
(torch.FloatTensor
,torch.DoubleTensor
,torch.cuda.FloatTensor
, ortorch.cuda.DoubleTensor
): The features of a sparse tensor.coordinates
(torch.IntTensor
): The coordinates associated to the features. If not provided,coordinate_map_key
must be provided.tensor_stride
(int
,list
,numpy.array
, ortensor.Tensor
): The tensor stride of the current sparse tensor. By default, it is 1.coordinate_field_map_key
(MinkowskiEngine.CoordinateMapKey
): When the coordinates are already cached in the MinkowskiEngine, we could reuse the same coordinate map by simply providing the coordinate map key. In most case, this process is done automatically. When you provide a coordinate_field_map_key, coordinates will be be ignored.coordinate_manager
(MinkowskiEngine.CoordinateManager
): The MinkowskiEngine manages all coordinate maps using the _C.CoordinateMapManager. If not provided, the MinkowskiEngine will create a new computation graph. In most cases, this process is handled automatically and you do not need to use this.quantization_mode
(MinkowskiEngine.SparseTensorQuantizationMode
): Defines how continuous coordinates will be quantized to define a sparse tensor. Please refer toSparseTensorQuantizationMode
for details.allocator_type
(MinkowskiEngine.GPUMemoryAllocatorType
): Defines the GPU memory allocator type. By default, it uses the c10 allocator.minkowski_algorithm
(MinkowskiEngine.MinkowskiAlgorithm
): Controls the mode the minkowski engine runs, UseMinkowskiAlgorithm.MEMORY_EFFICIENT
if you want to reduce the memory footprint. Or useMinkowskiAlgorithm.SPEED_OPTIMIZED
if you want to make it run fasterat the cost of more memory.requires_grad
(bool
): Set the requires_grad flag.device
(torch.device
): Set the device the sparse tensor is defined.
-
property
C
¶ The alias of
coords
.
-
coordinate_field_map_key
¶
-
property
coordinates
¶ The coordinates of the current sparse tensor. The coordinates are represented as a \(N \times (D + 1)\) dimensional matrix where \(N\) is the number of points in the space and \(D\) is the dimension of the space (e.g. 3 for 3D, 4 for 3D + Time). Additional dimension of the column of the matrix C is for batch indices which is internally treated as an additional spatial dimension to disassociate different instances in a batch.
-
inverse_mapping
(sparse_tensor_map_key: MinkowskiEngineBackend._C.CoordinateMapKey)¶
-
quantization_mode
¶
-
sparse
(tensor_stride: Union[int, collections.abc.Sequence, numpy.array] = 1, coordinate_map_key: Optional[MinkowskiEngineBackend._C.CoordinateMapKey] = None, quantization_mode=None)¶ Converts the current sparse tensor field to a sparse tensor.
-
SparseTensorOperationMode¶
-
class
MinkowskiEngine.MinkowskiTensor.
SparseTensorOperationMode
(value)¶ Enum class for SparseTensor internal instantiation modes.
SEPARATE_COORDINATE_MANAGER
: always create a new coordinate manager.SHARE_COORDINATE_MANAGER
: always use the globally defined coordinate manager. Must clear the coordinate manager manually byMinkowskiEngine.SparseTensor.clear_global_coordinate_manager
.
SparseTensorQuantizationMode¶
-
class
MinkowskiEngine.MinkowskiTensor.
SparseTensorQuantizationMode
(value)¶ RANDOM_SUBSAMPLE: Subsample one coordinate per each quantization block randomly. UNWEIGHTED_AVERAGE: average all features within a quantization block equally. UNWEIGHTED_SUM: sum all features within a quantization block equally. NO_QUANTIZATION: No quantization is applied. Should not be used for normal operation.
set_sparse_tensor_operation_mode¶
-
MinkowskiEngine.MinkowskiTensor.
set_sparse_tensor_operation_mode
(operation_mode: MinkowskiEngine.MinkowskiTensor.SparseTensorOperationMode)¶ Define the sparse tensor coordinate manager operation mode.
By default, a
MinkowskiEngine.SparseTensor.SparseTensor
instantiation creates a new coordinate manager that is not shared with other sparse tensors. By setting this function withMinkowskiEngine.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER
, you can share the coordinate manager globally with other sparse tensors. However, you must explicitly clear the coordinate manger after use. Please refer toMinkowskiEngine.clear_global_coordinate_manager
.- Args:
operation_mode
(MinkowskiEngine.SparseTensorOperationMode
): The operation mode for the sparse tensor coordinate manager. By defaultMinkowskiEngine.SparseTensorOperationMode.SEPARATE_COORDINATE_MANAGER
.
Example:
>>> import MinkowskiEngine as ME >>> ME.set_sparse_tensor_operation_mode(ME.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER) >>> ... >>> a = ME.SparseTensor(...) >>> b = ME.SparseTensor(...) # coords_man shared >>> ... # one feed forward and backward >>> ME.clear_global_coordinate_manager() # Must use to clear the coordinates after one forward/backward
sparse_tensor_operation_mode¶
-
MinkowskiEngine.MinkowskiTensor.
sparse_tensor_operation_mode
() → MinkowskiEngine.MinkowskiTensor.SparseTensorOperationMode¶ Return the current sparse tensor operation mode.
global_coordinate_manager¶
-
MinkowskiEngine.MinkowskiTensor.
global_coordinate_manager
()¶ Return the current global coordinate manager
set_global_coordinate_manager¶
-
MinkowskiEngine.MinkowskiTensor.
set_global_coordinate_manager
(coordinate_manager)¶ Set the global coordinate manager.
MinkowskiEngine.CoordinateManager
The coordinate manager which will be set to the global coordinate manager.
clear_global_coordinate_manager¶
-
MinkowskiEngine.MinkowskiTensor.
clear_global_coordinate_manager
()¶ Clear the global coordinate manager cache.
When you use the operation mode:
MinkowskiEngine.SparseTensor.SparseTensorOperationMode.SHARE_COORDINATE_MANAGER
, you must explicitly clear the coordinate manager after each feed forward/backward.