Source code for sdp.processors.inference.llm.vllm.vllm

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import yaml
import json
from tqdm import tqdm

from sdp.processors.base_processor import BaseProcessor


[docs] class vLLMInference(BaseProcessor): """ A processor that performs inference using a vLLM model on entries from an input manifest. This class supports three prompt configuration modes: - a static prompt template (`prompt`) - a field in each entry containing the prompt (`prompt_field`) - a YAML file containing the prompt structure (`prompt_file`) The prompts are converted into chat-style input using a tokenizer chat template, passed to the vLLM engine for generation, and the results are written to an output manifest. Args: prompt (str, optional): A fixed prompt used for all entries. prompt_field (str, optional): The key in each entry that holds the prompt template. prompt_file (str, optional): Path to a YAML file containing the prompt structure. generation_field (str): Name of the output field to store generated text. Default is 'generation'. model (dict): Parameters to initialize the vLLM model. inference (dict): Sampling parameters passed to vLLM.SamplingParams. apply_chat_template (dict): Arguments passed to the tokenizer's `apply_chat_template` method. **kwargs: Passed to the BaseProcessor (includes `input_manifest_file` and `output_manifest_file`). Raises: ValueError: If zero or more than one prompt configuration methods are used simultaneously. Returns: A line-delimited JSON manifest where each entry includes the original fields plus a field with the generated output. .. note:: For detailed parameter options, refer to the following documentation: - model: https://docs.vllm.ai/en/latest/api/vllm/index.html#vllm.LLM - inference: https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html - apply_chat_template: https://huggingface.co/docs/transformers/main/en/chat_templating#applychattemplate Make sure to install `optree>=0.13.0` and `vllm` before using this processor: pip install "optree>=0.13.0" vllm """ def __init__(self, prompt: str = None, prompt_field: str = None, prompt_file: str = None, generation_field: str = 'generation', model: dict = {}, inference: dict = {}, apply_chat_template: dict = {}, **kwargs): from vllm import SamplingParams from transformers import AutoTokenizer super().__init__(**kwargs) self.prompt = prompt self.prompt_field = prompt_field self.generation_field = generation_field # Ensure that exactly one prompt method is used prompt_args_counter = sum([prompt is not None, prompt_field is not None, prompt_file is not None]) if prompt_args_counter < 1: raise ValueError(f'One of `prompt`, `prompt_field` or `prompt_file` should be provided.') elif prompt_args_counter > 1: err = [] if prompt: err.append(f'`prompt` ({prompt})') if prompt_field: err.append(f'`prompt_field` ({prompt_field})') if prompt_file: err.append(f'`prompt_file` ({prompt_file})') raise ValueError(f'Found more than one prompt values: {", ".join(err)}.') if prompt_file: self.prompt = self._read_prompt_file(prompt_file) self.model_params = model self.sampling_params = SamplingParams(**inference) self.chat_template_params = apply_chat_template self.tokenizer = AutoTokenizer.from_pretrained(self.model_params['model']) def _read_prompt_file(self, prompt_filepath): """Read a YAML file with a chat-style prompt template.""" with open(prompt_filepath, 'r') as prompt: return yaml.safe_load(prompt) def get_entry_prompt(self, data_entry): """Format the prompt for a single data entry using the chat template.""" entry_chat = [] prompt = self.prompt if self.prompt_field: prompt = data_entry[self.prompt_field] for role in prompt: entry_chat.append(dict( role=role, content=prompt[role].format(**data_entry) )) entry_prompt = self.tokenizer.apply_chat_template( entry_chat, **self.chat_template_params ) return entry_prompt def process(self): """Main processing function: reads entries, builds prompts, runs generation, writes results.""" from vllm import LLM entries = [] entry_prompts = [] # Read entries and generate prompts with open(self.input_manifest_file, 'r', encoding='utf8') as fin: for line in tqdm(fin, desc = "Building prompts: "): data_entry = json.loads(line) entries.append(data_entry) entry_prompt = self.get_entry_prompt(data_entry) entry_prompts.append(entry_prompt) # Run vLLM inference llm = LLM(**self.model_params) outputs = llm.generate(entry_prompts, self.sampling_params) # Write results to output manifest with open(self.output_manifest_file, 'w', encoding='utf8') as fout: for data_entry, output in tqdm(zip(entries, outputs), desc="Writing outputs: "): data_entry[self.generation_field] = output.outputs[0].text line = json.dumps(data_entry) fout.writelines(f'{line}\n')