Lora Manager#
-
class LoRAManager#
Manages LoRA adapter weights and switching.
Maintains a map of named adapters, each containing a set of weight tensors identified by binding name. At any time at most one adapter is “active”. Switching adapters is O(1) — no GPU copies are required because the TensorRT engine binds directly to the stored tensor pointers.
Usage:
loadWeights()oraddWeights()to register adapter(s).switchWeights(name)to activate an adapter.getActiveWeight(bindingName)to retrieve the tensor for engine binding.resetWeights()to deactivate (bind dummy zero-tensors).
Public Functions
-
LoRAManager() = default#
Default constructor.
-
LoRAManager(LoRAManager const&) = delete#
Deleted copy to prevent accidental duplication of GPU resources.
-
LoRAManager &operator=(LoRAManager const&) = delete#
-
LoRAManager(LoRAManager&&) noexcept = default#
Allow move.
-
LoRAManager &operator=(LoRAManager&&) noexcept = default#
- void loadWeights(
- std::string const &name,
- std::filesystem::path const &path,
- cudaStream_t stream
Load LoRA adapter weights from a safetensors file.
Each tensor in the safetensors file is stored under its tensor name as the binding name.
- Parameters:
name – Adapter name (user-facing identifier).
path – Path to the
.safetensorsfile.stream – CUDA stream for async loading.
- Throws:
std::runtime_error – if file cannot be read or format is invalid.
- void addWeights( )#
Register adapter weights directly (useful for unit testing without I/O).
- Parameters:
name – Adapter name.
weights – Map of binding-name to tensor (tensors are moved from).
-
void switchWeights(std::string const &name)#
Activate an adapter by name.
- Parameters:
name – Adapter name (must have been loaded/added previously).
- Throws:
std::runtime_error – if the adapter name is not found.
-
void resetWeights()#
Deactivate any adapter. After this call
getActiveWeight()returns a reference to a zero-filled dummy tensor (of shape [1]).
-
rt::Tensor &getActiveWeight(std::string const &bindingName)#
Retrieve the currently active tensor for a given binding name. O(1).
- Parameters:
bindingName – The engine binding name (e.g. “lora_A_layer_0”).
- Throws:
std::runtime_error – if
bindingNameis not found in the active adapter.- Returns:
Reference to the weight tensor (dummy tensor if no adapter is active).
-
std::string const &getActiveAdapterName() const noexcept#
Return the name of the currently active adapter, or an empty string if no adapter is active.
-
std::vector<std::string> getBindingNames() const#
Return all binding names across all loaded adapters. Useful for initialising a
TensorMapwith the correct keys.
-
std::vector<std::string> getAdapterNames() const#
Return all loaded adapter names.
-
bool hasActiveAdapter() const noexcept#
Check whether any adapter is currently active.
-
bool hasWeightFor(std::string const &bindingName) const noexcept#
Check whether the active adapter contains a weight under
bindingName.Fused engines and non-fused adapters sometimes use different naming conventions (e.g.
qkv_proj.*vs separateq_proj.*/k_proj.*/v_proj.*).refreshTensorMapuses this predicate to decide whether to bind the adapter’s weight or fall back to a dummy tensor, without paying the cost oftry/catcharoundgetActiveWeight.Returns false when no adapter is active.
-
void initializeEngineBindings(EngineExecutor const &runner)#
Register the engine’s LoRA I/O bindings and create rank=1 dummy tensors with the correct engine shapes. Must be called once after the EngineExecutor is constructed so that
refreshTensorMap()knows which names to populate.Encapsulates the LoRA binding-shape convention:
lora_A_*weights have shape [k, rank]; dummy sets last dim to 1.lora_B_*weights have shape [rank, n]; dummy sets first dim to 1.
- Parameters:
runner – Source of the engine I/O list and per-binding max shapes.
-
void refreshTensorMap(TensorMap &map)#
Refresh all LoRA entries in the given TensorMap.
For each registered engine binding name, either the active adapter’s weight tensor or the per-binding dummy tensor is written into
map. Must be called after everyswitchWeights()/resetWeights().- Parameters:
map – TensorMap to update with the current LoRA bindings.