cpp_extension
Utility functions for loading CPP / CUDA extensions.
Functions
Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it. |
- load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', raise_if_failed=False, **load_kwargs)
Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.
Loading first time may take a few mins because of the compilation, but subsequent loads are instantaneous.
- Parameters:
name (str) – Name of the extension.
sources (List[str | Path]) – Source files to compile.
cuda_version_specifiers (str | None) – Specifier (e.g. “>=11.8,<12”) for CUDA versions required to enable the extension.
fail_msg (str) – Additional message to display if the extension fails to load.
raise_if_failed (bool) – Raise an exception if the extension fails to load.
**load_kwargs (Any) – Keyword arguments to torch.utils.cpp_extension.load().
- Return type:
module | None