magnitude

Magnitude-base sparsity inspired by NVIDIA ASP (Automatic SParsity).

Classes

MagnitudeSearcher

Searcher for magnitude-based sparsity.

Functions

compute_valid_1d_patterns

Computes all possible m:n patterns in a 1D vector.

create_asp_mask

Creates a mask for a given tensor based on a specified sparse pattern.

fill

Calculates the ratio of non-zero elements in a tensor.

get_nmprune_info

Gets the n:m sparsity pattern information from a given string.

m4n2_1d

Finds the best 2:4 pattern in a given matrix.

mn_1d_best

Finds the best m:n pattern in a given matrix.

reshape_1d

Reshapes a given matrix into m-dimensional vectors: (h,w) -> (hw/m, m).

class MagnitudeSearcher

Bases: BaseSparseSearcher

Searcher for magnitude-based sparsity.

compute_valid_1d_patterns(m, n)

Computes all possible m:n patterns in a 1D vector.

The function generates a tensor of size m with n ones and (m-n) zeros. It then generates all permutations of this tensor, removes duplicates, and returns the unique patterns as a tensor.

create_asp_mask(tensor, pattern)

Creates a mask for a given tensor based on a specified sparse pattern.

The function reshapes the tensor and applies the specified pattern to create a sparse mask. The default pattern is m4n2_1d, which finds the best 2:4 sparsity pattern in the tensor.

Parameters:
  • tensor (Parameter) –

  • pattern (str) –

Return type:

BoolTensor

fill(x)

Calculates the ratio of non-zero elements in a tensor.

get_nmprune_info(pattern)

Gets the n:m sparsity pattern information from a given string.

Parameters:

pattern (str) –

Return type:

Tuple[bool, int, int]

m4n2_1d(mat)

Finds the best 2:4 pattern in a given matrix.

mn_1d_best(matrix, m, n)

Finds the best m:n pattern in a given matrix.

The function computes all possible m:n patterns and selects the one that maximizes the sum of non-masked weights in the matrix. The selected pattern is then used to create a mask for the matrix.

reshape_1d(matrix, m)

Reshapes a given matrix into m-dimensional vectors: (h,w) -> (hw/m, m).