Source code for tensorrt_llm.models.mmdit_sd3.model

# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import Any, Dict, List, Optional

from ..._utils import str_dtype_to_torch
from ...functional import (Tensor, allgather, chunk, concat, einsum, pad, shape,
                           unsqueeze)
from ...layers import LayerNorm, Linear
from ...layers.attention import DiffusersAttention
from ...layers.embedding import (CombinedTimestepTextProjEmbeddings,
                                 SD3PatchEmbed)
from ...layers.mlp import (LinearActivation, LinearApproximateGELU, LinearGEGLU,
                           LinearGELU, LinearSwiGLU)
from ...layers.normalization import (AdaLayerNormContinuous, AdaLayerNormZero,
                                     SD35AdaLayerNormZeroX)
from ...logger import logger
from ...mapping import Mapping
from ...module import Module, ModuleList
from ..model_weights_loader import ModelWeightsLoader
from ..modeling_utils import PretrainedModel
from .config import SD3Transformer2DModelConfig


class FeedForward(Module):

    def __init__(
            self,
            dim: int,
            dim_out: Optional[int] = None,
            mult: int = 4,
            activation_fn: str = "geglu",
            inner_dim=None,
            bias: bool = True,
            mapping=Mapping(),
            dtype=None,
    ):
        super().__init__()

        self.mapping = mapping
        self.dtype = dtype

        if inner_dim is None:
            inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        if activation_fn == "gelu":
            raise NotImplementedError('GELU only support tanh now.')
        if activation_fn == "gelu-approximate":
            act_fn = LinearGELU(dim,
                                inner_dim,
                                approximate="tanh",
                                bias=bias,
                                mapping=mapping,
                                dtype=dtype)
        elif activation_fn == "geglu":
            act_fn = LinearGEGLU(dim,
                                 inner_dim,
                                 approximate="tanh",
                                 bias=bias,
                                 mapping=mapping,
                                 dtype=dtype)
        elif activation_fn == "geglu-approximate":
            act_fn = LinearApproximateGELU(dim,
                                           inner_dim,
                                           bias=bias,
                                           mapping=mapping,
                                           dtype=dtype)
        elif activation_fn == "swiglu":
            act_fn = LinearSwiGLU(dim,
                                  inner_dim,
                                  bias=bias,
                                  mapping=mapping,
                                  dtype=dtype)
        elif activation_fn == "linear-silu":
            act_fn = LinearActivation(dim,
                                      inner_dim,
                                      bias=bias,
                                      activation="silu",
                                      mapping=mapping,
                                      dtype=dtype)

        self.net = ModuleList([
            act_fn,
            Linear(inner_dim,
                   dim_out,
                   bias=bias,
                   tp_group=self.mapping.tp_group,
                   tp_size=self.mapping.tp_size,
                   dtype=self.dtype)
        ])

    def forward(self, hidden_states: Tensor):
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states


class JointTransformerBlock(Module):

    def __init__(self,
                 dim: int,
                 num_attention_heads: int,
                 attention_head_dim: int,
                 context_pre_only: bool = False,
                 qk_norm: Optional[str] = None,
                 use_dual_attention: bool = False,
                 mapping=Mapping(),
                 dtype=None):
        super().__init__()

        self.use_dual_attention = use_dual_attention
        self.context_pre_only = context_pre_only
        context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

        if use_dual_attention:
            self.norm1 = SD35AdaLayerNormZeroX(dim,
                                               mapping=mapping,
                                               dtype=dtype)
        else:
            self.norm1 = AdaLayerNormZero(dim, mapping=mapping, dtype=dtype)

        if context_norm_type == "ada_norm_continous":
            self.norm1_context = AdaLayerNormContinuous(
                dim,
                dim,
                elementwise_affine=False,
                eps=1e-6,
                bias=True,
                norm_type="layer_norm",
                dtype=dtype)
        elif context_norm_type == "ada_norm_zero":
            self.norm1_context = AdaLayerNormZero(dim, dtype=dtype)
        else:
            raise ValueError(
                f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
            )

        self.attn = DiffusersAttention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            context_pre_only=context_pre_only,
            bias=True,
            qk_norm=qk_norm,
            eps=1e-6,
            mapping=mapping,
            dtype=dtype,
        )

        if use_dual_attention:
            self.attn2 = DiffusersAttention(
                query_dim=dim,
                cross_attention_dim=None,
                dim_head=attention_head_dim,
                heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                qk_norm=qk_norm,
                eps=1e-6,
                mapping=mapping,
                dtype=dtype,
            )
        else:
            self.attn2 = None

        self.norm2 = LayerNorm(dim,
                               elementwise_affine=False,
                               eps=1e-6,
                               dtype=dtype)
        self.ff = FeedForward(dim=dim,
                              dim_out=dim,
                              activation_fn="gelu-approximate",
                              mapping=mapping,
                              dtype=dtype)

        if not context_pre_only:
            self.norm2_context = LayerNorm(dim,
                                           elementwise_affine=False,
                                           eps=1e-6,
                                           dtype=dtype)
            self.ff_context = FeedForward(dim=dim,
                                          dim_out=dim,
                                          activation_fn="gelu-approximate",
                                          mapping=mapping,
                                          dtype=dtype)
        else:
            self.norm2_context = None
            self.ff_context = None

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def set_chunk_feed_forward(self,
                               chunk_size: Optional[int] = None,
                               dim: int = 0):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

    @staticmethod
    def _chunked_feed_forward(ff: Module, hidden_states: Tensor, chunk_dim: int,
                              chunk_size: int):
        # "feed_forward_chunk_size" can be used to save memory
        if hidden_states.shape[chunk_dim] % chunk_size != 0:
            raise ValueError(
                f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
            )

        num_chunks = hidden_states.shape[chunk_dim] // chunk_size
        ff_output = concat(
            [
                ff(hid_slice)
                for hid_slice in chunk(hidden_states, num_chunks, dim=chunk_dim)
            ],
            dim=chunk_dim,
        )
        return ff_output

    def forward(self,
                hidden_states: Tensor,
                encoder_hidden_states: Tensor,
                temb: Tensor,
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
                *args,
                **kwargs):
        joint_attention_kwargs = joint_attention_kwargs or {}
        if self.use_dual_attention:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
                hidden_states, emb=temb)
        else:
            norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
                hidden_states, emb=temb)

        if self.context_pre_only:
            norm_encoder_hidden_states = self.norm1_context(
                encoder_hidden_states, temb)
        else:
            norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
                encoder_hidden_states, emb=temb)

        # Attention.
        attn_output, context_attn_output = self.attn(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            **joint_attention_kwargs,
        )

        # Process attention outputs for the `hidden_states`.
        attn_output = unsqueeze(gate_msa, 1) * attn_output
        hidden_states = hidden_states + attn_output

        if self.use_dual_attention:
            attn_output2 = self.attn2(hidden_states=norm_hidden_states2,
                                      **joint_attention_kwargs)
            attn_output2 = unsqueeze(gate_msa2, 1) * attn_output2
            hidden_states = hidden_states + attn_output2

        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (
            1 + unsqueeze(scale_mlp, 1)) + unsqueeze(shift_mlp, 1)

        if self._chunk_size is not None:
            # "feed_forward_chunk_size" can be used to save memory
            ff_output = self._chunked_feed_forward(self.ff, norm_hidden_states,
                                                   self._chunk_dim,
                                                   self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)
        ff_output = unsqueeze(gate_mlp, 1) * ff_output
        hidden_states = hidden_states + ff_output

        # Process attention outputs for the `encoder_hidden_states`.
        if self.context_pre_only:
            encoder_hidden_states = None
        else:
            context_attn_output = unsqueeze(c_gate_msa, 1) * context_attn_output
            encoder_hidden_states = encoder_hidden_states + context_attn_output

            norm_encoder_hidden_states = self.norm2_context(
                encoder_hidden_states)
            norm_encoder_hidden_states = norm_encoder_hidden_states * (
                1 + unsqueeze(c_scale_mlp, 1)) + unsqueeze(c_shift_mlp, 1)
            if self._chunk_size is not None:
                # "feed_forward_chunk_size" can be used to save memory
                context_ff_output = self._chunked_feed_forward(
                    self.ff_context, norm_encoder_hidden_states,
                    self._chunk_dim, self._chunk_size)
            else:
                context_ff_output = self.ff_context(norm_encoder_hidden_states)
            encoder_hidden_states = encoder_hidden_states + unsqueeze(
                c_gate_mlp, 1) * context_ff_output

        return encoder_hidden_states, hidden_states


[docs] class SD3Transformer2DModel(PretrainedModel): config_class = SD3Transformer2DModelConfig def __init__(self, config: SD3Transformer2DModelConfig): super().__init__(config) self.quant_mode = config.quant_mode self.mapping = config.mapping self.dtype = config.dtype self.in_channels = config.in_channels default_out_channels = config.in_channels self.out_channels = config.out_channels if config.out_channels is not None else default_out_channels self.inner_dim = config.num_attention_heads * config.attention_head_dim self.pos_embed = SD3PatchEmbed( height=config.sample_size, width=config.sample_size, patch_size=config.patch_size, in_channels=self.in_channels, embed_dim=self.inner_dim, pos_embed_max_size=config. pos_embed_max_size, # hard-code as HF implementation dtype=self.dtype) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=config.pooled_projection_dim, mapping=self.mapping, dtype=self.dtype) self.context_embedder = Linear(config.joint_attention_dim, config.caption_projection_dim, tp_group=self.mapping.tp_group, tp_size=self.mapping.tp_size, dtype=self.dtype) self.transformer_blocks = ModuleList([ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=config.num_attention_heads, attention_head_dim=config.attention_head_dim, context_pre_only=(i == config.num_layers - 1), qk_norm=config.qk_norm, use_dual_attention=True if i in config.dual_attention_layers else False, mapping=self.mapping, dtype=self.dtype) for i in range(config.num_layers) ]) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=self.dtype) self.proj_out = Linear(self.inner_dim, config.patch_size * config.patch_size * self.out_channels, bias=True, tp_group=self.mapping.tp_group, tp_size=self.mapping.tp_size, dtype=self.dtype) self.skip_layers = config.skip_layers self.use_pretrained_pos_emb = config.use_pretrained_pos_emb self.config = config
[docs] def forward(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None, pooled_projections: Optional[Tensor] = None, timestep: Optional[Tensor] = None, block_controlnet_hidden_states: List[Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None): height, width = hidden_states.shape[-2:] hidden_states = self.pos_embed( hidden_states) # takes care of adding positional embeddings too. temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if self.mapping.cp_size > 1: hidden_states = chunk(hidden_states, chunks=self.mapping.cp_size, dim=1)[self.mapping.cp_rank] encoder_redundant = encoder_hidden_states.shape[ 1] % self.mapping.cp_size encoder_padding_index = tuple( [0, 0] * (encoder_hidden_states.ndim() - 2) + [0, self.mapping.cp_size - encoder_redundant]) if encoder_redundant != 0: encoder_hidden_states = pad(encoder_hidden_states, pad=encoder_padding_index) encoder_hidden_states = chunk(encoder_hidden_states, chunks=self.mapping.cp_size, dim=1)[self.mapping.cp_rank] for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if self.skip_layers is not None and index_block in self.skip_layers else False if not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len( block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[ int(index_block / interval_control)] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) if self.mapping.cp_size > 1: hidden_states = allgather(hidden_states, group=self.mapping.cp_group, gather_dim=1) # unpatchify patch_size = self.config.patch_size height = height // patch_size width = width // patch_size hidden_states = hidden_states.view( concat([ shape(hidden_states, 0), height, width, patch_size, patch_size, self.out_channels ])) hidden_states = einsum("nhwpqc->nchpwq", [hidden_states]) output = hidden_states.view( concat([ shape(hidden_states, 0), self.out_channels, height * patch_size, width * patch_size ])) output.mark_output("output") return output
[docs] def prepare_inputs(self, max_batch_size, **kwargs): def sd3_default_range(max_batch_size): return [1, max(1, (max_batch_size + 1) // 2), max_batch_size] default_range = sd3_default_range prompt_embeds_len = 256 + 77 # [NOTE] tokenizer_max_length = 77; max_sequence_length = 256 hidden_states = Tensor(name='hidden_states', dtype=self.dtype, shape=[ -1, self.in_channels, self.config.sample_size, self.config.sample_size ], dim_range=OrderedDict([ ('batch_size', [default_range(max_batch_size)]), ('in_channels', [[self.in_channels] * 3]), ('height', [[self.config.sample_size] * 3]), ('width', [[self.config.sample_size] * 3]), ])) encoder_hidden_states = Tensor( name='encoder_hidden_states', dtype=self.dtype, shape=[-1, prompt_embeds_len, self.config.joint_attention_dim], dim_range=OrderedDict([ ('batch_size', [default_range(max_batch_size)]), ('txt_len', [[prompt_embeds_len] * 3]), ('joint_attention_dim', [[self.config.joint_attention_dim] * 3 ]), ])) pooled_projections = Tensor( name='pooled_projections', dtype=self.dtype, shape=[-1, self.config.pooled_projection_dim], dim_range=OrderedDict([ ('batch_size', [default_range(max_batch_size)]), ('pooled_projection_dim', [[self.config.pooled_projection_dim] * 3]), ])) timestep = Tensor(name='timestep', dtype=self.dtype, shape=[-1], dim_range=OrderedDict([ ('batch_size', [default_range(max_batch_size)]), ])) return { "hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "pooled_projections": pooled_projections, "timestep": timestep, }
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, dtype='float16', mapping=Mapping(), **kwargs): quant_ckpt_path = kwargs.pop('quant_ckpt_path', None) from diffusers import StableDiffusion3Pipeline transformer = StableDiffusion3Pipeline.from_pretrained( pretrained_model_name_or_path, torch_dtype=str_dtype_to_torch(dtype)).transformer config = SD3Transformer2DModelConfig.from_hugging_face_config( transformer.config, dtype=dtype, mapping=mapping, **kwargs) hf_model_dir = transformer.config._name_or_path custom_dict = {} if quant_ckpt_path is not None: hf_model_dir = quant_ckpt_path loader = SD3ModelWeightsLoader(hf_model_dir, custom_dict) model = cls(config) loader.generate_tllm_weights(model) return model
[docs] def load(self, weights, from_pruned=False): required_names = set() for name, param in self.named_parameters(): if self.use_pretrained_pos_emb and 'pos_embed' in name: required_names.add(name) continue if param.is_inited(): continue if name not in weights: # Exemption for embedding sharing if name.endswith('lm_head.weight') and any( k.endswith('vocab_embedding.weight') for k in weights.keys()): continue if name.endswith('lm_head.per_channel_scale') and any( k.endswith('vocab_embedding.per_channel_scale') for k in weights.keys()): continue required_names.add(name) provided_names = set(weights.keys()) if not required_names.issubset(provided_names): raise RuntimeError( f"Required but not provided tensors:{required_names.difference(provided_names)}" ) if not provided_names.issubset(required_names): logger.warning( f"Provided but not required tensors: {provided_names.difference(required_names)}" ) for name, param in self.named_parameters(): if name in provided_names: if not from_pruned: try: param.value = weights[name] except Exception as e: raise RuntimeError( f"Encounter error '{e}' for parameter '{name}'") else: param.set_value_or_dummy(weights[name])
[docs] def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0): raise NotImplementedError()
[docs] def disable_forward_chunking(self): raise NotImplementedError()
@property def attn_processors(self): return None
[docs] def set_attn_processor(self, processor): raise NotImplementedError()
[docs] def fuse_qkv_projections(self): raise NotImplementedError()
[docs] def unfuse_qkv_projections(self): raise NotImplementedError()
def _set_gradient_checkpointing(self, module, value=False): raise NotImplementedError()
class SD3ModelWeightsLoader(ModelWeightsLoader): def translate_to_external_key(self, tllm_key: str, tllm_to_externel_key_dict: dict): """Convert and load external checkpoint into a TensorRT-LLM model. """ trtllm_to_hf_name = { r"transformer_blocks.(\d+).ff(\w*).net.1.weight": "transformer_blocks.*.ff*.net.2.weight", r"transformer_blocks.(\d+).ff(\w*).net.1.bias": "transformer_blocks.*.ff*.net.2.bias", } import re for k, v in trtllm_to_hf_name.items(): m = re.match(k, tllm_key) if m is not None: matched_pos = m.groups() placeholders = v.count('*') assert len(matched_pos) == placeholders for i in range(len(matched_pos)): v = v.replace('*', matched_pos[i], 1) return v return tllm_key