cpp_extension
Utility functions for loading CPP / CUDA extensions.
Functions
Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it. |
- load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', **load_kwargs)
Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.
Loading first time may take a few mins because of the compilation, but subsequent loads are instantaneous.
- Parameters:
name (str) – Name of the extension.
sources (List[str | Path]) – Source files to compile.
cuda_version_specifiers (str | None) – Specifier (e.g. “>=11.8,<12”) for CUDA versions required to enable the extension.
**load_kwargs (Any) – Keyword arguments to torch.utils.cpp_extension.load().
- Return type:
module | None