cpp_extension

Utility functions for loading CPP / CUDA extensions.

Functions

load_cpp_extension

Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.

load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', **load_kwargs)

Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.

Loading first time may take a few mins because of the compilation, but subsequent loads are instantaneous.

Parameters:
  • name (str) – Name of the extension.

  • sources (List[str | Path]) – Source files to compile.

  • cuda_version_specifiers (str | None) – Specifier (e.g. “>=11.8,<12”) for CUDA versions required to enable the extension.

  • **load_kwargs (Any) – Keyword arguments to torch.utils.cpp_extension.load().

Return type:

module | None