Source code for experimental.server.engine

# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
vLLM-style one-line API for TensorRT Edge-LLM.

Pipeline: HuggingFace model ID or local path -> ONNX export -> TensorRT engine
build -> inference.

Example::

    from experimental.server import LLM, SamplingParams

    llm = LLM(model="Qwen/Qwen3-1.7B")
    outputs = llm.generate(
        ["What is the capital of France?"],
        SamplingParams(temperature=0.7, max_tokens=256),
    )
    for output in outputs:
        print(output.text)

    # Or start an OpenAI-compatible server:
    llm.serve(port=8000)
"""

import hashlib
import importlib.util
import json
import logging
import os
import sys
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Union

logger = logging.getLogger("edgellm.server")

_PLUGIN_LIB_NAME = "libNvInfer_edgellm_plugin.so"

_VLM_MODEL_TYPES = frozenset([
    "qwen3_vl",
    "qwen3_omni",
    "qwen3_5",
    "qwen2_5_vl",
    "internvl",
    "internvl_chat",
    "phi4mm",
    "phi4_multimodal",
])

# ---------------------------------------------------------------------------
# Public data classes
# ---------------------------------------------------------------------------


[docs] @dataclass class SamplingParams: """Sampling parameters (mirrors vLLM's SamplingParams).""" temperature: float = 0.7 top_p: float = 0.9 top_k: int = 50 max_tokens: int = 2048 enable_thinking: bool = False disable_spec_decode: bool = False
[docs] @dataclass class CompletionOutput: """Output of a single generation request.""" text: str = "" token_ids: List[int] = field(default_factory=list) finish_reason: Optional[str] = None
[docs] @dataclass class StreamDelta: """Single delta from a streaming generation.""" text: str = "" token_ids: List[int] = field(default_factory=list) finished: bool = False finish_reason: Optional[str] = None
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _resolve_model_dir(model: str) -> str: """Resolve a HuggingFace model ID or local path to a local directory.""" if os.path.isdir(model): return os.path.abspath(model) try: from huggingface_hub import snapshot_download except ImportError as exc: raise RuntimeError( "huggingface_hub is not installed. Install it with: " "pip install huggingface_hub") from exc logger.info("Downloading %s from Hugging Face Hub ...", model) return snapshot_download(model) def _artifacts_dir_for_model(model_dir: str) -> str: """Return a deterministic directory for ONNX/engine artifacts. Stored under ``<model_dir>/.edgellm/``. If that is not writable (e.g. shared filesystem), falls back to ``~/.cache/edgellm/<hash>/``. """ preferred = os.path.join(model_dir, ".edgellm") try: os.makedirs(preferred, exist_ok=True) return preferred except OSError: digest = hashlib.sha256( os.path.abspath(model_dir).encode()).hexdigest()[:12] fallback = os.path.join( os.path.expanduser("~"), ".cache", "edgellm", digest, ) os.makedirs(fallback, exist_ok=True) return fallback def _engine_config_tag( max_input_len: int, max_batch_size: int, max_kv_cache_capacity: int, ) -> str: """Return a short tag encoding the engine build config.""" return f"i{max_input_len}_b{max_batch_size}_kv{max_kv_cache_capacity}" def _is_vlm(model_dir: str) -> bool: """Check if the model is a VLM by reading config.json.""" cfg_path = os.path.join(model_dir, "config.json") if not os.path.exists(cfg_path): return False with open(cfg_path) as f: cfg = json.load(f) return cfg.get("model_type", "") in _VLM_MODEL_TYPES def _read_model_type(model_dir: str) -> str: """Read model_type from config.json.""" cfg_path = os.path.join(model_dir, "config.json") if not os.path.exists(cfg_path): return "" with open(cfg_path) as f: return json.load(f).get("model_type", "") def _read_vision_config(model_dir: str) -> dict: """Read vision_config from config.json for visual builder params.""" cfg_path = os.path.join(model_dir, "config.json") if not os.path.exists(cfg_path): return {} with open(cfg_path) as f: return json.load(f).get("vision_config", {}) def _ensure_plugin_path() -> None: """Set EDGELLM_PLUGIN_PATH if not already set. Searches relative to this package and common build locations. """ if os.environ.get("EDGELLM_PLUGIN_PATH"): return project_root = Path(__file__).resolve().parent.parent.parent search_dirs = [ project_root / "build" / "core", project_root / "build" / "lib", ] for d in search_dirs: candidate = d / _PLUGIN_LIB_NAME if candidate.is_file(): os.environ["EDGELLM_PLUGIN_PATH"] = str(candidate) return def _import_runtime(): """Import the C++ pybind module.""" _ensure_plugin_path() try: from tensorrt_edgellm import _edgellm_runtime as _rt return _rt except ImportError: pass project_root = Path(__file__).resolve().parent.parent.parent search_dirs = [ project_root / "experimental" / "pybind" / "build", project_root / "build" / "pybind", ] search_dirs.extend(project_root.glob("build/lib.*")) for cand_dir in search_dirs: if not cand_dir.is_dir(): continue so_files = list(cand_dir.glob("*_edgellm_runtime*.so")) if so_files: spec = importlib.util.spec_from_file_location( "_edgellm_runtime", so_files[0]) mod = importlib.util.module_from_spec(spec) sys.modules["tensorrt_edgellm._edgellm_runtime"] = mod spec.loader.exec_module(mod) return mod raise ImportError( "Could not import _edgellm_runtime. Build the C++ extension first:\n" " TRT_PACKAGE_DIR=/path/to/tensorrt python experimental/server/setup_pybind.py build_ext --inplace" ) def _llm_loader_dir() -> Path: """Return path to the llm_loader package, adding it to sys.path.""" loader_dir = Path(__file__).resolve().parent.parent / "llm_loader" if not loader_dir.is_dir(): raise RuntimeError(f"llm_loader not found at {loader_dir}. " "Ensure experimental/llm_loader/ exists.") experimental_dir = str(loader_dir.parent) if experimental_dir not in sys.path: sys.path.insert(0, experimental_dir) return loader_dir # --------------------------------------------------------------------------- # LLM class # ---------------------------------------------------------------------------
[docs] class LLM: """vLLM-style entry point for TensorRT Edge-LLM inference. Three initialization modes (exactly one of ``model``, ``onnx_dir``, or ``engine_dir`` must be provided): 1. **HuggingFace checkpoint** — exports ONNX, builds engine, loads:: llm = LLM(model="Qwen/Qwen3-1.7B") 2. **ONNX directory** — builds engine from ONNX, loads:: llm = LLM(onnx_dir="/path/to/onnx") 3. **Pre-built engine** — loads directly:: llm = LLM(engine_dir="/path/to/engine") llm = LLM(engine_dir="...", visual_engine_dir="...") See :mod:`experimental.server.engine_layout` for the expected directory layouts. """
[docs] def __init__( self, model: str = "", *, onnx_dir: str = "", visual_onnx_dir: str = "", engine_dir: str = "", visual_engine_dir: str = "", max_input_len: int = 4096, max_batch_size: int = 1, max_kv_cache_capacity: int = 8192, use_trt_native_ops: bool = False, eagle_engine_dir: str = "", draft_top_k: int = 10, draft_step: int = 6, verify_tree_size: int = 60, ): sources = sum(bool(s) for s in (model, onnx_dir, engine_dir)) if sources != 1: raise ValueError( "Exactly one of 'model', 'onnx_dir', or 'engine_dir' " "must be provided.") self._model_id = (model or os.path.basename(onnx_dir) or os.path.basename(engine_dir)) self._eagle_engine_dir = eagle_engine_dir self._draft_top_k = draft_top_k self._draft_step = draft_step self._verify_tree_size = verify_tree_size if engine_dir: self._init_from_engine(engine_dir, visual_engine_dir) elif onnx_dir: self._init_from_onnx( onnx_dir, visual_onnx_dir=visual_onnx_dir, max_input_len=max_input_len, max_batch_size=max_batch_size, max_kv_cache_capacity=max_kv_cache_capacity, use_trt_native_ops=use_trt_native_ops, ) else: self._init_from_model( model, max_input_len=max_input_len, max_batch_size=max_batch_size, max_kv_cache_capacity=max_kv_cache_capacity, use_trt_native_ops=use_trt_native_ops, ) self._load_runtime()
# ------------------------------------------------------------------ # Initialization paths # ------------------------------------------------------------------ def _init_from_engine(self, engine_dir: str, visual_engine_dir: str) -> None: """Load from pre-built engine directories (no export, no build).""" from .engine_layout import (find_visual_engine_dir, validate_llm_engine_dir, validate_visual_engine_dir) if not validate_llm_engine_dir(engine_dir): raise ValueError(f"llm.engine not found in: {engine_dir}") self._engine_dir = engine_dir self._model_dir = engine_dir self._is_vlm = False if visual_engine_dir: if not validate_visual_engine_dir(visual_engine_dir): raise ValueError( f"visual.engine not found in: {visual_engine_dir}") self._visual_engine_dir = visual_engine_dir self._is_vlm = True else: auto = find_visual_engine_dir(engine_dir) self._visual_engine_dir = auto or "" if auto: self._is_vlm = True logger.info("Auto-detected visual engine: %s", auto) logger.info("Using pre-built engine: %s", self._engine_dir) def _init_from_onnx( self, onnx_dir: str, *, visual_onnx_dir: str, max_input_len: int, max_batch_size: int, max_kv_cache_capacity: int, use_trt_native_ops: bool, ) -> None: """Build engine from ONNX directories (no export).""" self._max_input_len = max_input_len self._max_batch_size = max_batch_size self._max_kv_cache_capacity = max_kv_cache_capacity self._use_trt_native_ops = use_trt_native_ops self._onnx_dir = onnx_dir self._visual_onnx_dir = visual_onnx_dir self._model_dir = onnx_dir self._is_vlm = bool(visual_onnx_dir) cfg_tag = _engine_config_tag(max_input_len, max_batch_size, max_kv_cache_capacity) artifacts = _artifacts_dir_for_model(onnx_dir) eagle = self._eagle_engine_dir self._engine_dir = (eagle if eagle else os.path.join( artifacts, "engine", cfg_tag, "llm")) if not eagle and not os.path.exists( os.path.join(self._engine_dir, "llm.engine")): self._build_engine() else: logger.info("Using cached engine: %s", self._engine_dir) self._visual_engine_dir = "" if self._is_vlm: self._visual_engine_dir = os.path.join(artifacts, "engine", cfg_tag, "visual") if not os.path.exists( os.path.join(self._visual_engine_dir, "visual.engine")): self._build_visual_engine() else: logger.info("Using cached visual engine: %s", self._visual_engine_dir) def _init_from_model( self, model: str, *, max_input_len: int, max_batch_size: int, max_kv_cache_capacity: int, use_trt_native_ops: bool, ) -> None: """Export ONNX + build engine from HuggingFace checkpoint.""" self._max_input_len = max_input_len self._max_batch_size = max_batch_size self._max_kv_cache_capacity = max_kv_cache_capacity self._use_trt_native_ops = use_trt_native_ops logger.info("Resolving model: %s", model) self._model_dir = _resolve_model_dir(model) artifacts = _artifacts_dir_for_model(self._model_dir) self._is_vlm = _is_vlm(self._model_dir) self._model_type = _read_model_type(self._model_dir) if self._is_vlm: logger.info("Detected VLM model (type=%s)", self._model_type) self._onnx_dir = os.path.join(artifacts, "onnx", "llm") if not os.path.exists(os.path.join(self._onnx_dir, "model.onnx")): self._export_onnx() else: logger.info("Using cached ONNX: %s", self._onnx_dir) self._visual_onnx_dir = "" if self._is_vlm: self._visual_onnx_dir = os.path.join(artifacts, "onnx", "visual") if not os.path.exists( os.path.join(self._visual_onnx_dir, "model.onnx")): self._export_visual_onnx() else: logger.info("Using cached visual ONNX: %s", self._visual_onnx_dir) # Delegate to _init_from_onnx for the build step self._init_from_onnx( self._onnx_dir, visual_onnx_dir=self._visual_onnx_dir, max_input_len=max_input_len, max_batch_size=max_batch_size, max_kv_cache_capacity=max_kv_cache_capacity, use_trt_native_ops=use_trt_native_ops, ) def _load_runtime(self) -> None: """Load the C++ runtime from engine directories.""" self._rt = _import_runtime() logger.info("Loading TensorRT engine from %s ...", self._engine_dir) if self._visual_engine_dir: logger.info("Loading visual engine from %s ...", self._visual_engine_dir) eagle = self._eagle_engine_dir if eagle: logger.info( "Eagle spec-decode enabled (top_k=%d, step=%d, tree=%d)", self._draft_top_k, self._draft_step, self._verify_tree_size, ) self._runtime = self._rt.LLMRuntime( self._engine_dir, self._visual_engine_dir, {}, self._draft_top_k, self._draft_step, self._verify_tree_size, ) else: self._runtime = self._rt.LLMRuntime( self._engine_dir, self._visual_engine_dir, {}, ) self._runtime.capture_decoding_cuda_graph() logger.info("Engine loaded and ready.") # ------------------------------------------------------------------ # Pipeline stages # ------------------------------------------------------------------ def _export_onnx(self) -> None: """Export the model checkpoint to ONNX via llm_loader.""" logger.info("Exporting ONNX to %s ...", self._onnx_dir) os.makedirs(self._onnx_dir, exist_ok=True) _llm_loader_dir() from llm_loader import AutoModel, export_onnx model = AutoModel.from_pretrained(self._model_dir, device="cpu") output_path = os.path.join(self._onnx_dir, "model.onnx") export_onnx(model, output_path, model_dir=self._model_dir) # Patch image_token_id for VLM models if self._is_vlm: _llm_loader_dir() from llm_loader.export_all_cli import _find_token_id image_token_id = _find_token_id(self._model_dir, "<|image_pad|>") if image_token_id is not None: cfg_path = os.path.join(self._onnx_dir, "config.json") if os.path.exists(cfg_path): with open(cfg_path) as f: cfg = json.load(f) cfg["image_token_id"] = image_token_id with open(cfg_path, "w") as f: json.dump(cfg, f, indent=2) logger.info( "Patched image_token_id=%d into LLM config", image_token_id, ) logger.info("ONNX export complete: %s", output_path) def _export_visual_onnx(self) -> None: """Export the visual encoder to ONNX via llm_loader.""" logger.info( "Exporting visual ONNX to %s ...", self._visual_onnx_dir, ) os.makedirs(self._visual_onnx_dir, exist_ok=True) import torch _llm_loader_dir() from llm_loader.export_all_cli import (_export_visual, _load_all_weights, _load_config) config = _load_config(self._model_dir) weights = _load_all_weights(self._model_dir) _export_visual( self._model_dir, self._visual_onnx_dir, weights, config, self._model_type, torch.float16, ) logger.info( "Visual ONNX export complete: %s", self._visual_onnx_dir, ) def _build_engine(self) -> None: """Build a TensorRT engine from the ONNX directory.""" logger.info( "Building TensorRT engine: %s -> %s", self._onnx_dir, self._engine_dir, ) os.makedirs(self._engine_dir, exist_ok=True) rt = _import_runtime() config = rt.LLMBuilderConfig() config.max_input_len = self._max_input_len config.max_batch_size = self._max_batch_size config.max_kv_cache_capacity = self._max_kv_cache_capacity config.use_trt_native_ops = self._use_trt_native_ops builder = rt.LLMBuilder(self._onnx_dir, self._engine_dir, config) if not builder.build(): raise RuntimeError( f"TensorRT engine build failed. " f"ONNX dir: {self._onnx_dir}, engine dir: {self._engine_dir}") logger.info("Engine build complete: %s", self._engine_dir) def _build_visual_engine(self) -> None: """Build a TensorRT engine for the visual encoder.""" logger.info( "Building visual TensorRT engine: %s -> %s", self._visual_onnx_dir, self._visual_engine_dir, ) os.makedirs(self._visual_engine_dir, exist_ok=True) rt = _import_runtime() config = rt.VisualBuilderConfig() # Derive image token counts from vision_config vis_cfg = _read_vision_config(self._model_dir) image_size = vis_cfg.get("image_size", 448) patch_size = vis_cfg.get("patch_size", 14) if isinstance(image_size, list): image_size = image_size[0] if isinstance(patch_size, list): patch_size = patch_size[0] tokens_per_tile = (image_size // patch_size)**2 # Round up to nearest multiple of tokens_per_tile config.min_image_tokens = tokens_per_tile config.max_image_tokens = tokens_per_tile * 4 config.max_image_tokens_per_image = tokens_per_tile * 2 builder = rt.VisualBuilder( self._visual_onnx_dir, self._visual_engine_dir, config, ) if not builder.build(): raise RuntimeError(f"Visual TensorRT engine build failed. " f"ONNX dir: {self._visual_onnx_dir}, " f"engine dir: {self._visual_engine_dir}") logger.info( "Visual engine build complete: %s", self._visual_engine_dir, ) # ------------------------------------------------------------------ # Inference API (vLLM-style) # ------------------------------------------------------------------
[docs] def generate( self, prompts: Union[str, List[str], List[List[Dict[str, Any]]]], sampling_params: Optional[SamplingParams] = None, ) -> List[CompletionOutput]: """Generate completions for the given prompts. Args: prompts: A single prompt string, a list of prompt strings, or a list of OpenAI-style message lists. sampling_params: Sampling configuration. Defaults to ``SamplingParams()``. Returns: List of ``CompletionOutput`` objects, one per prompt. """ params = sampling_params or SamplingParams() if isinstance(prompts, str): prompts = [prompts] message_batches = [] for p in prompts: if isinstance(p, str): message_batches.append([{"role": "user", "content": p}]) elif isinstance(p, list): message_batches.append(p) else: raise TypeError(f"Unsupported prompt type: {type(p)}") outputs = [] for messages in message_batches: cpp_messages = _convert_messages_to_cpp(self._rt, messages) image_buffers = _load_image_buffers(self._rt, messages) request = self._rt.LLMGenerationRequest() req = self._rt.Request(messages=cpp_messages) req.image_buffers = image_buffers request.requests = [req] request.temperature = params.temperature request.top_p = params.top_p request.top_k = params.top_k request.max_generate_length = params.max_tokens request.apply_chat_template = True request.add_generation_prompt = True request.enable_thinking = params.enable_thinking request.disable_spec_decode = params.disable_spec_decode response = self._runtime.handle_request(request) text = response.output_texts[0] if response.output_texts else "" ids = response.output_ids[0] if response.output_ids else [] reason = ("length" if len(ids) >= params.max_tokens else "stop") outputs.append( CompletionOutput(text=text, token_ids=ids, finish_reason=reason)) return outputs
[docs] def chat( self, messages: List[Dict[str, Any]], sampling_params: Optional[SamplingParams] = None, ) -> CompletionOutput: """Single-turn chat completion (convenience wrapper). Args: messages: OpenAI-style message list. sampling_params: Sampling configuration. Returns: A single ``CompletionOutput``. """ return self.generate([messages], sampling_params)[0]
[docs] def generate_stream( self, messages: List[Dict[str, Any]], sampling_params: Optional[SamplingParams] = None, ) -> Generator[StreamDelta, None, None]: """Stream generation deltas for a single message list. Runs ``handleRequest`` in a background thread with a ``StreamChannel`` attached, yielding ``StreamDelta`` objects as tokens are produced. """ params = sampling_params or SamplingParams() cpp_messages = _convert_messages_to_cpp(self._rt, messages) image_buffers = _load_image_buffers(self._rt, messages) channel = self._rt.StreamChannel.create() channel.set_skip_special_tokens(True) request = self._rt.LLMGenerationRequest() req = self._rt.Request(messages=cpp_messages) req.image_buffers = image_buffers request.requests = [req] request.stream_channels = [channel] request.temperature = params.temperature request.top_p = params.top_p request.top_k = params.top_k request.max_generate_length = params.max_tokens request.apply_chat_template = True request.add_generation_prompt = True request.enable_thinking = params.enable_thinking request.disable_spec_decode = params.disable_spec_decode _FINISH_REASON_MAP = { self._rt.FinishReason.END_ID: "stop", self._rt.FinishReason.LENGTH: "length", self._rt.FinishReason.CANCELLED: "cancelled", self._rt.FinishReason.ERROR: "error", } error_holder = [None] def _run(): try: self._runtime.handle_request(request) except Exception as exc: error_holder[0] = exc channel.cancel() worker = threading.Thread(target=_run, daemon=True) worker.start() try: while True: chunk = channel.wait_pop(timeout_ms=200) if chunk is None: if channel.is_finished() or channel.is_cancelled(): break continue reason = _FINISH_REASON_MAP.get(chunk.reason) yield StreamDelta( text=chunk.text, token_ids=list(chunk.token_ids), finished=chunk.finished, finish_reason=reason, ) if chunk.finished: break finally: worker.join(timeout=5.0) if error_holder[0] is not None: raise error_holder[0]
# ------------------------------------------------------------------ # Server API # ------------------------------------------------------------------
[docs] def serve(self, host: str = "0.0.0.0", port: int = 8000) -> None: """Start an OpenAI-compatible HTTP server. Args: host: Bind address. port: Bind port. """ from .api_server import run_server run_server(self, host=host, port=port)
# ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property def model_dir(self) -> str: """Path to the resolved model checkpoint.""" return self._model_dir @property def engine_dir(self) -> str: """Path to the TensorRT engine directory.""" return self._engine_dir @property def has_draft_model(self) -> bool: """Whether Eagle speculative decoding is active.""" return self._runtime.has_draft_model()
# --------------------------------------------------------------------------- # Message conversion & image loading # --------------------------------------------------------------------------- def _convert_messages_to_cpp(rt_module, messages: List[Dict[str, Any]]): """Convert Python message dicts to C++ Message objects.""" cpp_messages = [] for msg in messages: cpp_msg = rt_module.Message() cpp_msg.role = msg["role"] content = msg["content"] contents_list = [] if isinstance(content, str): contents_list.append(rt_module.MessageContent("text", content)) elif isinstance(content, list): for item in content: if isinstance(item, str): contents_list.append(rt_module.MessageContent( "text", item)) elif isinstance(item, dict): ct = item.get("type", "text") if ct == "text": contents_list.append( rt_module.MessageContent( "text", item.get("text", ""), )) elif ct == "image": contents_list.append( rt_module.MessageContent( "image", item.get("image", ""), )) else: raise ValueError(f"Unsupported content type: {ct}") cpp_msg.contents = contents_list cpp_messages.append(cpp_msg) return cpp_messages def _load_image_buffers(rt_module, messages: List[Dict[str, Any]]): """Load image files referenced in messages into ImageData buffers.""" images = [] for msg in messages: content = msg.get("content", []) if not isinstance(content, list): continue for item in content: if not isinstance(item, dict): continue if item.get("type") == "image": path = item.get("image", "") if path and os.path.isfile(path): images.append(rt_module.load_image_from_path(path)) return images