Source code for tensorrt_llm.models.mmdit_sd3.model
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from ..._utils import str_dtype_to_torch
from ...functional import (Tensor, allgather, chunk, concat, einsum, pad, shape,
unsqueeze)
from ...layers import LayerNorm, Linear
from ...layers.attention import DiffusersAttention
from ...layers.embedding import (CombinedTimestepTextProjEmbeddings,
SD3PatchEmbed)
from ...layers.mlp import (LinearActivation, LinearApproximateGELU, LinearGEGLU,
LinearGELU, LinearSwiGLU)
from ...layers.normalization import (AdaLayerNormContinuous, AdaLayerNormZero,
SD35AdaLayerNormZeroX)
from ...logger import logger
from ...mapping import Mapping
from ...module import Module, ModuleList
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import PretrainedModel
from .config import SD3Transformer2DModelConfig
class FeedForward(Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
activation_fn: str = "geglu",
inner_dim=None,
bias: bool = True,
mapping=Mapping(),
dtype=None,
):
super().__init__()
self.mapping = mapping
self.dtype = dtype
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
raise NotImplementedError('GELU only support tanh now.')
if activation_fn == "gelu-approximate":
act_fn = LinearGELU(dim,
inner_dim,
approximate="tanh",
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "geglu":
act_fn = LinearGEGLU(dim,
inner_dim,
approximate="tanh",
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "geglu-approximate":
act_fn = LinearApproximateGELU(dim,
inner_dim,
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "swiglu":
act_fn = LinearSwiGLU(dim,
inner_dim,
bias=bias,
mapping=mapping,
dtype=dtype)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim,
inner_dim,
bias=bias,
activation="silu",
mapping=mapping,
dtype=dtype)
self.net = ModuleList([
act_fn,
Linear(inner_dim,
dim_out,
bias=bias,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
])
def forward(self, hidden_states: Tensor):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class JointTransformerBlock(Module):
def __init__(self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
mapping=Mapping(),
dtype=None):
super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim,
mapping=mapping,
dtype=dtype)
else:
self.norm1 = AdaLayerNormZero(dim, mapping=mapping, dtype=dtype)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim,
dim,
elementwise_affine=False,
eps=1e-6,
bias=True,
norm_type="layer_norm",
dtype=dtype)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim, dtype=dtype)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
self.attn = DiffusersAttention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
qk_norm=qk_norm,
eps=1e-6,
mapping=mapping,
dtype=dtype,
)
if use_dual_attention:
self.attn2 = DiffusersAttention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
qk_norm=qk_norm,
eps=1e-6,
mapping=mapping,
dtype=dtype,
)
else:
self.attn2 = None
self.norm2 = LayerNorm(dim,
elementwise_affine=False,
eps=1e-6,
dtype=dtype)
self.ff = FeedForward(dim=dim,
dim_out=dim,
activation_fn="gelu-approximate",
mapping=mapping,
dtype=dtype)
if not context_pre_only:
self.norm2_context = LayerNorm(dim,
elementwise_affine=False,
eps=1e-6,
dtype=dtype)
self.ff_context = FeedForward(dim=dim,
dim_out=dim,
activation_fn="gelu-approximate",
mapping=mapping,
dtype=dtype)
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self,
chunk_size: Optional[int] = None,
dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@staticmethod
def _chunked_feed_forward(ff: Module, hidden_states: Tensor, chunk_dim: int,
chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = concat(
[
ff(hid_slice)
for hid_slice in chunk(hidden_states, num_chunks, dim=chunk_dim)
],
dim=chunk_dim,
)
return ff_output
def forward(self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
temb: Tensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
*args,
**kwargs):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(
encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
attn_output = unsqueeze(gate_msa, 1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2,
**joint_attention_kwargs)
attn_output2 = unsqueeze(gate_msa2, 1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (
1 + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = self._chunked_feed_forward(self.ff, norm_hidden_states,
self._chunk_dim,
self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = unsqueeze(gate_mlp, 1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = unsqueeze(c_gate_msa, 1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(
encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (
1 + unsqueeze(c_scale_mlp, 1)) + unsqueeze(c_shift_mlp, 1)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = self._chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states,
self._chunk_dim, self._chunk_size)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + unsqueeze(
c_gate_mlp, 1) * context_ff_output
return encoder_hidden_states, hidden_states
[docs]
class SD3Transformer2DModel(PretrainedModel):
config_class = SD3Transformer2DModelConfig
def __init__(self, config: SD3Transformer2DModelConfig):
super().__init__(config)
self.quant_mode = config.quant_mode
self.mapping = config.mapping
self.dtype = config.dtype
self.in_channels = config.in_channels
default_out_channels = config.in_channels
self.out_channels = config.out_channels if config.out_channels is not None else default_out_channels
self.inner_dim = config.num_attention_heads * config.attention_head_dim
self.pos_embed = SD3PatchEmbed(
height=config.sample_size,
width=config.sample_size,
patch_size=config.patch_size,
in_channels=self.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=config.
pos_embed_max_size, # hard-code as HF implementation
dtype=self.dtype)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=config.pooled_projection_dim,
mapping=self.mapping,
dtype=self.dtype)
self.context_embedder = Linear(config.joint_attention_dim,
config.caption_projection_dim,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
self.transformer_blocks = ModuleList([
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
context_pre_only=(i == config.num_layers - 1),
qk_norm=config.qk_norm,
use_dual_attention=True
if i in config.dual_attention_layers else False,
mapping=self.mapping,
dtype=self.dtype) for i in range(config.num_layers)
])
self.norm_out = AdaLayerNormContinuous(self.inner_dim,
self.inner_dim,
elementwise_affine=False,
eps=1e-6,
dtype=self.dtype)
self.proj_out = Linear(self.inner_dim,
config.patch_size * config.patch_size *
self.out_channels,
bias=True,
tp_group=self.mapping.tp_group,
tp_size=self.mapping.tp_size,
dtype=self.dtype)
self.skip_layers = config.skip_layers
self.use_pretrained_pos_emb = config.use_pretrained_pos_emb
self.config = config
[docs]
def forward(self,
hidden_states: Tensor,
encoder_hidden_states: Optional[Tensor] = None,
pooled_projections: Optional[Tensor] = None,
timestep: Optional[Tensor] = None,
block_controlnet_hidden_states: List[Tensor] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None):
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(
hidden_states) # takes care of adding positional embeddings too.
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.mapping.cp_size > 1:
hidden_states = chunk(hidden_states,
chunks=self.mapping.cp_size,
dim=1)[self.mapping.cp_rank]
encoder_redundant = encoder_hidden_states.shape[
1] % self.mapping.cp_size
encoder_padding_index = tuple(
[0, 0] * (encoder_hidden_states.ndim() - 2) +
[0, self.mapping.cp_size - encoder_redundant])
if encoder_redundant != 0:
encoder_hidden_states = pad(encoder_hidden_states,
pad=encoder_padding_index)
encoder_hidden_states = chunk(encoder_hidden_states,
chunks=self.mapping.cp_size,
dim=1)[self.mapping.cp_rank]
for index_block, block in enumerate(self.transformer_blocks):
# Skip specified layers
is_skip = True if self.skip_layers is not None and index_block in self.skip_layers else False
if not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(
block_controlnet_hidden_states)
hidden_states = hidden_states + block_controlnet_hidden_states[
int(index_block / interval_control)]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
if self.mapping.cp_size > 1:
hidden_states = allgather(hidden_states,
group=self.mapping.cp_group,
gather_dim=1)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.view(
concat([
shape(hidden_states, 0), height, width, patch_size, patch_size,
self.out_channels
]))
hidden_states = einsum("nhwpqc->nchpwq", [hidden_states])
output = hidden_states.view(
concat([
shape(hidden_states, 0), self.out_channels, height * patch_size,
width * patch_size
]))
output.mark_output("output")
return output
[docs]
def prepare_inputs(self, max_batch_size, **kwargs):
def sd3_default_range(max_batch_size):
return [1, max(1, (max_batch_size + 1) // 2), max_batch_size]
default_range = sd3_default_range
prompt_embeds_len = 256 + 77 # [NOTE] tokenizer_max_length = 77; max_sequence_length = 256
hidden_states = Tensor(name='hidden_states',
dtype=self.dtype,
shape=[
-1, self.in_channels,
self.config.sample_size,
self.config.sample_size
],
dim_range=OrderedDict([
('batch_size',
[default_range(max_batch_size)]),
('in_channels', [[self.in_channels] * 3]),
('height', [[self.config.sample_size] * 3]),
('width', [[self.config.sample_size] * 3]),
]))
encoder_hidden_states = Tensor(
name='encoder_hidden_states',
dtype=self.dtype,
shape=[-1, prompt_embeds_len, self.config.joint_attention_dim],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('txt_len', [[prompt_embeds_len] * 3]),
('joint_attention_dim', [[self.config.joint_attention_dim] * 3
]),
]))
pooled_projections = Tensor(
name='pooled_projections',
dtype=self.dtype,
shape=[-1, self.config.pooled_projection_dim],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
('pooled_projection_dim',
[[self.config.pooled_projection_dim] * 3]),
]))
timestep = Tensor(name='timestep',
dtype=self.dtype,
shape=[-1],
dim_range=OrderedDict([
('batch_size', [default_range(max_batch_size)]),
]))
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"timestep": timestep,
}
[docs]
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
dtype='float16',
mapping=Mapping(),
**kwargs):
quant_ckpt_path = kwargs.pop('quant_ckpt_path', None)
from diffusers import StableDiffusion3Pipeline
transformer = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=str_dtype_to_torch(dtype)).transformer
config = SD3Transformer2DModelConfig.from_hugging_face_config(
transformer.config, dtype=dtype, mapping=mapping, **kwargs)
hf_model_dir = transformer.config._name_or_path
custom_dict = {}
if quant_ckpt_path is not None:
hf_model_dir = quant_ckpt_path
loader = SD3ModelWeightsLoader(hf_model_dir, custom_dict)
model = cls(config)
loader.generate_tllm_weights(model)
return model
[docs]
def load(self, weights, from_pruned=False):
required_names = set()
for name, param in self.named_parameters():
if self.use_pretrained_pos_emb and 'pos_embed' in name:
required_names.add(name)
continue
if param.is_inited():
continue
if name not in weights:
# Exemption for embedding sharing
if name.endswith('lm_head.weight') and any(
k.endswith('vocab_embedding.weight')
for k in weights.keys()):
continue
if name.endswith('lm_head.per_channel_scale') and any(
k.endswith('vocab_embedding.per_channel_scale')
for k in weights.keys()):
continue
required_names.add(name)
provided_names = set(weights.keys())
if not required_names.issubset(provided_names):
raise RuntimeError(
f"Required but not provided tensors:{required_names.difference(provided_names)}"
)
if not provided_names.issubset(required_names):
logger.warning(
f"Provided but not required tensors: {provided_names.difference(required_names)}"
)
for name, param in self.named_parameters():
if name in provided_names:
if not from_pruned:
try:
param.value = weights[name]
except Exception as e:
raise RuntimeError(
f"Encounter error '{e}' for parameter '{name}'")
else:
param.set_value_or_dummy(weights[name])
[docs]
def enable_forward_chunking(self,
chunk_size: Optional[int] = None,
dim: int = 0):
raise NotImplementedError()
@property
def attn_processors(self):
return None
def _set_gradient_checkpointing(self, module, value=False):
raise NotImplementedError()
class SD3ModelWeightsLoader(ModelWeightsLoader):
def translate_to_external_key(self, tllm_key: str,
tllm_to_externel_key_dict: dict):
"""Convert and load external checkpoint into a TensorRT-LLM model.
"""
trtllm_to_hf_name = {
r"transformer_blocks.(\d+).ff(\w*).net.1.weight":
"transformer_blocks.*.ff*.net.2.weight",
r"transformer_blocks.(\d+).ff(\w*).net.1.bias":
"transformer_blocks.*.ff*.net.2.bias",
}
import re
for k, v in trtllm_to_hf_name.items():
m = re.match(k, tllm_key)
if m is not None:
matched_pos = m.groups()
placeholders = v.count('*')
assert len(matched_pos) == placeholders
for i in range(len(matched_pos)):
v = v.replace('*', matched_pos[i], 1)
return v
return tllm_key