from __future__ import annotations

import json
import os
from collections import defaultdict
from typing import Any, Dict, List, NamedTuple, Optional

import torch

try:
    import pynvml
except ImportError:
    pynvml = None

from tensorrt_llm._torch.pyexecutor.model_loader import \
    validate_and_set_kv_cache_quant
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import DatasetMetadata
from tensorrt_llm.bench.dataclasses.statistics import (BenchmarkStatistics,
                                                       PercentileStats,
                                                       RequestRecord)
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.logger import Logger
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode


class PerfItemTuple(NamedTuple):
    """A tuple for recording new requests and their responses."""
    start_timestamp: int
    end_timestamp: int
    request_id: int
    num_input_tokens: int
    response_is_final: bool
    error: bool
    tokens: List[int]
    decoding_iteration: int
    time_on_first_token: int


class StatsKeeper:
    """A statistics keeper for benchmarking."""

    def __init__(self) -> None:
        self.requests: Dict[int, RequestRecord] = defaultdict(RequestRecord)
        self.num_complete: int = 0
        self.total_energy: Optional[float] = None

    def register_request(
        self,
        request_id: int,
        timestamp: int,
        num_tokens: int,
    ) -> None:
        """Register a new request.

        DEPRECATED after switching to LLM API.

        Args:
            request_id (int): Identifier of the request.
            timestamp (int): Timestamp of when the request was submitted.
            num_tokens (int): Number of input tokens in the request.
        """
        record = self.requests[request_id]
        record.num_input_tokens = num_tokens
        record.start_timestamp = timestamp

    def register_request_perf_item(self, request_perf_item: PerfItemTuple):
        """
        Register request perf items, used exclusively with LLM API.
        """
        record = self.requests[request_perf_item.request_id]
        record.id = request_perf_item.request_id
        record.num_input_tokens = request_perf_item.num_input_tokens
        record.start_timestamp = request_perf_item.start_timestamp
        record.register_event(request_perf_item.error,
                              request_perf_item.response_is_final,
                              request_perf_item.end_timestamp,
                              request_perf_item.decoding_iteration,
                              request_perf_item.tokens,
                              request_perf_item.time_on_first_token)
        if request_perf_item.response_is_final:
            self.num_complete = self.num_complete + 1

    def set_energy(self, energy: Optional[float]):
        """Set the total energy for the benchmark."""
        self.total_energy = energy

    def generate_statistics_summary(self, max_draft_tokens: int) -> None:
        """Generate summary statistics from internally stored statistics.

        Returns:
            BenchmarkStatistic: Benchmark run statistics.
        """
        total_input_tokens: int = 0
        num_requests = len(self.requests)
        start_time = float("inf")
        end_time = -1

        request_latencies = []
        generation_latencies = []
        generation_throughputs = []
        output_throughput_per_user = []

        intertoken_avg_latencies = []
        output_tokens = []
        total_decoding_iterations = 0
        ttft_times = []
        last_queue_time = 0.0
        queue_time_total = 0.0

        num_draft_tokens = []
        num_accepted_draft_tokens = []
        draft_acceptance_rate = []
        acceptance_length = []

        for entry in self.requests.values():
            start_time = min(entry.start_timestamp, start_time)
            end_time = max(entry.end_timestamp, end_time)
            last_queue_time = max(entry.start_timestamp, last_queue_time)

            request_latencies.append(entry.end_to_end_latency)
            generation_latencies.append(entry.generation_time)
            generation_throughputs.append(entry.generation_token_throughput)
            ttft_times.append(entry.time_to_first_token)
            intertoken_avg_latencies.append(entry.intertoken_latency)
            output_throughput_per_user.append(entry.output_token_throughput)
            total_decoding_iterations += entry.decode_iteration + 1

            output_tokens.append(entry.num_total_output_tokens)
            total_input_tokens += entry.num_input_tokens

            # For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request
            if max_draft_tokens > 0:
                num_draft_tokens.append(max_draft_tokens *
                                        (entry.decode_iteration + 1))
                num_accepted_draft_tokens.append(entry.num_total_output_tokens -
                                                 entry.decode_iteration - 1)
                draft_acceptance_rate.append(
                    float(num_accepted_draft_tokens[-1]) /
                    float(num_draft_tokens[-1]))
                acceptance_length.append(entry.num_total_output_tokens /
                                         (entry.decode_iteration + 1))

        global_acceptance_length = sum(
            output_tokens) / total_decoding_iterations
        queue_time_total = last_queue_time - start_time

        num_draft_tokens_percentiles = PercentileStats.from_iterable(
            num_draft_tokens) if num_draft_tokens else None
        num_accepted_draft_tokens_percentiles = PercentileStats.from_iterable(
            num_accepted_draft_tokens) if num_accepted_draft_tokens else None
        draft_acceptance_rate_percentiles = PercentileStats.from_iterable(
            draft_acceptance_rate) if draft_acceptance_rate else None
        acceptance_length_percentiles = PercentileStats.from_iterable(
            acceptance_length) if acceptance_length else None

        stats = BenchmarkStatistics(
            num_requests=num_requests,
            total_latency_ns=end_time - start_time,
            total_output_tokens=sum(output_tokens),
            total_input_tokens=total_input_tokens,
            total_energy=self.total_energy,
            request_latency_percentiles=PercentileStats.from_iterable(
                request_latencies),
            tpot_percentiles=PercentileStats.from_iterable(
                intertoken_avg_latencies),
            output_throughput_percentiles=PercentileStats.from_iterable(
                output_throughput_per_user),
            ttft_percentiles=PercentileStats.from_iterable(ttft_times),
            generation_tp_percentiles=PercentileStats.from_iterable(
                generation_throughputs),
            generation_latency_percentiles=PercentileStats.from_iterable(
                generation_latencies),
            token_percentiles=PercentileStats.from_iterable(output_tokens),
            issue_rate_ns=queue_time_total / num_requests,
            acceptance_length=global_acceptance_length,
            num_draft_tokens_percentiles=num_draft_tokens_percentiles,
            num_accepted_draft_tokens_percentiles=
            num_accepted_draft_tokens_percentiles,
            draft_acceptance_rate_percentiles=draft_acceptance_rate_percentiles,
            acceptance_length_percentiles=acceptance_length_percentiles,
        )

        return stats


class ReportUtility:
    """A utility for reporting statistics."""

    def __init__(self,
                 statistics: StatsKeeper,
                 dataset_metadata: DatasetMetadata,
                 rt_cfg: RuntimeConfig,
                 logger: Logger,
                 kwargs: Dict[str, Any],
                 streaming: bool = False) -> None:
        """Initialize the ReportingController.

        Args:
            statistics (StatsKeeper): A statistics container.
            dataset_metadata (DatasetMetadata): Metadata about the dataset.
            rt_cfg (RuntimeConfig): Configuration for the run.
            logger (Logger): A logger for logging.
            streaming (bool, optional): Streaming benchmark used. Defaults to False.
        """
        self.dataset_metadata = dataset_metadata
        self.rt_cfg = rt_cfg
        self.logger = logger
        self.kwargs = kwargs
        self.raw_statistics = statistics
        self.statistics = statistics.generate_statistics_summary(
            self.get_max_draft_len())
        self.streaming = streaming

    def _query_gpu_info(self) -> Dict[str, Any]:
        """Query first GPU info (all GPUs must be identical for TRT-LLM)."""
        if not torch.cuda.is_available():
            return None

        try:
            cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
            physical_idx = int(
                cuda_visible.split(",")[0].strip()) if cuda_visible else 0

            props = torch.cuda.get_device_properties(physical_idx)
            gpu_info = {
                "name":
                getattr(props, "name", "Unknown"),
                "memory.total":
                float(getattr(props, "total_memory", 0.0)) / (1024.0**3),
                "clocks.mem":
                None,
            }
            if pynvml:
                try:
                    # Memory clock information is not reported by torch, using NVML instead
                    pynvml.nvmlInit()
                    handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
                    clocks_mem = pynvml.nvmlDeviceGetMaxClockInfo(
                        handle, pynvml.NVML_CLOCK_MEM) / 1000.0
                    gpu_info["clocks.mem"] = clocks_mem
                except pynvml.NVMLError as e:
                    self.logger.info(
                        f"Error querying GPU clock info with NVML: {e}")
                    gpu_info["clocks.mem"] = None
        except Exception as e:
            # broad catch for any other errors, since this is a non-critical operation
            self.logger.warning(f"Error querying GPU info: {e}")
            return None
        return gpu_info

    @staticmethod
    def convert_to_ms(ns: float) -> float:
        """Convert nanoseconds to milliseconds."""
        return ns * 1.0e-6

    @staticmethod
    def convert_to_s(ns: float) -> float:
        """Convert nanoseconds to seconds."""
        return ns * 1.0e-9

    @staticmethod
    def convert_rate_to_s(rate: float) -> float:
        """Convert rate to seconds."""
        return rate * 1.0e9

    @property
    def request_throughput_req_s(self) -> float:
        """Request throughput in requests per second."""
        return self.convert_rate_to_s(self.statistics.request_throughput_ns)

    @property
    def output_throughput_tok_s(self) -> float:
        """Output throughput in tokens per second."""
        return self.convert_rate_to_s(self.statistics.output_throughput_tok_ns)

    @property
    def total_token_throughput_tok_s(self) -> float:
        """Total token throughput in tokens per second."""
        return self.convert_rate_to_s(
            self.statistics.total_token_throughput_tok_ns)

    @property
    def per_user_generation_token_throughput_s(self) -> float:
        """Output throughput per user in tokens per second."""
        return self.convert_rate_to_s(
            self.statistics.per_user_generation_token_throughput_ns)

    @property
    def per_user_output_throughput_tok_s(self) -> float:
        """Output throughput per user in tokens per second."""
        return self.convert_rate_to_s(
            self.statistics.output_throughput_tok_ns_per_user)

    def get_output_tokens(self, tokenizer) -> Dict[int, List[str]]:
        retval = {}
        for req_id, request in self.raw_statistics.requests.items():
            output_str = tokenizer.decode(request.tokens)
            retval[req_id] = output_str
        return dict(sorted(retval.items()))

    def get_request_info(self, tokenizer) -> Dict[int, List[str]]:
        requests = []
        for request in self.raw_statistics.requests.values():
            entry = request.model_dump()
            entry["output"] = tokenizer.decode(entry["tokens"])
            entry["output_tokens"] = len(entry["tokens"])
            entry.pop("tokens")
            requests.append(entry)
        return requests

    def get_statistics_dict(self) -> Dict[str, Any]:
        """Get statistics as a dictionary.

        Returns:
            Dict[str, Any]: Dictionary containing benchmark statistics.
        """
        stats_dict = {
            "engine": {
                "model": self.rt_cfg.model,
                "model_path": str(self.rt_cfg.model_path),
                "engine_dir": str(self.rt_cfg.engine_dir),
                "revision": self.rt_cfg.revision,
                "version": self.rt_cfg.sw_version,
            },
        }

        # Machine / GPU details - query only first GPU (all GPUs must be identical)
        stats_dict["machine"] = self._query_gpu_info()

        # Retrieve KV cache information.
        kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig())
        if isinstance(kv_cache_config, KvCacheConfig):
            kv_cache_dtype = kv_cache_config.dtype
            kv_cache_mem_percent = kv_cache_config.free_gpu_memory_fraction
        elif isinstance(kv_cache_config, dict):
            kv_cache_dtype = kv_cache_config.get("dtype", "auto")
            kv_cache_mem_percent = kv_cache_config.get(
                "free_gpu_memory_fraction")
        else:
            raise ValueError(
                f"Invalid kv_cache_config type: {type(kv_cache_config)}.")

        kv_cache_mem_percent = kv_cache_mem_percent \
            if kv_cache_mem_percent is not None else None

        # Engine/Backend details
        if self.rt_cfg.backend not in ('pytorch', '_autodeploy'):
            config_path = self.rt_cfg.engine_dir / "config.json"
            with open(config_path, "r") as config:
                engine_config = json.load(config)
            build_cfg = engine_config["build_config"]
            pretrain_cfg = engine_config["pretrained_config"]

            stats_dict["engine"] |= {
                "backend":
                "TRT",
                "dtype":
                pretrain_cfg["dtype"],
                "kv_cache_dtype":
                pretrain_cfg["quantization"]["kv_cache_quant_algo"],
                "quantization":
                pretrain_cfg["quantization"]["quant_algo"],
                "max_input_length":
                build_cfg["max_input_len"],
                "max_sequence_length":
                build_cfg["max_seq_len"]
            }
        else:
            from tensorrt_llm._torch.model_config import ModelConfig
            from tensorrt_llm._utils import torch_dtype_to_str

            model = self.rt_cfg.model_path or self.rt_cfg.model
            model_config = ModelConfig.from_pretrained(model,
                                                       trust_remote_code=True)

            validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)

            stats_dict["engine"] |= {
                "backend":
                "Pytorch",
                "dtype":
                torch_dtype_to_str(model_config.torch_dtype
                                   or model_config.pretrained_config.
                                   get_text_config().torch_dtype),
                "kv_cache_dtype":
                model_config.quant_config.kv_cache_quant_algo,
                "quantization":
                model_config.quant_config.quant_algo
            }

        # World and runtime info
        stats_dict["world_info"] = {
            "tp_size": self.rt_cfg.mapping["tp_size"],
            "pp_size": self.rt_cfg.mapping["pp_size"],
            "ep_size": self.rt_cfg.mapping["moe_ep_size"],
            "world_size": self.rt_cfg.mapping["world_size"],
            "max_batch_size": self.rt_cfg.settings_config.max_batch_size,
            "max_num_tokens": self.rt_cfg.settings_config.max_num_tokens,
            "scheduling_policy": self.rt_cfg.settings_config.scheduler_policy,
            "kv_cache_percentage": kv_cache_mem_percent,
            "issue_rate": self.convert_rate_to_s(self.statistics.issue_rate_ns)
        }

        # Request details
        stats_dict["request_info"] = {
            "num_requests": self.statistics.num_requests,
            "avg_num_concurrent_requests":
            self.statistics.avg_concurrent_requests,
            "avg_input_length": self.statistics.average_input_length,
            "avg_output_length": self.statistics.average_output_length
        }

        # Performance stats
        stats_dict["performance"] = {
            # End-to-End Latency (last request end - 1st request start)
            "total_latency_ms":
            self.convert_to_ms(self.statistics.total_latency_ns),
            # Average per request latency (sum request latencies / num requests)
            "avg_request_latency_ms":
            self.convert_to_ms(
                self.statistics.request_latency_percentiles.average),
            # Request throughput (num requests / end-to-end latency)
            "request_throughput_req_s":
            self.request_throughput_req_s,
            # NOTE: All mention of "output" below is in reference to OSL tokens
            # including the first token.
            # Output throughput (total output (OSL) tokens / end-to-end latency)
            "system_output_throughput_tok_s":
            self.output_throughput_tok_s,
            # Output throughput per user (average per request output throughput)
            "system_total_throughput_tok_s":
            self.total_token_throughput_tok_s,
            "output_throughput_per_user_tok_s":
            self.per_user_output_throughput_tok_s,
            # Output throughput per GPU (total throughput / world size)
            "output_throughput_per_gpu_tok_s":
            self.output_throughput_tok_s / self.rt_cfg.mapping["world_size"],
            # Request latency percentiles
            "request_latency_percentiles_ms":
            self.statistics.request_latency_percentiles.model_dump(
                exclude_none=True, by_alias=True, mode='json') | {
                    k: self.convert_to_ms(v)
                    for k, v in self.statistics.request_latency_percentiles.
                    model_dump().items()
                },
        }

        if self.statistics.total_energy is not None:
            stats_dict["energy"] = {
                "total_energy_j":
                self.statistics.total_energy,
                "output_tps_per_w":
                self.statistics.output_tps_per_w,
                "average_gpu_power":
                self.statistics.total_gpu_power /
                self.rt_cfg.mapping["world_size"]
            }

        if self.streaming:
            avg_tpot = self.convert_to_ms(
                self.statistics.per_user_time_per_output_token_ns)

            stats_dict["streaming_metrics"] = {
                # NOTE: Excludes TTFT by nature as this is a genphase calculation.
                "token_output_speed_tok_s":
                self.per_user_generation_token_throughput_s,
                # Average per request time-to-first-token (TTFT)
                "avg_ttft_ms":
                self.convert_to_ms(
                    self.statistics.per_user_time_to_first_token_ns),
                # Average per request token time-per-output-token (TPOT)
                "avg_tpot_ms":
                avg_tpot,
                # Average per request Time-per-output-token percentiles (TPOT)
                "tpot_percentiles":
                self.statistics.tpot_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json') | {
                        k: self.convert_to_ms(v)
                        for k, v in
                        self.statistics.tpot_percentiles.model_dump().items()
                    },
                # Per request Time-to-first-token percentiles (TTFT)
                "ttft_percentiles":
                self.statistics.ttft_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json') | {
                        k: self.convert_to_ms(v)
                        for k, v in
                        self.statistics.ttft_percentiles.model_dump().items()
                    },
                "gen_tps_percentiles":
                self.statistics.generation_tp_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json') | {
                        k: self.convert_rate_to_s(v)
                        for k, v in self.statistics.generation_tp_percentiles.
                        model_dump().items()
                    },
            }

        spec_decoding, decoding_mode = False, None
        if (self.rt_cfg.decoding_config
                and self.rt_cfg.decoding_config.decoding_mode
                != SpeculativeDecodingMode.NONE):
            # cpp decoding
            spec_decoding = True
            decoding_mode = self.rt_cfg.decoding_config.decoding_mode.values[1]
        elif ("speculative_config" in self.kwargs
              and self.kwargs["speculative_config"] is not None):
            # pytorch speculative decoding
            spec_decoding = True
            spec_config = self.kwargs["speculative_config"]
            # Handle both dict (from YAML) and object types
            if isinstance(spec_config, dict):
                decoding_mode = spec_config.get("decoding_type")
            else:
                decoding_mode = spec_config.decoding_type
        if (spec_decoding):
            stats_dict["decoding_stats"] = {
                "mode":
                decoding_mode,
                "num_draft_tokens_percentiles":
                self.statistics.num_draft_tokens_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json')
                if self.statistics.num_draft_tokens_percentiles else None,
                "num_accepted_draft_tokens_percentiles":
                self.statistics.num_accepted_draft_tokens_percentiles.
                model_dump(exclude_none=True, by_alias=True, mode='json') if
                self.statistics.num_accepted_draft_tokens_percentiles else None,
                "draft_acceptance_rate_percentiles":
                self.statistics.draft_acceptance_rate_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json')
                if self.statistics.draft_acceptance_rate_percentiles else None,
                "acceptance_length_percentiles":
                self.statistics.acceptance_length_percentiles.model_dump(
                    exclude_none=True, by_alias=True, mode='json')
                if self.statistics.acceptance_length_percentiles else None
            }
        # Dataset metadata
        stats_dict["dataset"] = self.dataset_metadata.model_dump(by_alias=True,
                                                                 mode='json')

        return stats_dict

    def report_statistics(self) -> None:
        """Report internal statistics about benchmark.

        Returns:
            BenchmarkStatistics: Benchmark statistics for the provided keeper.
        """
        stats_dict = self.get_statistics_dict()
        engine = stats_dict["engine"]
        machine = stats_dict.get("machine")
        world_info = stats_dict["world_info"]
        requests = stats_dict["request_info"]
        perf = stats_dict["performance"]
        streaming = stats_dict.get("streaming_metrics")
        decoding = stats_dict.get("decoding_stats", None)

        backend_info = ""
        if self.rt_cfg.backend not in ('pytorch', '_autodeploy'):
            config_path = self.rt_cfg.engine_dir / "config.json"
            with open(config_path, "r") as config:
                engine_config = json.load(config)
            build_cfg = engine_config["build_config"]
            pretrain_cfg = engine_config["pretrained_config"]

            backend_info = (
                "\n\n===========================================================\n"
                "= ENGINE DETAILS\n"
                "===========================================================\n"
                f"Model:\t\t\t{engine['model']}\n"
                f"Model Path:\t\t{engine['model_path']}\n"
                f"Revision:\t\t{engine['revision'] or 'N/A'}\n"
                f"Engine Directory:\t{engine['engine_dir']}\n"
                f"TensorRT LLM Version:\t{engine['version']}\n"
                f"Dtype:\t\t\t{pretrain_cfg['dtype']}\n"
                f"KV Cache Dtype:\t\t{pretrain_cfg['quantization']['kv_cache_quant_algo']}\n"
                f"Quantization:\t\t{pretrain_cfg['quantization']['quant_algo']}\n"
                f"Max Input Length:\t{build_cfg['max_input_len']}\n"
                f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
                f"\n")
        else:
            backend_info = (
                "\n\n===========================================================\n"
                f"= {self.rt_cfg.backend.upper()} BACKEND\n"
                "===========================================================\n"
                f"Model:\t\t\t{engine['model']}\n"
                f"Model Path:\t\t{engine['model_path']}\n"
                f"Revision:\t\t{engine['revision'] or 'N/A'}\n"
                f"TensorRT LLM Version:\t{engine['version']}\n"
                f"Dtype:\t\t\t{engine['dtype']}\n"
                f"KV Cache Dtype:\t\t{engine['kv_cache_dtype']}\n"
                f"Quantization:\t\t{engine['quantization']}\n"
                # TODO
                # f"Max Input Length:\t{build_cfg['max_input_len']}\n"
                # f"Max Sequence Length:\t{build_cfg['max_seq_len']}\n"
                f"\n")

        kv_cache_percentage = world_info.get("kv_cache_percentage", None)
        if kv_cache_percentage is not None:
            kv_cache_percentage = f"{kv_cache_percentage * 100.0:.2f}%"

        machine_info = (
            "===========================================================\n"
            "= MACHINE DETAILS \n"
            "===========================================================\n")
        if machine is None:
            machine_info += "No GPU info available\n\n"
        else:
            name = machine.get("name", "Unknown")
            mem_total_str = f"{machine['memory.total']:.2f} GB" if machine.get(
                "memory.total") is not None else "N/A"
            mem_clock_str = f"{machine['clocks.mem']:.2f} GHz" if machine.get(
                'clocks.mem') is not None else "N/A"
            machine_info += f"{name}, memory {mem_total_str}, {mem_clock_str}\n\n"

        world_info = (
            "===========================================================\n"
            "= WORLD + RUNTIME INFORMATION \n"
            "===========================================================\n"
            f"TP Size:                {world_info['tp_size']}\n"
            f"PP Size:                {world_info['pp_size']}\n"
            f"EP Size:                {world_info['ep_size']}\n"
            f"Max Runtime Batch Size: {world_info['max_batch_size']}\n"
            f"Max Runtime Tokens:     {world_info['max_num_tokens']}\n"
            f"Scheduling Policy:      {world_info['scheduling_policy']}\n"
            f"KV Memory Percentage:   {kv_cache_percentage}\n"
            f"Issue Rate (req/sec):   {world_info['issue_rate']:.4E}\n"
            f"\n")

        req_lat_info = "\n".join(
            f"[Latency] {key.upper():<7}: {perf['request_latency_percentiles_ms'][key]:.4f}"
            for key in perf['request_latency_percentiles_ms'].keys())

        request_info = (
            "===========================================================\n"
            "= REQUEST DETAILS \n"
            "===========================================================\n"
            f"Number of requests:             {requests['num_requests']}\n"
            f"Number of concurrent requests:  {requests['avg_num_concurrent_requests']:.4f}\n"
            f"Average Input Length (tokens):  {requests['avg_input_length']:.4f}\n"
            f"Average Output Length (tokens): {requests['avg_output_length']:.4f}\n"
        )

        perf_header = (
            "===========================================================\n"
            "= PERFORMANCE OVERVIEW \n"
            "===========================================================\n")

        perf_stats = (
            f"Request Throughput (req/sec):                     {perf['request_throughput_req_s']:.4f}\n"
            f"Total Output Throughput (tokens/sec):             {perf['system_output_throughput_tok_s']:.4f}\n"
            f"Total Token Throughput (tokens/sec):              {perf['system_total_throughput_tok_s']:.4f}\n"
            f"Total Latency (ms):                               {perf['total_latency_ms']:.4f}\n"
            f"Average request latency (ms):                     {perf['avg_request_latency_ms']:.4f}\n"
            # Output Throughput includes context/first token.
            f"Per User Output Throughput [w/ ctx] (tps/user):   {perf['output_throughput_per_user_tok_s']:.4f}\n"
            f"Per GPU Output Throughput (tps/gpu):              {perf['output_throughput_per_gpu_tok_s']:.4f}\n"
        )

        if streaming:
            streaming = stats_dict["streaming_metrics"]
            itl = streaming["tpot_percentiles"]
            ttft = streaming["ttft_percentiles"]

            tpot_stats = "\n".join(
                f"[TPOT] {key.upper():<7}: {itl[key]:.4f}" for key in
                ["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])

            ttft_stats = "\n".join(
                f"[TTFT] {key.upper():<7}: {ttft[key]:.4f}" for key in
                ["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])

            gen_tps_stats = "\n".join(
                f"[GTPS] {key.upper():<7}: {streaming['gen_tps_percentiles'][key]:.4f}"
                for key in
                ["minimum", "maximum", "average", "p50", "p90", "p95", "p99"])

            perf_stats += (
                f"Average time-to-first-token [TTFT] (ms):          {streaming['avg_ttft_ms']:.4f}\n"
                f"Average time-per-output-token [TPOT] (ms):        {streaming['avg_tpot_ms']:.4f}\n"
                f"Per User Output Speed (tps/user):                 {streaming['token_output_speed_tok_s']:.4f}\n"
                "\n-- Per-Request Time-per-Output-Token [TPOT] Breakdown (ms)\n\n"
                f"{tpot_stats}\n"
                "\n-- Per-Request Time-to-First-Token [TTFT] Breakdown (ms) \n\n"
                f"{ttft_stats}\n"
                "\n-- Per-Request Generation Throughput [GTPS] Breakdown (tps/user)\n\n"
                f"{gen_tps_stats}\n")

        if "energy" in stats_dict:
            energy = stats_dict["energy"]
            perf_stats += (
                "\n-- Energy Metrics --------------------------------------\n\n"
                f"Total Energy (J):                                 {energy['total_energy_j']:.4f}\n"
                f"Output Tokens per Second per Watt (tps/W):         {energy['output_tps_per_w']:.4f}\n"
                f"Average GPU Power (W):                            {energy['average_gpu_power']:.4f}\n"
            )

        perf_stats += (
            "\n-- Request Latency Breakdown (ms) -----------------------\n\n"
            f"{req_lat_info}\n")

        decoding_stats = ""
        if decoding is not None:
            decoding = stats_dict["decoding_stats"]
            if self.get_max_draft_len() > 0:
                num_draft_tokens = decoding["num_draft_tokens_percentiles"]
                num_draft_tokens_stats = "\n".join(
                    f"[DT] {key.upper():<7}: {num_draft_tokens[key]:.2f}"
                    for key in [
                        "minimum", "maximum", "average", "p50", "p90", "p95",
                        "p99"
                    ])

                num_accepted_draft_tokens = decoding[
                    "num_accepted_draft_tokens_percentiles"]
                num_accepted_draft_tokens_stats = "\n".join(
                    f"[ADT] {key.upper():<7}: {num_accepted_draft_tokens[key]:.2f}"
                    for key in [
                        "minimum", "maximum", "average", "p50", "p90", "p95",
                        "p99"
                    ])

                draft_acceptance_rate = decoding[
                    "draft_acceptance_rate_percentiles"]
                draft_acceptance_rate_stats = "\n".join(
                    f"[DAR] {key.upper():<7}: {draft_acceptance_rate[key]:.2f}"
                    for key in [
                        "minimum", "maximum", "average", "p50", "p90", "p95",
                        "p99"
                    ])

                acceptance_length = decoding["acceptance_length_percentiles"]
                acceptance_length_stats = "\n".join(
                    f"[AL] {key.upper():<7}: {acceptance_length[key]:.2f}"
                    for key in [
                        "minimum", "maximum", "average", "p50", "p90", "p95",
                        "p99"
                    ])

                decoding_stats = (
                    "===========================================================\n"
                    f"= DECODING STATISTICS ({decoding['mode']})\n"
                    "===========================================================\n"
                    "\n"
                    "-- Number of Draft Tokens Details --------------------------------\n\n"
                    "\n"
                    f"{num_draft_tokens_stats}"
                    f"\n"
                    "-- Number of Accepted Draft Tokens Details --------------------------------\n\n"
                    f"{num_accepted_draft_tokens_stats}"
                    f"\n"
                    "-- Draft Acceptance Rate Details --------------------------------\n\n"
                    f"{draft_acceptance_rate_stats}"
                    f"\n"
                    "-- Acceptance Length Details --------------------------------\n\n"
                    f"{acceptance_length_stats}"
                    f"\n"
                    "===========================================================\n"
                )

        logging_info = (f"{backend_info}"
                        f"{machine_info}"
                        f"{request_info}"
                        f"{world_info}"
                        f"{perf_header}"
                        f"{perf_stats}"
                        f"{decoding_stats}"
                        f"{self.dataset_metadata.get_summary_for_print()}")
        self.logger.info(logging_info)
        return self.statistics

    def get_max_draft_len(self) -> int:
        """Get max_draft_len from speculative_config."""
        # Try to get from speculative_config
        if ("speculative_config" in self.kwargs
                and self.kwargs["speculative_config"] is not None):
            spec_config = self.kwargs["speculative_config"]
            # Handle both dict (from YAML) and object types
            if isinstance(spec_config, dict):
                draft_len = (spec_config.get("max_draft_len")
                             or spec_config.get("num_nextn_predict_layers"))
                return draft_len or 0
            return spec_config.max_draft_len or 0

        return 0
