HF BERT: Prune, Distill & Quantize

This example shows how to compress a Hugging Face Bert large model for Question Answering using the combination of modelopt.torch.prune, modelopt.torch.distill and modelopt.torch.quantize. More specifically, we will:

  1. Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation

  2. Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation

  3. Export the quantized model to ONNX format for deployment with TensorRT

Prerequisites

  1. Install the Model Optimizer and optional torch and huggingface dependencies:

    pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
    

Note

This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.

Full code

You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub

bert_prune_distill_quantize.py
   1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from https://github.com/huggingface/
   2# transformers/blob/c52b515e948fc12ff58ad773a0385860d0162f61/examples/pytorch/question-answering
   3#
   4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
   5#
   6# Licensed under the Apache License, Version 2.0 (the "License");
   7# you may not use this file except in compliance with the License.
   8# You may obtain a copy of the License at
   9#
  10#     http://www.apache.org/licenses/LICENSE-2.0
  11#
  12# Unless required by applicable law or agreed to in writing, software
  13# distributed under the License is distributed on an "AS IS" BASIS,
  14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15# See the License for the specific language governing permissions and
  16# limitations under the License.
  17
  18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  19# SPDX-License-Identifier: MIT
  20
  21# Permission is hereby granted, free of charge, to any person obtaining a
  22# copy of this software and associated documentation files (the "Software"),
  23# to deal in the Software without restriction, including without limitation
  24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
  25# and/or sell copies of the Software, and to permit persons to whom the
  26# Software is furnished to do so, subject to the following conditions:
  27
  28# The above copyright notice and this permission notice shall be included in
  29# all copies or substantial portions of the Software.
  30
  31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  37# DEALINGS IN THE SOFTWARE.
  38"""
  39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
  40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
  41"""
  42import argparse
  43import collections
  44import json
  45import logging
  46import math
  47import os
  48import random
  49from typing import Any, Dict, List, Optional, Tuple
  50
  51import datasets
  52import evaluate
  53import numpy as np
  54import torch
  55import torch.nn as nn
  56import transformers
  57from accelerate import Accelerator
  58from accelerate.logging import get_logger
  59from accelerate.utils import set_seed
  60from torch.utils.data import DataLoader
  61from tqdm.auto import tqdm
  62from transformers import (
  63    AutoModelForQuestionAnswering,
  64    AutoTokenizer,
  65    DataCollatorWithPadding,
  66    EvalPrediction,
  67    PreTrainedTokenizer,
  68    SchedulerType,
  69    default_data_collator,
  70    get_scheduler,
  71)
  72
  73import modelopt.torch.distill as mtd
  74import modelopt.torch.opt as mto
  75import modelopt.torch.prune as mtp
  76import modelopt.torch.quantization as mtq
  77from modelopt.torch._deploy.utils import get_onnx_bytes
  78
  79logger = get_logger(__name__)
  80
  81SEED = 123
  82
  83
  84def parse_args(input_args: Optional[List[str]] = None):
  85    parser = argparse.ArgumentParser(
  86        description="Finetune a transformers model on a Question Answering task"
  87    )
  88
  89    # Training arguments
  90    parser.add_argument(
  91        "--model_name_or_path",
  92        type=str,
  93        default="bert-large-uncased-whole-word-masking-finetuned-squad",
  94        help="Path to pretrained model or model identifier from huggingface.co/models.",
  95    )
  96    parser.add_argument(
  97        "--do_train",
  98        action="store_true",
  99        help="Whether to run training / fine-tuning.",
 100    )
 101    parser.add_argument(
 102        "--per_device_train_batch_size",
 103        type=int,
 104        default=16,
 105        help="Batch size (per device) for the training dataloader.",
 106    )
 107    parser.add_argument(
 108        "--per_device_eval_batch_size",
 109        type=int,
 110        default=64,
 111        help="Batch size (per device) for the evaluation dataloader.",
 112    )
 113    parser.add_argument(
 114        "--learning_rate",
 115        type=float,
 116        default=5e-5,
 117        help="Initial learning rate (after the potential warmup period) to use.",
 118    )
 119    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
 120    parser.add_argument(
 121        "--lr_scheduler_type",
 122        type=SchedulerType,
 123        default="linear",
 124        help="The scheduler type to use.",
 125        choices=[
 126            "linear",
 127            "cosine",
 128            "cosine_with_restarts",
 129            "polynomial",
 130            "constant",
 131            "constant_with_warmup",
 132        ],
 133    )
 134    parser.add_argument(
 135        "--num_warmup_steps",
 136        type=int,
 137        default=0,
 138        help="Number of steps for the warmup in the lr scheduler.",
 139    )
 140    parser.add_argument(
 141        "--num_train_epochs",
 142        type=float,
 143        default=2.0,
 144        help="Total number of training epochs to perform.",
 145    )
 146    parser.add_argument(
 147        "--max_train_steps",
 148        type=int,
 149        default=None,
 150        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
 151    )
 152    parser.add_argument(
 153        "--gradient_accumulation_steps",
 154        type=int,
 155        default=1,
 156        help="Number of updates steps to accumulate before performing a backward/update pass.",
 157    )
 158    parser.add_argument(
 159        "--preprocessing_num_workers",
 160        type=int,
 161        default=4,
 162        help="The number of processes to use for preprocessing the dataset.",
 163    )
 164
 165    # Logging and checkpointing arguments
 166    parser.add_argument(
 167        "--output_dir", type=str, default="results", help="Where to store the optimized models."
 168    )
 169    parser.add_argument(
 170        "--with_tracking",
 171        action="store_true",
 172        help="Whether to enable experiment trackers for logging.",
 173    )
 174    parser.add_argument(
 175        "--checkpointing_steps",
 176        type=str,
 177        default=None,
 178        help=(
 179            "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
 180            " each epoch."
 181        ),
 182    )
 183    parser.add_argument(
 184        "--resume_from_last_ckpt",
 185        action="store_true",
 186        help="If the training should continue from the latest checkpoint in output_dir.",
 187    )
 188
 189    # Misc arguments for Bert (should not be modified in most cases)
 190    parser.add_argument(
 191        "--max_seq_length",
 192        type=int,
 193        default=384,
 194        help=(
 195            "The maximum total input sequence length after tokenization. Sequences longer than this"
 196            " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
 197        ),
 198    )
 199    parser.add_argument(
 200        "--pad_to_max_length",
 201        action="store_true",
 202        help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
 203    )
 204    parser.add_argument(
 205        "--doc_stride",
 206        type=int,
 207        default=128,
 208        help=(
 209            "When splitting up a long document into chunks how much stride to take between chunks."
 210        ),
 211    )
 212    parser.add_argument(
 213        "--n_best_size",
 214        type=int,
 215        default=20,
 216        help="The total number of n-best predictions to generate when looking for an answer.",
 217    )
 218    parser.add_argument(
 219        "--max_answer_length",
 220        type=int,
 221        default=30,
 222        help=(
 223            "The maximum length of an answer that can be generated. This is needed because the"
 224            " start and end predictions are not conditioned on one another."
 225        ),
 226    )
 227
 228    # Debugging arguments
 229    parser.add_argument(
 230        "--max_train_samples",
 231        type=int,
 232        default=None,
 233        help="For debugging purposes or quicker training.",
 234    )
 235    parser.add_argument(
 236        "--max_eval_samples",
 237        type=int,
 238        default=None,
 239        help="For debugging purposes or quicker training.",
 240    )
 241
 242    # Model Optimizer: pruning arguments
 243    parser.add_argument(
 244        "--do_modelopt_prune",
 245        action="store_true",
 246        help="Whether or not to use Model Optimizer pruning.",
 247    )
 248    parser.add_argument(
 249        "--modelopt_prune_flops_percent",
 250        type=float,
 251        default=None,
 252        help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
 253    )
 254
 255    # Model Optimizer: quantization arguments
 256    parser.add_argument(
 257        "--modelopt_quantize_cfg",
 258        help="Model Optimizer quantization config.",
 259        choices=mtq.config.choices,
 260    )
 261
 262    # Model Optimizer: Distillation arguments
 263    parser.add_argument(
 264        "--do_modelopt_distill",
 265        action="store_true",
 266        help="Whether or not to use distillation. A teacher model must be specified.",
 267    )
 268    parser.add_argument(
 269        "--temperature",
 270        type=float,
 271        default=2.0,
 272        help="The temperature to use when distilling.",
 273    )
 274
 275    # Model Optimizer: save and restore arguments
 276    parser.add_argument(
 277        "--modelopt_save_file",
 278        type=str,
 279        default=None,
 280        help="File (inside output_dir) to save the modelopt modified model to.",
 281    )
 282    parser.add_argument(
 283        "--modelopt_restore_path",
 284        type=str,
 285        default=None,
 286        help="Path to restore the modelopt modified model from.",
 287    )
 288
 289    # ONNX export arguments
 290    parser.add_argument(
 291        "--onnx_export_file",
 292        type=str,
 293        default=None,
 294        help="File (inside output_dir) to export the ONNX model to.",
 295    )
 296
 297    args = parser.parse_args(input_args)
 298
 299    # Sanity checks
 300    if args.do_modelopt_prune and not args.modelopt_prune_flops_percent:
 301        raise ValueError(
 302            "Need a `modelopt_prune_flops_percent` when `do_modelopt_prune` is passed."
 303        )
 304
 305    return args
 306
 307
 308def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
 309    """Get the examples, dataset, dataloader, answer_column_name
 310
 311    You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 312    or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
 313    (the dataset will be downloaded automatically from the datasets Hub).
 314
 315    For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
 316    'text' is found. You can easily tweak this behavior (see below).
 317    """
 318
 319    def prepare_train_features(examples):
 320        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 321        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 322        # left whitespace
 323        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 324
 325        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 326        # in one example possible giving several features when a context is long, each of those features having a
 327        # context that overlaps a bit the context of the previous feature.
 328        tokenized_examples = tokenizer(
 329            examples[question_column_name if pad_on_right else context_column_name],
 330            examples[context_column_name if pad_on_right else question_column_name],
 331            truncation="only_second" if pad_on_right else "only_first",
 332            max_length=max_seq_length,
 333            stride=args.doc_stride,
 334            return_overflowing_tokens=True,
 335            return_offsets_mapping=True,
 336            padding="max_length" if args.pad_to_max_length else False,
 337        )
 338
 339        # Since one example might give us several features if it has a long context, we need a map from a feature to
 340        # its corresponding example. This key gives us just that.
 341        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 342        # The offset mappings will give us a map from token to character position in the original context. This will
 343        # help us compute the start_positions and end_positions.
 344        offset_mapping = tokenized_examples.pop("offset_mapping")
 345
 346        # Let's label those examples!
 347        tokenized_examples["start_positions"] = []
 348        tokenized_examples["end_positions"] = []
 349
 350        for i, offsets in enumerate(offset_mapping):
 351            # We will label impossible answers with the index of the CLS token.
 352            input_ids = tokenized_examples["input_ids"][i]
 353            cls_index = input_ids.index(tokenizer.cls_token_id)
 354
 355            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 356            sequence_ids = tokenized_examples.sequence_ids(i)
 357
 358            # One example can give several spans, this is the index of the example containing this span of text.
 359            sample_index = sample_mapping[i]
 360            answers = examples[answer_column_name][sample_index]
 361            # If no answers are given, set the cls_index as answer.
 362            if len(answers["answer_start"]) == 0:
 363                tokenized_examples["start_positions"].append(cls_index)
 364                tokenized_examples["end_positions"].append(cls_index)
 365            else:
 366                # Start/end character index of the answer in the text.
 367                start_char = answers["answer_start"][0]
 368                end_char = start_char + len(answers["text"][0])
 369
 370                # Start token index of the current span in the text.
 371                token_start_index = 0
 372                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
 373                    token_start_index += 1
 374
 375                # End token index of the current span in the text.
 376                token_end_index = len(input_ids) - 1
 377                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
 378                    token_end_index -= 1
 379
 380                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
 381                if not (
 382                    offsets[token_start_index][0] <= start_char
 383                    and offsets[token_end_index][1] >= end_char
 384                ):
 385                    tokenized_examples["start_positions"].append(cls_index)
 386                    tokenized_examples["end_positions"].append(cls_index)
 387                else:
 388                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
 389                    # Note: we could go after the last offset if the answer is the last word (edge case).
 390                    while (
 391                        token_start_index < len(offsets)
 392                        and offsets[token_start_index][0] <= start_char
 393                    ):
 394                        token_start_index += 1
 395                    tokenized_examples["start_positions"].append(token_start_index - 1)
 396                    while offsets[token_end_index][1] >= end_char:
 397                        token_end_index -= 1
 398                    tokenized_examples["end_positions"].append(token_end_index + 1)
 399
 400        return tokenized_examples
 401
 402    def prepare_validation_features(examples):
 403        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 404        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 405        # left whitespace
 406        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 407
 408        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 409        # in one example possible giving several features when a context is long, each of those features having a
 410        # context that overlaps a bit the context of the previous feature.
 411        tokenized_examples = tokenizer(
 412            examples[question_column_name if pad_on_right else context_column_name],
 413            examples[context_column_name if pad_on_right else question_column_name],
 414            truncation="only_second" if pad_on_right else "only_first",
 415            max_length=max_seq_length,
 416            stride=args.doc_stride,
 417            return_overflowing_tokens=True,
 418            return_offsets_mapping=True,
 419            padding="max_length" if args.pad_to_max_length else False,
 420        )
 421
 422        # Since one example might give us several features if it has a long context, we need a map from a feature to
 423        # its corresponding example. This key gives us just that.
 424        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 425
 426        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
 427        # corresponding example_id and we will store the offset mappings.
 428        tokenized_examples["example_id"] = []
 429
 430        for i in range(len(tokenized_examples["input_ids"])):
 431            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 432            sequence_ids = tokenized_examples.sequence_ids(i)
 433            context_index = 1 if pad_on_right else 0
 434
 435            # One example can give several spans, this is the index of the example containing this span of text.
 436            sample_index = sample_mapping[i]
 437            tokenized_examples["example_id"].append(examples["id"][sample_index])
 438
 439            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
 440            # position is part of the context or not.
 441            tokenized_examples["offset_mapping"][i] = [
 442                (o if sequence_ids[k] == context_index else None)
 443                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
 444            ]
 445
 446        return tokenized_examples
 447
 448    examples, dataset, dataloader = {}, {}, {}
 449
 450    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
 451    # download the dataset.
 452    # Downloading and loading a dataset from the hub.
 453    raw_datasets = datasets.load_dataset("squad")
 454    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
 455    # https://huggingface.co/docs/datasets/loading_datasets.
 456
 457    # Preprocessing the datasets.
 458    # Preprocessing is slighlty different for training and evaluation.
 459
 460    column_names = raw_datasets["train"].column_names
 461
 462    question_column_name = "question" if "question" in column_names else column_names[0]
 463    context_column_name = "context" if "context" in column_names else column_names[1]
 464    answer_column_name = "answers" if "answers" in column_names else column_names[2]
 465
 466    # Padding side determines if we do (question|context) or (context|question).
 467    pad_on_right = tokenizer.padding_side == "right"
 468
 469    if args.max_seq_length > tokenizer.model_max_length:
 470        logger.warning(
 471            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
 472            f" for the model ({tokenizer.model_max_length}). Using"
 473            f" max_seq_length={tokenizer.model_max_length}."
 474        )
 475
 476    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
 477
 478    examples["train"] = raw_datasets["train"]
 479    if args.max_train_samples is not None:
 480        # We will select sample from whole data if agument is specified
 481        examples["train"] = examples["train"].select(range(args.max_train_samples))
 482
 483    # Create train feature from dataset
 484    with accelerator.main_process_first():
 485        dataset["train"] = examples["train"].map(
 486            prepare_train_features,
 487            batched=True,
 488            num_proc=args.preprocessing_num_workers,
 489            remove_columns=column_names,
 490            load_from_cache_file=True,
 491            desc="Running tokenizer on train dataset",
 492        )
 493        # if args.max_train_samples is not None:
 494        #     # Number of samples might increase during Feature Creation, We select only specified max samples
 495        #     dataset["train"] = dataset["train"].select(range(args.max_train_samples))
 496
 497    examples["eval"] = raw_datasets["validation"]
 498    if args.max_eval_samples is not None:
 499        # We will select sample from whole data
 500        examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
 501    # Validation Feature Creation
 502    with accelerator.main_process_first():
 503        dataset["eval"] = examples["eval"].map(
 504            prepare_validation_features,
 505            batched=True,
 506            num_proc=args.preprocessing_num_workers,
 507            remove_columns=column_names,
 508            load_from_cache_file=True,
 509            desc="Running tokenizer on validation dataset",
 510        )
 511        # if args.max_eval_samples is not None:
 512        #     # During Feature creation dataset samples might increase, we will select required samples again
 513        #     dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
 514
 515    # Log a random sample from the training set:
 516    for index in random.sample(range(len(dataset["train"])), 1):
 517        logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
 518
 519    # DataLoaders creation:
 520    if args.pad_to_max_length:
 521        # If padding was already done ot max length, we use the default data collator that will just convert everything
 522        # to tensors.
 523        data_collator = default_data_collator
 524    else:
 525        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
 526        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
 527        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
 528        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
 529
 530    dataloader["train"] = DataLoader(
 531        dataset["train"],
 532        shuffle=True,
 533        collate_fn=data_collator,
 534        batch_size=args.per_device_train_batch_size,
 535    )
 536
 537    dataloader["eval"] = DataLoader(
 538        dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
 539        collate_fn=data_collator,
 540        batch_size=args.per_device_eval_batch_size,
 541    )
 542
 543    return examples, dataset, dataloader, answer_column_name
 544
 545
 546def evaluate_model(
 547    args,
 548    model: nn.Module,
 549    accelerator: Accelerator,
 550    eval_examples: Any,
 551    eval_dataset: Any,
 552    eval_dataloader: DataLoader,
 553    answer_column_name: str,
 554    prefix: str = "Eval",
 555):
 556    def create_and_fill_np_array(start_or_end_logits, max_len):
 557        """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
 558
 559        Args:
 560            start_or_end_logits: This is the output predictions of the model.
 561                We can only enter either start or end logits.
 562            max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
 563        """
 564        step = 0
 565        # create a numpy array and fill it with -100.
 566        logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
 567        # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
 568        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
 569            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
 570            # And after every iteration we have to change the step
 571            batch_size = output_logit.shape[0]
 572            cols = output_logit.shape[1]
 573
 574            if step + batch_size < len(eval_dataset):
 575                logits_concat[step : step + batch_size, :cols] = output_logit
 576            else:
 577                logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
 578
 579            step += batch_size
 580
 581        return logits_concat
 582
 583    def postprocess_qa_predictions(
 584        examples,
 585        features,
 586        predictions: Tuple[np.ndarray, np.ndarray],
 587        version_2_with_negative: bool = False,
 588        n_best_size: int = 20,
 589        max_answer_length: int = 30,
 590        null_score_diff_threshold: float = 0.0,
 591        output_dir: Optional[str] = None,
 592        prefix: Optional[str] = None,
 593    ) -> EvalPrediction:
 594        """Post-processes the predictions of a question-answering model to convert them to answers
 595        that are substrings of  the original contexts. This is the base postprocessing functions for
 596        models that only return start and end logits.
 597
 598        Args:
 599            examples: The non-preprocessed dataset.
 600            features: The processed dataset.
 601            predictions: The predictions of the model: two arrays containing the start logits and the end logits
 602                respectively. Its first dimension must match the number of elements of `features`.
 603            version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
 604            n_best_size: The total number of n-best predictions to generate when looking for an answer.
 605            max_answer_length: The maximum length of an answer that can be generated. This is needed
 606                because the start and end predictions are not conditioned on one another.
 607            null_score_diff_threshold: The threshold used to select the null answer: if the best answer
 608                has a score that is less than the score of the null answer minus this threshold, the
 609                null answer is selected for this example (note that the score of the null answer for
 610                an example giving several features is the minimum of the scores for the null answer on
 611                each feature: all features must be aligned on the fact they `want` to predict a null answer).
 612                Only useful when `version_2_with_negative` is `True`.
 613            output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
 614                and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
 615                answers, are saved in `output_dir`.
 616            prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
 617        """
 618        if len(predictions) != 2:
 619            raise ValueError(
 620                "`predictions` should be a tuple with two elements (start_logits, end_logits)."
 621            )
 622        all_start_logits, all_end_logits = predictions
 623
 624        if len(predictions[0]) != len(features):
 625            raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
 626
 627        # Build a map example to its corresponding features.
 628        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
 629        features_per_example = collections.defaultdict(list)
 630        for i, feature in enumerate(features):
 631            features_per_example[example_id_to_index[feature["example_id"]]].append(i)
 632
 633        # The dictionaries we have to fill.
 634        all_predictions = collections.OrderedDict()
 635        all_nbest_json = collections.OrderedDict()
 636        if version_2_with_negative:
 637            scores_diff_json = collections.OrderedDict()
 638
 639        logger.debug(
 640            f"Post-processing {len(examples)} example predictions split into"
 641            f" {len(features)} features."
 642        )
 643
 644        # Let's loop over all the examples!
 645        for example_index, example in enumerate(examples):
 646            # Those are the indices of the features associated to the current example.
 647            feature_indices = features_per_example[example_index]
 648
 649            min_null_prediction = None
 650            prelim_predictions = []
 651
 652            # Looping through all the features associated to the current example.
 653            for feature_index in feature_indices:
 654                # We grab the predictions of the model for this feature.
 655                start_logits = all_start_logits[feature_index]
 656                end_logits = all_end_logits[feature_index]
 657                # This is what will allow us to map some the positions in our logits to span of texts in the original
 658                # context.
 659                offset_mapping = features[feature_index]["offset_mapping"]
 660                # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
 661                # context available in the current feature.
 662                token_is_max_context = features[feature_index].get("token_is_max_context", None)
 663
 664                # Update minimum null prediction.
 665                feature_null_score = start_logits[0] + end_logits[0]
 666                if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
 667                    min_null_prediction = {
 668                        "offsets": (0, 0),
 669                        "score": feature_null_score,
 670                        "start_logit": start_logits[0],
 671                        "end_logit": end_logits[0],
 672                    }
 673
 674                # Go through all possibilities for the `n_best_size` greater start and end logits.
 675                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
 676                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
 677                for start_index in start_indexes:
 678                    for end_index in end_indexes:
 679                        # Don't consider out-of-scope answers, either because the indices are out of bounds or
 680                        # correspond to part of the input_ids that are not in the context.
 681                        if (
 682                            start_index >= len(offset_mapping)
 683                            or end_index >= len(offset_mapping)
 684                            or offset_mapping[start_index] is None
 685                            or len(offset_mapping[start_index]) < 2
 686                            or offset_mapping[end_index] is None
 687                            or len(offset_mapping[end_index]) < 2
 688                        ):
 689                            continue
 690                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
 691                        if (
 692                            end_index < start_index
 693                            or end_index - start_index + 1 > max_answer_length
 694                        ):
 695                            continue
 696                        # Don't consider answer that don't have the maximum context available (if such information is
 697                        # provided).
 698                        if token_is_max_context is not None and not token_is_max_context.get(
 699                            str(start_index), False
 700                        ):
 701                            continue
 702
 703                        prelim_predictions.append(
 704                            {
 705                                "offsets": (
 706                                    offset_mapping[start_index][0],
 707                                    offset_mapping[end_index][1],
 708                                ),
 709                                "score": start_logits[start_index] + end_logits[end_index],
 710                                "start_logit": start_logits[start_index],
 711                                "end_logit": end_logits[end_index],
 712                            }
 713                        )
 714            if version_2_with_negative and min_null_prediction is not None:
 715                # Add the minimum null prediction
 716                prelim_predictions.append(min_null_prediction)
 717                null_score = min_null_prediction["score"]
 718
 719            # Only keep the best `n_best_size` predictions.
 720            n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
 721                :n_best_size
 722            ]
 723
 724            # Add back the minimum null prediction if it was removed because of its low score.
 725            if (
 726                version_2_with_negative
 727                and min_null_prediction is not None
 728                and not any(p["offsets"] == (0, 0) for p in n_best_preds)
 729            ):
 730                n_best_preds.append(min_null_prediction)
 731
 732            # Use the offsets to gather the answer text in the original context.
 733            context = example["context"]
 734            for pred in n_best_preds:
 735                offsets = pred.pop("offsets")
 736                pred["text"] = context[offsets[0] : offsets[1]]
 737
 738            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
 739            # failure.
 740            if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
 741                n_best_preds.insert(
 742                    0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
 743                )
 744
 745            # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
 746            # using the LogSumExp trick).
 747            scores = np.array([pred.pop("score") for pred in n_best_preds])
 748            exp_scores = np.exp(scores - np.max(scores))
 749            probs = exp_scores / exp_scores.sum()
 750
 751            # Include the probabilities in our n_best_preds.
 752            for prob, pred in zip(probs, n_best_preds):
 753                pred["probability"] = prob
 754
 755            # Pick the best prediction. If the null answer is not possible, this is easy.
 756            if not version_2_with_negative:
 757                all_predictions[example["id"]] = n_best_preds[0]["text"]
 758            else:
 759                # Otherwise we first need to find the best non-empty prediction.
 760                i = 0
 761                while n_best_preds[i]["text"] == "":
 762                    i += 1
 763                best_non_null_pred = n_best_preds[i]
 764
 765                # Then we compare to the null prediction using the threshold.
 766                score_diff = (
 767                    null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
 768                )
 769                scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
 770                if score_diff > null_score_diff_threshold:
 771                    all_predictions[example["id"]] = ""
 772                else:
 773                    all_predictions[example["id"]] = best_non_null_pred["text"]
 774
 775            # Make `n_best_preds` JSON-serializable by casting np.float back to float.
 776            all_nbest_json[example["id"]] = [
 777                {
 778                    k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
 779                    for k, v in pred.items()
 780                }
 781                for pred in n_best_preds
 782            ]
 783
 784        # If we have an output_dir, let's save all those dicts.
 785        if output_dir is not None:
 786            if not os.path.isdir(output_dir):
 787                raise EnvironmentError(f"{output_dir} is not a directory.")
 788
 789            prediction_file = os.path.join(
 790                output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
 791            )
 792            nbest_file = os.path.join(
 793                output_dir,
 794                "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
 795            )
 796            if version_2_with_negative:
 797                null_odds_file = os.path.join(
 798                    output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
 799                )
 800
 801            logger.info(f"Saving predictions to {prediction_file}.")
 802            with open(prediction_file, "w") as writer:
 803                writer.write(json.dumps(all_predictions, indent=4) + "\n")
 804            logger.info(f"Saving nbest_preds to {nbest_file}.")
 805            with open(nbest_file, "w") as writer:
 806                writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
 807            if version_2_with_negative:
 808                logger.info(f"Saving null_odds to {null_odds_file}.")
 809                with open(null_odds_file, "w") as writer:
 810                    writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
 811
 812        # Format the result to the format the metric expects.
 813        formatted_predictions = [
 814            {"id": k, "prediction_text": v} for k, v in all_predictions.items()
 815        ]
 816
 817        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
 818        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
 819
 820    logger.info(f"***** Running {prefix} *****")
 821    logger.info(f"  Num examples = {len(eval_dataset)}")
 822    logger.info(f"  Batch size = {args.per_device_eval_batch_size}")
 823
 824    model.eval()
 825    all_start_logits = []
 826    all_end_logits = []
 827    for _, batch in enumerate(eval_dataloader):
 828        with torch.no_grad():
 829            outputs = model(**batch)
 830            start_logits = outputs.start_logits
 831            end_logits = outputs.end_logits
 832
 833            if (
 834                not args.pad_to_max_length
 835            ):  # necessary to pad predictions and labels for being gathered
 836                start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
 837                end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
 838
 839            all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
 840            all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
 841
 842    # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
 843    if args.do_modelopt_distill:
 844        model.module.compute_kd_loss()
 845
 846    max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
 847
 848    # concatenate the numpy array
 849    start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
 850    end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
 851
 852    outputs_numpy = (start_logits_concat, end_logits_concat)
 853    prediction = postprocess_qa_predictions(
 854        examples=eval_examples,
 855        features=eval_dataset,
 856        predictions=outputs_numpy,
 857        n_best_size=args.n_best_size,
 858        max_answer_length=args.max_answer_length,
 859        # output_dir=args.output_dir,
 860        prefix=prefix,
 861    )
 862
 863    metric = evaluate.load("squad")
 864    eval_metric = metric.compute(
 865        predictions=prediction.predictions, references=prediction.label_ids
 866    )
 867    logger.info(f"{prefix} metrics: {eval_metric}\n")
 868    return eval_metric
 869
 870
 871# Model Optimizer: Define a teacher factory for initializing the distillation model
 872def teacher_factory(model_name_or_path):
 873    return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
 874
 875
 876# Model Optimizer: Define a custom distillation loss function that uses start and end logits
 877class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
 878
 879    def forward(self, outputs_s, outputs_t):
 880        loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
 881        loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
 882        loss = (loss_start + loss_end) / 2.0
 883
 884        return loss
 885
 886
 887def train_and_evaluate_model(
 888    args,
 889    model: nn.Module,
 890    tokenizer: PreTrainedTokenizer,
 891    accelerator: Accelerator,
 892    examples: Dict,
 893    dataset: Dict,
 894    dataloader: Dict[str, DataLoader],
 895    answer_column_name,
 896):
 897    # Optimizer
 898    # Split weights in two groups, one with weight decay and the other not.
 899    no_decay = ["bias", "LayerNorm.weight"]
 900    optimizer_grouped_parameters = [
 901        {
 902            "params": [
 903                p
 904                for n, p in model.named_parameters()
 905                if p.requires_grad and not any(nd in n for nd in no_decay)
 906            ],
 907            "weight_decay": args.weight_decay,
 908        },
 909        {
 910            "params": [
 911                p
 912                for n, p in model.named_parameters()
 913                if p.requires_grad and any(nd in n for nd in no_decay)
 914            ],
 915            "weight_decay": 0.0,
 916        },
 917    ]
 918    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 919
 920    # Scheduler and math around the number of training steps.
 921    overrode_max_train_steps = False
 922    num_update_steps_per_epoch = math.ceil(
 923        len(dataloader["train"]) / args.gradient_accumulation_steps
 924    )
 925    if args.max_train_steps is None:
 926        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 927        overrode_max_train_steps = True
 928
 929    lr_scheduler = get_scheduler(
 930        name=args.lr_scheduler_type,
 931        optimizer=optimizer,
 932        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
 933        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
 934    )
 935
 936    # Prepare everything with our `accelerator`.
 937    model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
 938        model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
 939    )
 940
 941    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 942    num_update_steps_per_epoch = math.ceil(
 943        len(dataloader["train"]) / args.gradient_accumulation_steps
 944    )
 945    if overrode_max_train_steps:
 946        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 947    # Afterwards we recalculate our number of training epochs
 948    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 949
 950    # Figure out how many steps we should save the Accelerator states
 951    checkpointing_steps = args.checkpointing_steps
 952    if checkpointing_steps is not None and checkpointing_steps.isdigit():
 953        checkpointing_steps = int(checkpointing_steps)
 954
 955    # We need to initialize the trackers we use, and also store our configuration.
 956    # The trackers initializes automatically on the main process.
 957    if args.with_tracking:
 958        experiment_config = vars(args)
 959        # TensorBoard cannot log Enums, need the raw value
 960        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
 961        accelerator.init_trackers("tensorboard", experiment_config)
 962
 963    # Train!
 964    total_batch_size = (
 965        args.per_device_train_batch_size
 966        * accelerator.num_processes
 967        * args.gradient_accumulation_steps
 968    )
 969
 970    logger.info("***** Running training *****")
 971    logger.info(f"  Num examples = {len(dataset['train'])}")
 972    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 973    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
 974    logger.info(
 975        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
 976    )
 977    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 978    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 979
 980    # Only show the progress bar once on each machine.
 981    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
 982    completed_steps = 0
 983    starting_epoch = 0
 984    resume_step = None
 985
 986    # Potentially load in the weights and states from a previous save
 987    if args.resume_from_last_ckpt:
 988        # Get the most recent checkpoint
 989        dirs = [
 990            f.path
 991            for f in os.scandir(args.output_dir)
 992            if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
 993        ]
 994        if len(dirs) == 0:
 995            logger.warning("No checkpoint found in output_dir. Training from scratch!")
 996        else:
 997            latest_dir = max(dirs, key=os.path.getctime)
 998            accelerator.load_state(latest_dir)
 999
1000            # Extract `epoch_{i}` or `step_{i}`
1001            latest_dir = os.path.basename(latest_dir)
1002            if "epoch" in latest_dir:
1003                starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1004                completed_steps = starting_epoch * num_update_steps_per_epoch
1005            else:
1006                # need to multiply `gradient_accumulation_steps` to reflect real steps
1007                resume_step = (
1008                    int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1009                )
1010                starting_epoch = resume_step // len(dataloader["train"])
1011                completed_steps = resume_step // args.gradient_accumulation_steps
1012                resume_step -= starting_epoch * len(dataloader["train"])
1013
1014    # update the progress_bar if load from checkpoint
1015    progress_bar.update(completed_steps)
1016
1017    # Evaluate before training (e.g. PTQ accuracy before QAT)
1018    eval_metric = evaluate_model(
1019        args,
1020        model,
1021        accelerator,
1022        examples["eval"],
1023        dataset["eval"],
1024        dataloader["eval"],
1025        answer_column_name,
1026    )
1027    for epoch in range(starting_epoch, args.num_train_epochs):
1028        model.train()
1029        if args.with_tracking:
1030            total_loss = 0
1031        if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1032            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1033            active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1034        else:
1035            active_dataloader = dataloader["train"]
1036        for batch in active_dataloader:
1037            optimizer.zero_grad()
1038            outputs = model(**batch)
1039
1040            # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1041            if args.do_modelopt_distill:
1042                loss = model.module.compute_kd_loss()
1043            else:
1044                loss, _, _ = outputs.to_tuple()
1045
1046            # We keep track of the loss at each epoch
1047            if args.with_tracking:
1048                total_loss += loss.detach().float()
1049
1050            accelerator.backward(loss)
1051            optimizer.step()
1052            lr_scheduler.step()
1053
1054            # Checks if the accelerator has performed an optimization step behind the scenes
1055            if accelerator.sync_gradients:
1056                progress_bar.update(1)
1057                completed_steps += 1
1058
1059            if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1060                accelerator.save_state(os.path.join(args.output_dir, f"step_{completed_steps}"))
1061
1062            if completed_steps >= args.max_train_steps:
1063                break
1064
1065        if args.checkpointing_steps == "epoch":
1066            accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}"))
1067
1068        eval_metric = evaluate_model(
1069            args,
1070            model,
1071            accelerator,
1072            examples["eval"],
1073            dataset["eval"],
1074            dataloader["eval"],
1075            answer_column_name,
1076        )
1077
1078        if args.with_tracking:
1079            log = {
1080                "squad": eval_metric,
1081                "train_loss": total_loss.item() / len(dataloader["train"]),  # type: ignore[attr-defined]
1082                "epoch": epoch,
1083                "step": completed_steps,
1084            }
1085            accelerator.log(log, step=completed_steps)
1086
1087    accelerator.wait_for_everyone()
1088    if accelerator.is_main_process and eval_metric:
1089        logger.info(json.dumps(eval_metric, indent=4))
1090        # Prefix all keys with prefix + '_'
1091        for key in list(eval_metric.keys()):
1092            eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094        with open(os.path.join(args.output_dir, "results.json"), "w") as f:
1095            json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099    args = parse_args(input_args)
1100
1101    # Initialize the accelerator
1102    accelerator_log_kwargs = {}
1103    if args.with_tracking:
1104        accelerator_log_kwargs["log_with"] = "tensorboard"
1105        accelerator_log_kwargs["project_dir"] = args.output_dir
1106    accelerator = Accelerator(
1107        gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108    )
1109
1110    # Setup logging
1111    logging.basicConfig(
1112        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113        datefmt="%m/%d/%Y %H:%M:%S",
1114        level=logging.INFO,
1115    )
1116    logger.info(accelerator.state, main_process_only=False)
1117    if accelerator.is_local_main_process:
1118        datasets.utils.logging.set_verbosity_warning()
1119        transformers.utils.logging.set_verbosity_info()
1120    else:
1121        datasets.utils.logging.set_verbosity_error()
1122        transformers.utils.logging.set_verbosity_error()
1123
1124    # Set the training seed
1125    set_seed(SEED)
1126
1127    # Handle the output_dir creation
1128    if accelerator.is_main_process:
1129        os.makedirs(args.output_dir, exist_ok=True)
1130    accelerator.wait_for_everyone()
1131
1132    # Load pretrained model and tokenizer
1133    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1134    model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1135    dummy_input = model.dummy_inputs["input_ids"]
1136
1137    # Get datasets
1138    examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1139        args, tokenizer, accelerator
1140    )
1141
1142    # Model Optimizer: Save the pruned or quantized model
1143    def save_modelopt_model(model):
1144        if accelerator.is_main_process and args.modelopt_save_file is not None:
1145            save_path = os.path.join(args.output_dir, args.modelopt_save_file)
1146            logger.info(f"Saving modelopt optimized model to {save_path}")
1147            mto.save(model, save_path)
1148
1149    # Model Optimizer: Restore the pruned or quantized model from a checkpoint
1150    if args.modelopt_restore_path:
1151        assert os.path.exists(
1152            args.modelopt_restore_path
1153        ), f"{args.modelopt_restore_path} does not exist."
1154        logger.info(f"Restoring model from {args.modelopt_restore_path}")
1155        model = mto.restore(model, args.modelopt_restore_path)
1156
1157    # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1158    if args.do_modelopt_prune:
1159        logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1160
1161        # NOTE: gradnas does not perform synchronization across data parallel groups
1162        # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1163        # in the data parallel group have the same gradients for pruning
1164        model = model.to(accelerator.device)
1165        dummy_input = dummy_input.to(accelerator.device)
1166
1167        # Search for the best pruned model
1168        # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1169        model, _ = mtp.prune(
1170            model=model,
1171            mode="gradnas",
1172            constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1173            dummy_input=dummy_input,
1174            config={
1175                "data_loader": dataloader["train"],
1176                "collect_func": lambda batch: (batch,),
1177                "loss_func": lambda output, batch: (
1178                    output["loss"] if isinstance(output, dict) else output[0]
1179                ),
1180            },
1181        )
1182        save_modelopt_model(model)
1183
1184    # Model Optimizer: Quantize the model to INT8 precision
1185    if args.modelopt_quantize_cfg:
1186        logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1187
1188        # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1189        # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1190        # in the data parallel group have same calibration statistics
1191        model = model.to(accelerator.device)
1192
1193        def forward_loop(model):
1194            num_samples = 256  # Use only 256 samples for PTQ calibration
1195            num_batches = num_samples // args.per_device_train_batch_size
1196            for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1197                batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1198                model(**batch)
1199                if idx >= num_batches:
1200                    break
1201
1202        model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1203        torch.cuda.empty_cache()
1204        save_modelopt_model(model)
1205
1206    if args.do_train:
1207        # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1208        if args.do_modelopt_distill:
1209            logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1210
1211            kd_config = {
1212                "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1213                "criterion": StartEndLogitsDistillationLoss(args.temperature),
1214            }
1215            model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1216
1217        train_and_evaluate_model(
1218            args,
1219            model,
1220            tokenizer,
1221            accelerator,
1222            examples,
1223            dataset,
1224            dataloader,
1225            answer_column_name,
1226        )
1227
1228        # Model Optimizer: Export the distilled model
1229        if args.do_modelopt_distill:
1230            model = mtd.export(model)
1231
1232        save_modelopt_model(model)
1233
1234    if accelerator.is_main_process and args.onnx_export_file is not None:
1235        save_path = os.path.join(args.output_dir, args.onnx_export_file)
1236        logger.info(f"Exporting ONNX model to {save_path}")
1237
1238        # Move the model and dummy_input to the device
1239        model = model.to(accelerator.device)
1240        dummy_input = dummy_input.to(accelerator.device)
1241
1242        with open(save_path, "wb") as f:
1243            f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1244
1245    logger.info("Done!")
1246
1247
1248if __name__ == "__main__":
1249    main()

Commands

  1. First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.

    1_prune.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    if [ ! -f "${BASE_DIR}/pruned_model.pth" ]; then
        modelopt_args="--do_modelopt_prune --modelopt_prune_flops_percent ${FLOPS_PERCENT}"
    else
        modelopt_args="--modelopt_restore_path ${BASE_DIR}/pruned_model.pth"
    fi
    modelopt_args="${modelopt_args} --modelopt_save_file pruned_model.pth"
    
    # Run pruning followed by distributed fine-tuning on the pruned model
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${BASE_DIR} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 1e-4 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 15 \
        --with_tracking \
        --checkpointing_steps epoch \
        --resume_from_last_ckpt
    
  2. Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.

    2_int8_quantize.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    OUTPUT_DIR=${BASE_DIR}/int8_quantized
    
    if [ ! -f "${OUTPUT_DIR}/quantized_model.pth" ]; then
        modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --modelopt_restore_path ${BASE_DIR}/pruned_model.pth"
    else
        modelopt_args="--modelopt_restore_path ${OUTPUT_DIR}/quantized_model.pth"
    fi
    
    modelopt_args="${modelopt_args} --modelopt_save_file quantized_model.pth"
    
    # Run distributed QAT on the pruned model with 0.1x LR and less epochs
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${OUTPUT_DIR} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 5e-6 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 2 \
        --with_tracking \
        --checkpointing_steps epoch \
        --resume_from_last_ckpt
    
  3. Export the quantized model to ONNX format for deployment with TensorRT.

    3_onnx_export.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    # Export to ONNX on a single GPU
    python3 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${BASE_DIR}/onnx \
        --modelopt_restore_path ${BASE_DIR}/int8_quantized/quantized_model.pth \
        --onnx_export_file pruned_model_int8.onnx \