cpp_extension

Utility functions for loading CPP / CUDA extensions.

Functions

load_cpp_extension

Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.

load_cpp_extension(name, sources, cuda_version_specifiers, fail_msg='', raise_if_failed=False, **load_kwargs)

Load a C++ / CUDA extension using torch.utils.cpp_extension.load() if the current CUDA version satisfies it.

Loading first time may take a few mins because of the compilation, but subsequent loads are instantaneous.

Parameters:
  • name (str) – Name of the extension.

  • sources (List[str | Path]) – Source files to compile.

  • cuda_version_specifiers (str | None) – Specifier (e.g. “>=11.8,<12”) for CUDA versions required to enable the extension.

  • fail_msg (str) – Additional message to display if the extension fails to load.

  • raise_if_failed (bool) – Raise an exception if the extension fails to load.

  • **load_kwargs (Any) – Keyword arguments to torch.utils.cpp_extension.load().

Return type:

module | None