import json
import time
from pathlib import Path
# isort: off
import torch
import tensorrt as trt
from ..logger import logger
from .._utils import torch_to_numpy, trt_dtype_to_torch, mpi_world_size, mpi_rank
from ..plugin.plugin import CustomAllReduceHelper
from .generation import ModelConfig, SamplingConfig, LoraManager, GenerationSession
from ..mapping import Mapping
from .session import Session
from ..models.modeling_utils import get_kv_cache_type_from_legacy
def get_engine_name(rank):
return 'rank{}.engine'.format(rank)
def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
builder_config = config['build_config']
plugin_config = builder_config['plugin_config']
pretrained_config = config['pretrained_config']
lora_config = builder_config['lora_config']
auto_parallel_config = builder_config['auto_parallel_config']
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
remove_input_padding = plugin_config["remove_input_padding"]
use_lora_plugin = plugin_config["lora_plugin"]
tp_size = pretrained_config['mapping']['tp_size']
pp_size = pretrained_config['mapping']['pp_size']
gpus_per_node = auto_parallel_config['gpus_per_node']
world_size = tp_size * pp_size
assert world_size == mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({mpi_world_size()})'
num_heads = pretrained_config["num_attention_heads"]
hidden_size = pretrained_config["hidden_size"]
head_size = pretrained_config["head_size"]
vocab_size = pretrained_config["vocab_size"]
max_batch_size = builder_config["max_batch_size"]
max_beam_width = builder_config["max_beam_width"]
num_layers = pretrained_config["num_hidden_layers"]
num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)
assert (num_heads % tp_size) == 0
num_heads = num_heads // tp_size
hidden_size = hidden_size // tp_size
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
cross_attention = pretrained_config["architecture"] == "DecoderModel"
skip_cross_kv = pretrained_config.get('skip_cross_kv', False)
has_position_embedding = pretrained_config["has_position_embedding"]
has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
dtype = pretrained_config["dtype"]
paged_kv_cache = plugin_config['paged_kv_cache']
tokens_per_block = plugin_config['tokens_per_block']
gather_context_logits = builder_config.get('gather_context_logits', False)
gather_generation_logits = builder_config.get('gather_generation_logits',
False)
max_prompt_embedding_table_size = builder_config.get(
'max_prompt_embedding_table_size', 0)
kv_cache_type = get_kv_cache_type_from_legacy(True, paged_kv_cache)
model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
head_size=head_size,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
kv_cache_type=kv_cache_type,
tokens_per_block=tokens_per_block,
cross_attention=cross_attention,
has_position_embedding=has_position_embedding,
has_token_type_embedding=has_token_type_embedding,
dtype=dtype,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_plugin=use_lora_plugin,
lora_target_modules=lora_config.get('lora_target_modules'),
trtllm_modules_to_hf_modules=lora_config.get(
'trtllm_modules_to_hf_modules'),
skip_cross_kv=skip_cross_kv,
)
return model_config, tp_size, pp_size, gpus_per_node, dtype
[docs]
class EncDecModelRunner:
def __init__(self,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream: torch.cuda.Stream = None,
enable_context_fmha_fp32_acc: bool = None):
# in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
# accordingly, all input & output tensors should be moved to current device
# otherwise, it's default to 'cuda:0'
self.runtime_rank = mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = torch.cuda.current_device()
self.skip_encoder = skip_encoder
self.lora_task_uids = lora_task_uids
self.enable_context_fmha_fp32_acc = enable_context_fmha_fp32_acc
# when enc-dec runs by itself, stream can be None and we create new stream here
# when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_dir = Path(engine_dir)
def engine_setup(component):
# model config
config_path = engine_dir / component / "config.json"
logger.info(f"Using config path {config_path}")
model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
config_path)
# MGMN config
world_size = tp_size * pp_size
runtime_rank = mpi_rank()
assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
runtime_mapping = Mapping(world_size,
runtime_rank,
tp_size=tp_size,
pp_size=pp_size,
gpus_per_node=gpus_per_node)
# load engine
engine_fname = get_engine_name(runtime_rank)
with open(engine_dir / component / engine_fname, "rb") as f:
engine_buffer = f.read()
return model_config, runtime_mapping, engine_buffer
# Note: encoder and decoder doesn't necessarily have the same TP & PP config
if not skip_encoder:
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(
component='encoder')
self.nccl_comm = None
if self.encoder_runtime_mapping.has_pp():
# for Pipeline Parallelism in encoder
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.encoder_runtime_mapping.tp_size,
self.encoder_runtime_mapping.pp_size,
self.encoder_runtime_mapping.rank)
# session setup
self.encoder_session = Session.from_serialized_engine(
encoder_engine_buffer)
# encoder lora manager setup
if self.encoder_model_config.lora_plugin:
self.encoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.encoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.encoder_model_config,
runtime_mapping=self.encoder_runtime_mapping,
component='encoder',
)
else:
self.encoder_lora_manager = None
else:
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None
self.nccl_comm, self.encoder_session = None, None
self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(
component='decoder')
self.decoder_session = GenerationSession(self.decoder_model_config,
decoder_engine_buffer,
self.decoder_runtime_mapping,
debug_mode=debug_mode)
# decoder lora manager setup
if self.decoder_model_config.lora_plugin:
self.decoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.decoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.decoder_model_config,
runtime_mapping=self.decoder_runtime_mapping,
component='decoder',
)
else:
self.decoder_lora_manager = None
[docs]
@classmethod
def from_engine(cls,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream=None,
enable_context_fmha_fp32_acc=None):
return cls(engine_name,
engine_dir,
lora_dir,
lora_task_uids,
debug_mode=debug_mode,
skip_encoder=skip_encoder,
stream=stream,
enable_context_fmha_fp32_acc=enable_context_fmha_fp32_acc)
[docs]
def encoder_run(self,
input_ids,
input_lengths,
max_input_length,
position_ids=None,
token_type_ids=None,
debug_mode=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None):
# each engine has hidden_dim/TP, don't forget to multiply TP
hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
if input_ids.dim() == 1:
hidden_states_shape = (input_ids.shape[0], hidden_size
) # [num_tokens,D]
else:
hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],
hidden_size) # [BS,seqlen,D]
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
# input tensors. only first PP rank has id input, others are hidden_states input
inputs = {}
if self.encoder_runtime_mapping.is_first_pp_rank():
inputs['input_ids'] = input_ids.contiguous()
if self.encoder_model_config.has_position_embedding:
if position_ids is None:
if self.encoder_model_config.remove_input_padding:
position_ids = [
torch.arange(sample_length,
dtype=torch.int32,
device=input_ids.device)
for sample_length in torch_to_numpy(input_lengths)
]
position_ids = torch.cat(position_ids)
else:
bsz, seq_len = input_ids.shape[:2]
position_ids = torch.arange(
seq_len, dtype=torch.int32,
device=input_ids.device).expand(bsz, -1)
inputs['position_ids'] = position_ids.contiguous()
if self.encoder_model_config.has_token_type_embedding:
inputs['token_type_ids'] = token_type_ids.contiguous()
if self.encoder_model_config.max_prompt_embedding_table_size > 0:
inputs[
'prompt_embedding_table'] = prompt_embedding_table.contiguous(
)
inputs['tasks'] = prompt_tasks.contiguous()
inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()
else:
# just need a placeholder, engine will call NCCL to recv and fill data from previous rank
inputs['hidden_states_input'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('hidden_states_input'),
device=self.device).contiguous()
if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:
inputs['attention_mask'] = attention_mask.contiguous()
inputs['input_lengths'] = input_lengths
# use shape info to pass max length info in remove padding mode
inputs['max_input_length'] = torch.empty(
(max_input_length, ),
dtype=hidden_states_dtype('max_input_length'),
device=self.device).contiguous()
if self.encoder_runtime_mapping.tp_size > 1:
ipc_buffers, all_reduce_workspace = CustomAllReduceHelper.allocate_workspace(
self.encoder_runtime_mapping,
CustomAllReduceHelper.max_workspace_size_auto(
self.encoder_runtime_mapping.tp_size))
inputs['all_reduce_workspace'] = all_reduce_workspace
if self.encoder_model_config.lora_plugin:
inputs.update(
self.encoder_lora_manager.input_buffers(
self.lora_task_uids,
self.encoder_runtime_mapping,
self.encoder_model_config.num_layers,
))
batch_size = input_lengths.size(0)
inputs['host_request_types'] = torch.IntTensor([0] *
batch_size).to('cpu')
if self.encoder_model_config.remove_input_padding:
inputs['host_context_lengths'] = input_lengths.to('cpu')
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
self.encoder_session.set_shapes(inputs)
# output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
outputs = {}
if self.encoder_runtime_mapping.is_last_pp_rank():
outputs['encoder_output'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('encoder_output'),
device=self.device).contiguous()
else:
outputs['hidden_states_output'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('hidden_states_output'),
device=self.device).contiguous()
# -------------------------------------------
if debug_mode:
engine = self.encoder_session.engine
context = self.encoder_session.context
# setup debugging buffer for the encoder
for i in range(self.encoder_session.engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(
name
) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
dtype = engine.get_tensor_dtype(name)
shape = context.get_tensor_shape(name)
outputs[name] = torch.zeros(tuple(shape),
dtype=trt_dtype_to_torch(dtype),
device=self.device)
context.set_tensor_address(name, outputs[name].data_ptr())
# -------------------------------------------
# TRT session run
# Note: need cuda stream ID, not a torch Stream
ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed"
self.stream.synchronize()
# Tensor Parallelism is handled by model/engine definition
# But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
# After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
def pp_communicate_encoder_output(encoder_output):
if self.encoder_runtime_mapping.is_last_pp_rank():
for pp_rank in self.encoder_runtime_mapping.pp_group:
if pp_rank != self.encoder_runtime_mapping.rank:
self.nccl_comm.send(encoder_output, pp_rank)
return encoder_output
else:
self.nccl_comm.recv(encoder_output,
self.encoder_runtime_mapping.pp_group[-1])
return encoder_output
if self.encoder_runtime_mapping.has_pp():
# use hidden_states output buffer to receive output as the shapes are same
encoder_output_buf = outputs[
'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(
) else outputs['hidden_states_output']
encoder_output = pp_communicate_encoder_output(encoder_output_buf)
else:
encoder_output = outputs['encoder_output']
return encoder_output
[docs]
def generate(self,
encoder_input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=1,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
debug_mode=False,
return_dict=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
time_encoder=False,
return_encoder_output=False):
## ensure all externally provided tensors are on the correct device.
encoder_input_ids = encoder_input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
if attention_mask is not None:
attention_mask = torch.tensor(attention_mask,
dtype=torch.int32,
device=self.device)
## encoder run
encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding
encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(
encoder_input_ids, encoder_remove_input_padding, pad_token_id,
prompt_tasks)
if not self.skip_encoder:
logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
if time_encoder:
tik = time.time()
encoder_output = self.encoder_run(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
debug_mode=debug_mode,
prompt_embedding_table=prompt_embedding_table,
prompt_tasks=prompt_tasks,
prompt_vocab_size=prompt_vocab_size,
attention_mask=attention_mask)
if time_encoder:
tok = time.time()
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
else:
encoder_output = prompt_embedding_table
if encoder_input_ids.dim() > 1:
encoder_output = encoder_output.unsqueeze(0)
## decoder run
logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(
decoder_input_ids, self.decoder_model_config.remove_input_padding,
pad_token_id)
# `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
# where query_len happens to be 1 in current cases, but not necessarily always, and
# `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
# the query_len is always 1 since we have kv cache. But we use
# cross_attention_mask[:, step, :] during generation
cross_attention_mask = None
if attention_mask is not None:
cross_attention_mask = torch.tensor(attention_mask,
dtype=torch.int32,
device=self.device).reshape(
attention_mask.shape[0], 1,
attention_mask.shape[1])
cross_attention_mask = cross_attention_mask.repeat(
[1, decoder_max_input_length + max_new_tokens, 1])
# generation config
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
min_length=1,
return_dict=return_dict)
sampling_config.update(output_cum_log_probs=return_dict,
output_log_probs=return_dict)
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
max_attention_window_size=None,
encoder_max_input_length=encoder_max_input_length,
lora_manager=self.decoder_lora_manager,
lora_uids=self.lora_task_uids,
enable_context_fmha_fp32_acc=self.enable_context_fmha_fp32_acc)
output = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
return_dict=return_dict,
cross_attention_mask=cross_attention_mask)
if return_dict and return_encoder_output:
output['encoder_output'] = encoder_output
return output