HF BERT: Prune, Distill & Quantize

This example shows how to compress a Hugging Face Bert large model for Question Answering using the combination of modelopt.torch.prune, modelopt.torch.distill and modelopt.torch.quantize. More specifically, we will:

  1. Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation

  2. Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation

  3. Export the quantized model to ONNX format for deployment with TensorRT

Prerequisites

  1. Install Model Optimizer with optional torch and huggingface dependencies:

    pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
    

Note

This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.

Full code

You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub

bert_prune_distill_quantize.py
   1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from
   2# https://github.com/huggingface/transformers/blob/c52b515e/examples/pytorch/question-answering
   3#
   4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
   5#
   6# Licensed under the Apache License, Version 2.0 (the "License");
   7# you may not use this file except in compliance with the License.
   8# You may obtain a copy of the License at
   9#
  10#     http://www.apache.org/licenses/LICENSE-2.0
  11#
  12# Unless required by applicable law or agreed to in writing, software
  13# distributed under the License is distributed on an "AS IS" BASIS,
  14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15# See the License for the specific language governing permissions and
  16# limitations under the License.
  17
  18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  19# SPDX-License-Identifier: MIT
  20#
  21# Permission is hereby granted, free of charge, to any person obtaining a
  22# copy of this software and associated documentation files (the "Software"),
  23# to deal in the Software without restriction, including without limitation
  24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
  25# and/or sell copies of the Software, and to permit persons to whom the
  26# Software is furnished to do so, subject to the following conditions:
  27#
  28# The above copyright notice and this permission notice shall be included in
  29# all copies or substantial portions of the Software.
  30#
  31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  37# DEALINGS IN THE SOFTWARE.
  38"""
  39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
  40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
  41"""
  42
  43import argparse
  44import collections
  45import json
  46import logging
  47import math
  48import os
  49import random
  50from typing import Any, Dict, List, Optional, Tuple
  51
  52import datasets
  53import evaluate
  54import numpy as np
  55import torch
  56import torch.nn as nn
  57import transformers
  58from accelerate import Accelerator
  59from accelerate.logging import get_logger
  60from accelerate.utils import set_seed
  61from torch.utils.data import DataLoader
  62from tqdm.auto import tqdm
  63from transformers import (
  64    AutoModelForQuestionAnswering,
  65    AutoTokenizer,
  66    DataCollatorWithPadding,
  67    EvalPrediction,
  68    PreTrainedTokenizer,
  69    SchedulerType,
  70    default_data_collator,
  71    get_scheduler,
  72)
  73
  74# Model Optimizer: imports
  75import modelopt.torch.distill as mtd
  76import modelopt.torch.opt as mto
  77import modelopt.torch.prune as mtp
  78import modelopt.torch.quantization as mtq
  79from modelopt.torch._deploy.utils import get_onnx_bytes
  80
  81# Enable automatic save/load of modelopt_state with huggingface checkpointing
  82mto.enable_huggingface_checkpointing()
  83
  84logger = get_logger(__name__)
  85
  86SEED = 123
  87
  88
  89def parse_args(input_args: Optional[List[str]] = None):
  90    parser = argparse.ArgumentParser(
  91        description="Finetune a transformers model on a Question Answering task"
  92    )
  93
  94    # Training arguments
  95    parser.add_argument(
  96        "--model_name_or_path",
  97        type=str,
  98        default="bert-large-uncased-whole-word-masking-finetuned-squad",
  99        help="Path to pretrained model or model identifier from huggingface.co/models.",
 100    )
 101    parser.add_argument(
 102        "--do_train", action="store_true", help="Whether to run training / fine-tuning."
 103    )
 104    parser.add_argument(
 105        "--per_device_train_batch_size",
 106        type=int,
 107        default=16,
 108        help="Batch size (per device) for the training dataloader.",
 109    )
 110    parser.add_argument(
 111        "--per_device_eval_batch_size",
 112        type=int,
 113        default=64,
 114        help="Batch size (per device) for the evaluation dataloader.",
 115    )
 116    parser.add_argument(
 117        "--learning_rate",
 118        type=float,
 119        default=5e-5,
 120        help="Initial learning rate (after the potential warmup period) to use.",
 121    )
 122    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
 123    parser.add_argument(
 124        "--lr_scheduler_type",
 125        type=SchedulerType,
 126        default="linear",
 127        help="The scheduler type to use.",
 128        choices=[
 129            "linear",
 130            "cosine",
 131            "cosine_with_restarts",
 132            "polynomial",
 133            "constant",
 134            "constant_with_warmup",
 135        ],
 136    )
 137    parser.add_argument(
 138        "--num_warmup_steps",
 139        type=int,
 140        default=0,
 141        help="Number of steps for the warmup in the lr scheduler.",
 142    )
 143    parser.add_argument(
 144        "--num_train_epochs",
 145        type=float,
 146        default=2.0,
 147        help="Total number of training epochs to perform.",
 148    )
 149    parser.add_argument(
 150        "--max_train_steps",
 151        type=int,
 152        default=None,
 153        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
 154    )
 155    parser.add_argument(
 156        "--gradient_accumulation_steps",
 157        type=int,
 158        default=1,
 159        help="Number of updates steps to accumulate before performing a backward/update pass.",
 160    )
 161    parser.add_argument(
 162        "--preprocessing_num_workers",
 163        type=int,
 164        default=4,
 165        help="The number of processes to use for preprocessing the dataset.",
 166    )
 167
 168    # Logging and checkpointing arguments
 169    parser.add_argument(
 170        "--finetuned_model_path",
 171        type=str,
 172        default=None,
 173        help="Path to save the finetuned (pruned or quantized) model for restoring later with `.from_pretrained()`.",
 174    )
 175    parser.add_argument(
 176        "--with_tracking",
 177        action="store_true",
 178        help="Whether to enable experiment trackers for logging.",
 179    )
 180    parser.add_argument(
 181        "--checkpointing_steps",
 182        type=str,
 183        default="epoch",
 184        help=(
 185            "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
 186            " each epoch."
 187        ),
 188    )
 189    parser.add_argument(
 190        "--resume_from_last_ckpt",
 191        action="store_true",
 192        help="If the training should continue from the latest checkpoint in model_name_or_path.",
 193    )
 194    parser.add_argument(
 195        "--onnx_export_path", type=str, default=None, help="Path to export the ONNX model to."
 196    )
 197
 198    # Misc arguments for Bert (should not be modified in most cases)
 199    parser.add_argument(
 200        "--max_seq_length",
 201        type=int,
 202        default=384,
 203        help=(
 204            "The maximum total input sequence length after tokenization. Sequences longer than this"
 205            " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
 206        ),
 207    )
 208    parser.add_argument(
 209        "--pad_to_max_length",
 210        action="store_true",
 211        help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
 212    )
 213    parser.add_argument(
 214        "--doc_stride",
 215        type=int,
 216        default=128,
 217        help=(
 218            "When splitting up a long document into chunks how much stride to take between chunks."
 219        ),
 220    )
 221    parser.add_argument(
 222        "--n_best_size",
 223        type=int,
 224        default=20,
 225        help="The total number of n-best predictions to generate when looking for an answer.",
 226    )
 227    parser.add_argument(
 228        "--max_answer_length",
 229        type=int,
 230        default=30,
 231        help=(
 232            "The maximum length of an answer that can be generated. This is needed because the"
 233            " start and end predictions are not conditioned on one another."
 234        ),
 235    )
 236
 237    # Debugging arguments
 238    parser.add_argument(
 239        "--max_train_samples",
 240        type=int,
 241        default=None,
 242        help="For debugging purposes or quicker training.",
 243    )
 244    parser.add_argument(
 245        "--max_eval_samples",
 246        type=int,
 247        default=None,
 248        help="For debugging purposes or quicker training.",
 249    )
 250
 251    # Model Optimizer: pruning arguments
 252    parser.add_argument(
 253        "--do_modelopt_prune",
 254        action="store_true",
 255        help="Whether or not to use Model Optimizer pruning.",
 256    )
 257    parser.add_argument(
 258        "--modelopt_prune_flops_percent",
 259        type=float,
 260        default=None,
 261        help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
 262    )
 263    parser.add_argument(
 264        "--pruned_model_path",
 265        type=str,
 266        default=None,
 267        help="Path to save the pruned model for further finetuning.",
 268    )
 269
 270    # Model Optimizer: quantization arguments
 271    parser.add_argument(
 272        "--modelopt_quantize_cfg",
 273        help="Model Optimizer quantization config.",
 274        choices=mtq.config.choices,
 275    )
 276
 277    # Model Optimizer: Distillation arguments
 278    parser.add_argument(
 279        "--do_modelopt_distill",
 280        action="store_true",
 281        help="Whether or not to use distillation. A teacher model must be specified.",
 282    )
 283    parser.add_argument(
 284        "--temperature", type=float, default=2.0, help="The temperature to use when distilling."
 285    )
 286    parser.add_argument(
 287        "--ptq_model_path",
 288        type=str,
 289        default=None,
 290        help="Path to save the PTQ quantized model for further QAT.",
 291    )
 292
 293    args = parser.parse_args(input_args)
 294
 295    # Sanity checks
 296    if args.do_train and not args.finetuned_model_path:
 297        raise ValueError("`finetuned_model_path` required when `do_train` is passed.")
 298    if args.do_modelopt_prune and not (
 299        args.modelopt_prune_flops_percent and args.pruned_model_path
 300    ):
 301        raise ValueError(
 302            "`modelopt_prune_flops_percent` and `pruned_model_path` required when `do_modelopt_prune` is passed."
 303        )
 304    if args.modelopt_quantize_cfg and not args.ptq_model_path:
 305        raise ValueError("`ptq_model_path` required when `modelopt_quantize_cfg` is passed.")
 306
 307    return args
 308
 309
 310def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
 311    """Get the examples, dataset, dataloader, answer_column_name
 312
 313    You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 314    or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
 315    (the dataset will be downloaded automatically from the datasets Hub).
 316
 317    For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
 318    'text' is found. You can easily tweak this behavior (see below).
 319    """
 320
 321    def prepare_train_features(examples):
 322        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 323        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 324        # left whitespace
 325        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 326
 327        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 328        # in one example possible giving several features when a context is long, each of those features having a
 329        # context that overlaps a bit the context of the previous feature.
 330        tokenized_examples = tokenizer(
 331            examples[question_column_name if pad_on_right else context_column_name],
 332            examples[context_column_name if pad_on_right else question_column_name],
 333            truncation="only_second" if pad_on_right else "only_first",
 334            max_length=max_seq_length,
 335            stride=args.doc_stride,
 336            return_overflowing_tokens=True,
 337            return_offsets_mapping=True,
 338            padding="max_length" if args.pad_to_max_length else False,
 339        )
 340
 341        # Since one example might give us several features if it has a long context, we need a map from a feature to
 342        # its corresponding example. This key gives us just that.
 343        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 344        # The offset mappings will give us a map from token to character position in the original context. This will
 345        # help us compute the start_positions and end_positions.
 346        offset_mapping = tokenized_examples.pop("offset_mapping")
 347
 348        # Let's label those examples!
 349        tokenized_examples["start_positions"] = []
 350        tokenized_examples["end_positions"] = []
 351
 352        for i, offsets in enumerate(offset_mapping):
 353            # We will label impossible answers with the index of the CLS token.
 354            input_ids = tokenized_examples["input_ids"][i]
 355            cls_index = input_ids.index(tokenizer.cls_token_id)
 356
 357            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 358            sequence_ids = tokenized_examples.sequence_ids(i)
 359
 360            # One example can give several spans, this is the index of the example containing this span of text.
 361            sample_index = sample_mapping[i]
 362            answers = examples[answer_column_name][sample_index]
 363            # If no answers are given, set the cls_index as answer.
 364            if len(answers["answer_start"]) == 0:
 365                tokenized_examples["start_positions"].append(cls_index)
 366                tokenized_examples["end_positions"].append(cls_index)
 367            else:
 368                # Start/end character index of the answer in the text.
 369                start_char = answers["answer_start"][0]
 370                end_char = start_char + len(answers["text"][0])
 371
 372                # Start token index of the current span in the text.
 373                token_start_index = 0
 374                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
 375                    token_start_index += 1
 376
 377                # End token index of the current span in the text.
 378                token_end_index = len(input_ids) - 1
 379                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
 380                    token_end_index -= 1
 381
 382                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
 383                if not (
 384                    offsets[token_start_index][0] <= start_char
 385                    and offsets[token_end_index][1] >= end_char
 386                ):
 387                    tokenized_examples["start_positions"].append(cls_index)
 388                    tokenized_examples["end_positions"].append(cls_index)
 389                else:
 390                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
 391                    # Note: we could go after the last offset if the answer is the last word (edge case).
 392                    while (
 393                        token_start_index < len(offsets)
 394                        and offsets[token_start_index][0] <= start_char
 395                    ):
 396                        token_start_index += 1
 397                    tokenized_examples["start_positions"].append(token_start_index - 1)
 398                    while offsets[token_end_index][1] >= end_char:
 399                        token_end_index -= 1
 400                    tokenized_examples["end_positions"].append(token_end_index + 1)
 401
 402        return tokenized_examples
 403
 404    def prepare_validation_features(examples):
 405        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 406        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 407        # left whitespace
 408        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 409
 410        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 411        # in one example possible giving several features when a context is long, each of those features having a
 412        # context that overlaps a bit the context of the previous feature.
 413        tokenized_examples = tokenizer(
 414            examples[question_column_name if pad_on_right else context_column_name],
 415            examples[context_column_name if pad_on_right else question_column_name],
 416            truncation="only_second" if pad_on_right else "only_first",
 417            max_length=max_seq_length,
 418            stride=args.doc_stride,
 419            return_overflowing_tokens=True,
 420            return_offsets_mapping=True,
 421            padding="max_length" if args.pad_to_max_length else False,
 422        )
 423
 424        # Since one example might give us several features if it has a long context, we need a map from a feature to
 425        # its corresponding example. This key gives us just that.
 426        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 427
 428        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
 429        # corresponding example_id and we will store the offset mappings.
 430        tokenized_examples["example_id"] = []
 431
 432        for i in range(len(tokenized_examples["input_ids"])):
 433            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 434            sequence_ids = tokenized_examples.sequence_ids(i)
 435            context_index = 1 if pad_on_right else 0
 436
 437            # One example can give several spans, this is the index of the example containing this span of text.
 438            sample_index = sample_mapping[i]
 439            tokenized_examples["example_id"].append(examples["id"][sample_index])
 440
 441            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
 442            # position is part of the context or not.
 443            tokenized_examples["offset_mapping"][i] = [
 444                (o if sequence_ids[k] == context_index else None)
 445                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
 446            ]
 447
 448        return tokenized_examples
 449
 450    examples, dataset, dataloader = {}, {}, {}
 451
 452    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
 453    # download the dataset.
 454    # Downloading and loading a dataset from the hub.
 455    raw_datasets = datasets.load_dataset("squad")
 456    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
 457    # https://huggingface.co/docs/datasets/loading_datasets.
 458
 459    # Preprocessing the datasets.
 460    # Preprocessing is slighlty different for training and evaluation.
 461
 462    column_names = raw_datasets["train"].column_names
 463
 464    question_column_name = "question" if "question" in column_names else column_names[0]
 465    context_column_name = "context" if "context" in column_names else column_names[1]
 466    answer_column_name = "answers" if "answers" in column_names else column_names[2]
 467
 468    # Padding side determines if we do (question|context) or (context|question).
 469    pad_on_right = tokenizer.padding_side == "right"
 470
 471    if args.max_seq_length > tokenizer.model_max_length:
 472        logger.warning(
 473            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
 474            f" for the model ({tokenizer.model_max_length}). Using"
 475            f" max_seq_length={tokenizer.model_max_length}."
 476        )
 477
 478    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
 479
 480    examples["train"] = raw_datasets["train"]
 481    if args.max_train_samples is not None:
 482        # We will select sample from whole data if agument is specified
 483        examples["train"] = examples["train"].select(range(args.max_train_samples))
 484
 485    # Create train feature from dataset
 486    with accelerator.main_process_first():
 487        dataset["train"] = examples["train"].map(
 488            prepare_train_features,
 489            batched=True,
 490            num_proc=args.preprocessing_num_workers,
 491            remove_columns=column_names,
 492            load_from_cache_file=True,
 493            desc="Running tokenizer on train dataset",
 494        )
 495        # if args.max_train_samples is not None:
 496        #     # Number of samples might increase during Feature Creation, We select only specified max samples
 497        #     dataset["train"] = dataset["train"].select(range(args.max_train_samples))
 498
 499    examples["eval"] = raw_datasets["validation"]
 500    if args.max_eval_samples is not None:
 501        # We will select sample from whole data
 502        examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
 503    # Validation Feature Creation
 504    with accelerator.main_process_first():
 505        dataset["eval"] = examples["eval"].map(
 506            prepare_validation_features,
 507            batched=True,
 508            num_proc=args.preprocessing_num_workers,
 509            remove_columns=column_names,
 510            load_from_cache_file=True,
 511            desc="Running tokenizer on validation dataset",
 512        )
 513        # if args.max_eval_samples is not None:
 514        #     # During Feature creation dataset samples might increase, we will select required samples again
 515        #     dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
 516
 517    # Log a random sample from the training set:
 518    for index in random.sample(range(len(dataset["train"])), 1):
 519        logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
 520
 521    # DataLoaders creation:
 522    if args.pad_to_max_length:
 523        # If padding was already done ot max length, we use the default data collator that will just convert everything
 524        # to tensors.
 525        data_collator = default_data_collator
 526    else:
 527        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
 528        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
 529        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
 530        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
 531
 532    dataloader["train"] = DataLoader(
 533        dataset["train"],
 534        shuffle=True,
 535        collate_fn=data_collator,
 536        batch_size=args.per_device_train_batch_size,
 537    )
 538
 539    dataloader["eval"] = DataLoader(
 540        dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
 541        collate_fn=data_collator,
 542        batch_size=args.per_device_eval_batch_size,
 543    )
 544
 545    return examples, dataset, dataloader, answer_column_name
 546
 547
 548def evaluate_model(
 549    args,
 550    model: nn.Module,
 551    accelerator: Accelerator,
 552    eval_examples: Any,
 553    eval_dataset: Any,
 554    eval_dataloader: DataLoader,
 555    answer_column_name: str,
 556    prefix: str = "Eval",
 557):
 558    def create_and_fill_np_array(start_or_end_logits, max_len):
 559        """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
 560
 561        Args:
 562            start_or_end_logits: This is the output predictions of the model.
 563                We can only enter either start or end logits.
 564            max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
 565        """
 566        step = 0
 567        # create a numpy array and fill it with -100.
 568        logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
 569        # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
 570        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
 571            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
 572            # And after every iteration we have to change the step
 573            batch_size = output_logit.shape[0]
 574            cols = output_logit.shape[1]
 575
 576            if step + batch_size < len(eval_dataset):
 577                logits_concat[step : step + batch_size, :cols] = output_logit
 578            else:
 579                logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
 580
 581            step += batch_size
 582
 583        return logits_concat
 584
 585    def postprocess_qa_predictions(
 586        examples,
 587        features,
 588        predictions: Tuple[np.ndarray, np.ndarray],
 589        version_2_with_negative: bool = False,
 590        n_best_size: int = 20,
 591        max_answer_length: int = 30,
 592        null_score_diff_threshold: float = 0.0,
 593        output_dir: Optional[str] = None,
 594        prefix: Optional[str] = None,
 595    ) -> EvalPrediction:
 596        """Post-processes the predictions of a question-answering model to convert them to answers
 597        that are substrings of  the original contexts. This is the base postprocessing functions for
 598        models that only return start and end logits.
 599
 600        Args:
 601            examples: The non-preprocessed dataset.
 602            features: The processed dataset.
 603            predictions: The predictions of the model: two arrays containing the start logits and the end logits
 604                respectively. Its first dimension must match the number of elements of `features`.
 605            version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
 606            n_best_size: The total number of n-best predictions to generate when looking for an answer.
 607            max_answer_length: The maximum length of an answer that can be generated. This is needed
 608                because the start and end predictions are not conditioned on one another.
 609            null_score_diff_threshold: The threshold used to select the null answer: if the best answer
 610                has a score that is less than the score of the null answer minus this threshold, the
 611                null answer is selected for this example (note that the score of the null answer for
 612                an example giving several features is the minimum of the scores for the null answer on
 613                each feature: all features must be aligned on the fact they `want` to predict a null answer).
 614                Only useful when `version_2_with_negative` is `True`.
 615            output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
 616                and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
 617                answers, are saved in `output_dir`.
 618            prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
 619        """
 620        if len(predictions) != 2:
 621            raise ValueError(
 622                "`predictions` should be a tuple with two elements (start_logits, end_logits)."
 623            )
 624        all_start_logits, all_end_logits = predictions
 625
 626        if len(predictions[0]) != len(features):
 627            raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
 628
 629        # Build a map example to its corresponding features.
 630        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
 631        features_per_example = collections.defaultdict(list)
 632        for i, feature in enumerate(features):
 633            features_per_example[example_id_to_index[feature["example_id"]]].append(i)
 634
 635        # The dictionaries we have to fill.
 636        all_predictions = collections.OrderedDict()
 637        all_nbest_json = collections.OrderedDict()
 638        if version_2_with_negative:
 639            scores_diff_json = collections.OrderedDict()
 640
 641        logger.debug(
 642            f"Post-processing {len(examples)} example predictions split into"
 643            f" {len(features)} features."
 644        )
 645
 646        # Let's loop over all the examples!
 647        for example_index, example in enumerate(examples):
 648            # Those are the indices of the features associated to the current example.
 649            feature_indices = features_per_example[example_index]
 650
 651            min_null_prediction = None
 652            prelim_predictions = []
 653
 654            # Looping through all the features associated to the current example.
 655            for feature_index in feature_indices:
 656                # We grab the predictions of the model for this feature.
 657                start_logits = all_start_logits[feature_index]
 658                end_logits = all_end_logits[feature_index]
 659                # This is what will allow us to map some the positions in our logits to span of texts in the original
 660                # context.
 661                offset_mapping = features[feature_index]["offset_mapping"]
 662                # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
 663                # context available in the current feature.
 664                token_is_max_context = features[feature_index].get("token_is_max_context", None)
 665
 666                # Update minimum null prediction.
 667                feature_null_score = start_logits[0] + end_logits[0]
 668                if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
 669                    min_null_prediction = {
 670                        "offsets": (0, 0),
 671                        "score": feature_null_score,
 672                        "start_logit": start_logits[0],
 673                        "end_logit": end_logits[0],
 674                    }
 675
 676                # Go through all possibilities for the `n_best_size` greater start and end logits.
 677                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
 678                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
 679                for start_index in start_indexes:
 680                    for end_index in end_indexes:
 681                        # Don't consider out-of-scope answers, either because the indices are out of bounds or
 682                        # correspond to part of the input_ids that are not in the context.
 683                        if (
 684                            start_index >= len(offset_mapping)
 685                            or end_index >= len(offset_mapping)
 686                            or offset_mapping[start_index] is None
 687                            or len(offset_mapping[start_index]) < 2
 688                            or offset_mapping[end_index] is None
 689                            or len(offset_mapping[end_index]) < 2
 690                        ):
 691                            continue
 692                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
 693                        if (
 694                            end_index < start_index
 695                            or end_index - start_index + 1 > max_answer_length
 696                        ):
 697                            continue
 698                        # Don't consider answer that don't have the maximum context available (if such information is
 699                        # provided).
 700                        if token_is_max_context is not None and not token_is_max_context.get(
 701                            str(start_index), False
 702                        ):
 703                            continue
 704
 705                        prelim_predictions.append(
 706                            {
 707                                "offsets": (
 708                                    offset_mapping[start_index][0],
 709                                    offset_mapping[end_index][1],
 710                                ),
 711                                "score": start_logits[start_index] + end_logits[end_index],
 712                                "start_logit": start_logits[start_index],
 713                                "end_logit": end_logits[end_index],
 714                            }
 715                        )
 716            if version_2_with_negative and min_null_prediction is not None:
 717                # Add the minimum null prediction
 718                prelim_predictions.append(min_null_prediction)
 719                null_score = min_null_prediction["score"]
 720
 721            # Only keep the best `n_best_size` predictions.
 722            n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
 723                :n_best_size
 724            ]
 725
 726            # Add back the minimum null prediction if it was removed because of its low score.
 727            if (
 728                version_2_with_negative
 729                and min_null_prediction is not None
 730                and not any(p["offsets"] == (0, 0) for p in n_best_preds)
 731            ):
 732                n_best_preds.append(min_null_prediction)
 733
 734            # Use the offsets to gather the answer text in the original context.
 735            context = example["context"]
 736            for pred in n_best_preds:
 737                offsets = pred.pop("offsets")
 738                pred["text"] = context[offsets[0] : offsets[1]]
 739
 740            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
 741            # failure.
 742            if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
 743                n_best_preds.insert(
 744                    0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
 745                )
 746
 747            # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
 748            # using the LogSumExp trick).
 749            scores = np.array([pred.pop("score") for pred in n_best_preds])
 750            exp_scores = np.exp(scores - np.max(scores))
 751            probs = exp_scores / exp_scores.sum()
 752
 753            # Include the probabilities in our n_best_preds.
 754            for prob, pred in zip(probs, n_best_preds):
 755                pred["probability"] = prob
 756
 757            # Pick the best prediction. If the null answer is not possible, this is easy.
 758            if not version_2_with_negative:
 759                all_predictions[example["id"]] = n_best_preds[0]["text"]
 760            else:
 761                # Otherwise we first need to find the best non-empty prediction.
 762                i = 0
 763                while n_best_preds[i]["text"] == "":
 764                    i += 1
 765                best_non_null_pred = n_best_preds[i]
 766
 767                # Then we compare to the null prediction using the threshold.
 768                score_diff = (
 769                    null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
 770                )
 771                scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
 772                if score_diff > null_score_diff_threshold:
 773                    all_predictions[example["id"]] = ""
 774                else:
 775                    all_predictions[example["id"]] = best_non_null_pred["text"]
 776
 777            # Make `n_best_preds` JSON-serializable by casting np.float back to float.
 778            all_nbest_json[example["id"]] = [
 779                {
 780                    k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
 781                    for k, v in pred.items()
 782                }
 783                for pred in n_best_preds
 784            ]
 785
 786        # If we have an output_dir, let's save all those dicts.
 787        if output_dir is not None:
 788            if not os.path.isdir(output_dir):
 789                raise EnvironmentError(f"{output_dir} is not a directory.")
 790
 791            prediction_file = os.path.join(
 792                output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
 793            )
 794            nbest_file = os.path.join(
 795                output_dir,
 796                "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
 797            )
 798            if version_2_with_negative:
 799                null_odds_file = os.path.join(
 800                    output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
 801                )
 802
 803            logger.info(f"Saving predictions to {prediction_file}.")
 804            with open(prediction_file, "w") as writer:
 805                writer.write(json.dumps(all_predictions, indent=4) + "\n")
 806            logger.info(f"Saving nbest_preds to {nbest_file}.")
 807            with open(nbest_file, "w") as writer:
 808                writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
 809            if version_2_with_negative:
 810                logger.info(f"Saving null_odds to {null_odds_file}.")
 811                with open(null_odds_file, "w") as writer:
 812                    writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
 813
 814        # Format the result to the format the metric expects.
 815        formatted_predictions = [
 816            {"id": k, "prediction_text": v} for k, v in all_predictions.items()
 817        ]
 818
 819        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
 820        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
 821
 822    logger.info(f"***** Running {prefix} *****")
 823    logger.info(f"  Num examples = {len(eval_dataset)}")
 824    logger.info(f"  Batch size = {args.per_device_eval_batch_size}")
 825
 826    model.eval()
 827    all_start_logits = []
 828    all_end_logits = []
 829    for _, batch in enumerate(eval_dataloader):
 830        with torch.no_grad():
 831            outputs = model(**batch)
 832            start_logits = outputs.start_logits
 833            end_logits = outputs.end_logits
 834
 835            if (
 836                not args.pad_to_max_length
 837            ):  # necessary to pad predictions and labels for being gathered
 838                start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
 839                end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
 840
 841            all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
 842            all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
 843
 844    # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
 845    if args.do_modelopt_distill:
 846        model.module.compute_kd_loss()
 847
 848    max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
 849
 850    # concatenate the numpy array
 851    start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
 852    end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
 853
 854    outputs_numpy = (start_logits_concat, end_logits_concat)
 855    prediction = postprocess_qa_predictions(
 856        examples=eval_examples,
 857        features=eval_dataset,
 858        predictions=outputs_numpy,
 859        n_best_size=args.n_best_size,
 860        max_answer_length=args.max_answer_length,
 861        # output_dir=args.finetuned_model_path,
 862        prefix=prefix,
 863    )
 864
 865    metric = evaluate.load("squad")
 866    eval_metric = metric.compute(
 867        predictions=prediction.predictions, references=prediction.label_ids
 868    )
 869    logger.info(f"{prefix} metrics: {eval_metric}\n")
 870    return eval_metric
 871
 872
 873# Model Optimizer: Define a teacher factory for initializing the distillation model
 874def teacher_factory(model_name_or_path):
 875    return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
 876
 877
 878# Model Optimizer: Define a custom distillation loss function that uses start and end logits
 879class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
 880    def forward(self, outputs_s, outputs_t):
 881        loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
 882        loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
 883        loss = (loss_start + loss_end) / 2.0
 884
 885        return loss
 886
 887
 888def train_and_evaluate_model(
 889    args,
 890    model: nn.Module,
 891    accelerator: Accelerator,
 892    examples: Dict,
 893    dataset: Dict,
 894    dataloader: Dict[str, DataLoader],
 895    answer_column_name,
 896):
 897    # Optimizer
 898    # Split weights in two groups, one with weight decay and the other not.
 899    no_decay = ["bias", "LayerNorm.weight"]
 900    optimizer_grouped_parameters = [
 901        {
 902            "params": [
 903                p
 904                for n, p in model.named_parameters()
 905                if p.requires_grad and not any(nd in n for nd in no_decay)
 906            ],
 907            "weight_decay": args.weight_decay,
 908        },
 909        {
 910            "params": [
 911                p
 912                for n, p in model.named_parameters()
 913                if p.requires_grad and any(nd in n for nd in no_decay)
 914            ],
 915            "weight_decay": 0.0,
 916        },
 917    ]
 918    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 919
 920    # Scheduler and math around the number of training steps.
 921    overrode_max_train_steps = False
 922    num_update_steps_per_epoch = math.ceil(
 923        len(dataloader["train"]) / args.gradient_accumulation_steps
 924    )
 925    if args.max_train_steps is None:
 926        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 927        overrode_max_train_steps = True
 928
 929    lr_scheduler = get_scheduler(
 930        name=args.lr_scheduler_type,
 931        optimizer=optimizer,
 932        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
 933        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
 934    )
 935
 936    # Prepare everything with our `accelerator`.
 937    model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
 938        model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
 939    )
 940
 941    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 942    num_update_steps_per_epoch = math.ceil(
 943        len(dataloader["train"]) / args.gradient_accumulation_steps
 944    )
 945    if overrode_max_train_steps:
 946        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 947    # Afterwards we recalculate our number of training epochs
 948    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 949
 950    # Figure out how many steps we should save the Accelerator states
 951    checkpointing_steps = args.checkpointing_steps
 952    if checkpointing_steps is not None and checkpointing_steps.isdigit():
 953        checkpointing_steps = int(checkpointing_steps)
 954
 955    # We need to initialize the trackers we use, and also store our configuration.
 956    # The trackers initializes automatically on the main process.
 957    if args.with_tracking:
 958        experiment_config = vars(args)
 959        # TensorBoard cannot log Enums, need the raw value
 960        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
 961        accelerator.init_trackers("tensorboard", experiment_config)
 962
 963    # Train!
 964    total_batch_size = (
 965        args.per_device_train_batch_size
 966        * accelerator.num_processes
 967        * args.gradient_accumulation_steps
 968    )
 969
 970    logger.info("***** Running training *****")
 971    logger.info(f"  Num examples = {len(dataset['train'])}")
 972    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 973    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
 974    logger.info(
 975        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
 976    )
 977    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 978    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 979
 980    # Only show the progress bar once on each machine.
 981    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
 982    completed_steps = 0
 983    starting_epoch = 0
 984    resume_step = None
 985
 986    # Potentially load in the weights and states from a previous save
 987    if args.resume_from_last_ckpt:
 988        # Get the most recent checkpoint
 989        dirs = [
 990            f.path
 991            for f in os.scandir(args.finetuned_model_path)
 992            if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
 993        ]
 994        if len(dirs) == 0:
 995            logger.warning(
 996                f"No checkpoint found in {args.finetuned_model_path}. Training from scratch!"
 997            )
 998        else:
 999            latest_dir = max(dirs, key=os.path.getctime)
1000            accelerator.load_state(latest_dir)
1001
1002            # Extract `epoch_{i}` or `step_{i}`
1003            latest_dir = os.path.basename(latest_dir)
1004            if "epoch" in latest_dir:
1005                starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1006                completed_steps = starting_epoch * num_update_steps_per_epoch
1007            else:
1008                # need to multiply `gradient_accumulation_steps` to reflect real steps
1009                resume_step = (
1010                    int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1011                )
1012                starting_epoch = resume_step // len(dataloader["train"])
1013                completed_steps = resume_step // args.gradient_accumulation_steps
1014                resume_step -= starting_epoch * len(dataloader["train"])
1015
1016    # update the progress_bar if load from checkpoint
1017    progress_bar.update(completed_steps)
1018
1019    # Evaluate before training (e.g. PTQ accuracy before QAT)
1020    eval_metric = evaluate_model(
1021        args,
1022        model,
1023        accelerator,
1024        examples["eval"],
1025        dataset["eval"],
1026        dataloader["eval"],
1027        answer_column_name,
1028    )
1029    for epoch in range(starting_epoch, args.num_train_epochs):
1030        model.train()
1031        if args.with_tracking:
1032            total_loss = 0
1033        if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1034            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1035            active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1036        else:
1037            active_dataloader = dataloader["train"]
1038        for batch in active_dataloader:
1039            optimizer.zero_grad()
1040            outputs = model(**batch)
1041
1042            # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1043            if args.do_modelopt_distill:
1044                loss = model.module.compute_kd_loss()
1045            else:
1046                loss, _, _ = outputs.to_tuple()
1047
1048            # We keep track of the loss at each epoch
1049            if args.with_tracking:
1050                total_loss += loss.detach().float()
1051
1052            accelerator.backward(loss)
1053            optimizer.step()
1054            lr_scheduler.step()
1055
1056            # Checks if the accelerator has performed an optimization step behind the scenes
1057            if accelerator.sync_gradients:
1058                progress_bar.update(1)
1059                completed_steps += 1
1060
1061            if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1062                accelerator.save_state(
1063                    os.path.join(args.finetuned_model_path, f"step_{completed_steps}")
1064                )
1065
1066            if completed_steps >= args.max_train_steps:
1067                break
1068
1069        if args.checkpointing_steps == "epoch":
1070            accelerator.save_state(os.path.join(args.finetuned_model_path, f"epoch_{epoch}"))
1071
1072        eval_metric = evaluate_model(
1073            args,
1074            model,
1075            accelerator,
1076            examples["eval"],
1077            dataset["eval"],
1078            dataloader["eval"],
1079            answer_column_name,
1080        )
1081
1082        if args.with_tracking:
1083            log = {
1084                "squad": eval_metric,
1085                "train_loss": total_loss.item() / len(dataloader["train"]),  # type: ignore[attr-defined]
1086                "epoch": epoch,
1087                "step": completed_steps,
1088            }
1089            accelerator.log(log, step=completed_steps)
1090
1091    accelerator.wait_for_everyone()
1092    if accelerator.is_main_process and eval_metric:
1093        logger.info(json.dumps(eval_metric, indent=4))
1094        # Prefix all keys with prefix + '_'
1095        for key in list(eval_metric.keys()):
1096            eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1097
1098        with open(os.path.join(args.finetuned_model_path, "results.json"), "w") as f:
1099            json.dump(eval_metric, f, indent=4)
1100
1101
1102def main(input_args: Optional[List[str]] = None) -> None:
1103    args = parse_args(input_args)
1104
1105    # Initialize the accelerator
1106    accelerator_log_kwargs = {}
1107    if args.with_tracking:
1108        accelerator_log_kwargs["log_with"] = "tensorboard"
1109        accelerator_log_kwargs["project_dir"] = args.finetuned_model_path
1110    accelerator = Accelerator(
1111        gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1112    )
1113
1114    # Setup logging
1115    logging.basicConfig(
1116        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1117        datefmt="%m/%d/%Y %H:%M:%S",
1118        level=logging.INFO,
1119    )
1120    logger.info(accelerator.state, main_process_only=False)
1121    if accelerator.is_local_main_process:
1122        datasets.utils.logging.set_verbosity_warning()
1123        transformers.utils.logging.set_verbosity_info()
1124    else:
1125        datasets.utils.logging.set_verbosity_error()
1126        transformers.utils.logging.set_verbosity_error()
1127
1128    # Set the training seed
1129    set_seed(SEED)
1130
1131    accelerator.wait_for_everyone()
1132
1133    # Load pretrained model and tokenizer
1134    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1135    model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1136    dummy_input = model.dummy_inputs["input_ids"]
1137
1138    # Get datasets
1139    examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1140        args, tokenizer, accelerator
1141    )
1142
1143    def save(model, output_path):
1144        if accelerator.is_main_process:
1145            os.makedirs(os.path.dirname(output_path), exist_ok=True)
1146            model = accelerator.unwrap_model(model)
1147            model.save_pretrained(output_path)
1148            tokenizer.save_pretrained(output_path)
1149            logger.info(f"Saved model and tokenizer to {output_path}")
1150
1151    # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1152    if args.do_modelopt_prune:
1153        logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1154
1155        # NOTE: gradnas does not perform synchronization across data parallel groups
1156        # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1157        # in the data parallel group have the same gradients for pruning
1158        model = model.to(accelerator.device)
1159        dummy_input = dummy_input.to(accelerator.device)
1160
1161        # Search for the best pruned model
1162        # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1163        model, _ = mtp.prune(
1164            model=model,
1165            mode="gradnas",
1166            constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1167            dummy_input=dummy_input,
1168            config={
1169                "data_loader": dataloader["train"],
1170                "collect_func": lambda batch: (batch,),
1171                "loss_func": lambda output, batch: (
1172                    output["loss"] if isinstance(output, dict) else output[0]
1173                ),
1174            },
1175        )
1176        save(model, args.pruned_model_path)
1177
1178    # Model Optimizer: Quantize the model to INT8 precision
1179    if args.modelopt_quantize_cfg:
1180        logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1181
1182        # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1183        # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1184        # in the data parallel group have same calibration statistics
1185        model = model.to(accelerator.device)
1186
1187        def forward_loop(model):
1188            num_samples = 256  # Use only 256 samples for PTQ calibration
1189            num_batches = num_samples // args.per_device_train_batch_size
1190            for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1191                batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1192                model(**batch)
1193                if idx >= num_batches:
1194                    break
1195
1196        model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1197        torch.cuda.empty_cache()
1198        save(model, args.ptq_model_path)
1199
1200    if args.do_train:
1201        # Handle the finetuned_model_path creation
1202        if accelerator.is_main_process:
1203            os.makedirs(args.finetuned_model_path, exist_ok=True)
1204
1205        # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1206        if args.do_modelopt_distill:
1207            logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1208
1209            kd_config = {
1210                "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1211                "criterion": StartEndLogitsDistillationLoss(args.temperature),
1212            }
1213            model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1214
1215        train_and_evaluate_model(
1216            args, model, accelerator, examples, dataset, dataloader, answer_column_name
1217        )
1218
1219        # Model Optimizer: Export the distilled model
1220        if args.do_modelopt_distill:
1221            model = mtd.export(model)
1222
1223        save(model, args.finetuned_model_path)
1224
1225    if accelerator.is_main_process and args.onnx_export_path is not None:
1226        logger.info(f"Exporting ONNX model to {args.onnx_export_path}")
1227
1228        # Move the model and dummy_input to the device
1229        model = model.to(accelerator.device)
1230        dummy_input = dummy_input.to(accelerator.device)
1231
1232        with open(args.onnx_export_path, "wb") as f:
1233            f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1234
1235    logger.info("Done!")
1236
1237
1238if __name__ == "__main__":
1239    main()

Commands

  1. First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.

    1_prune.sh
    #!/bin/bash
    set -ex
    
    MODEL_NAME_OR_PATH=bert-large-uncased-whole-word-masking-finetuned-squad
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    PRUNED_MODEL_PATH=${BASE_DIR}/pruned/
    FINETUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/
    
    modelopt_args=""
    if [ ! -d ${PRUNED_MODEL_PATH} ]; then
        modelopt_args="--do_modelopt_prune \
            --modelopt_prune_flops_percent ${FLOPS_PERCENT} \
            --pruned_model_path ${PRUNED_MODEL_PATH}"
    else
        MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH}
    fi
    
    # Run pruning followed by distributed fine-tuning on the pruned model
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path ${MODEL_NAME_OR_PATH} \
        --finetuned_model_path ${FINETUNED_MODEL_PATH} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 1e-4 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 15 \
        --with_tracking \
        --resume_from_last_ckpt
    
  2. Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.

    2_int8_quantize.sh
    #!/bin/bash
    set -ex
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    PRUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/
    PTQ_MODEL_PATH=${BASE_DIR}/int8_ptq/
    QAT_MODEL_PATH=${BASE_DIR}/int8_qat/
    
    modelopt_args=""
    if [ ! -d ${PTQ_MODEL_PATH} ]; then
        modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --ptq_model_path ${PTQ_MODEL_PATH}"
        MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH}
    else
        MODEL_NAME_OR_PATH=${PTQ_MODEL_PATH}
    fi
    
    # Run distributed QAT on the pruned model with 0.1x LR and less epochs
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path ${MODEL_NAME_OR_PATH} \
        --finetuned_model_path ${QAT_MODEL_PATH} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 5e-6 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 2 \
        --with_tracking \
        --resume_from_last_ckpt
    
  3. Export the quantized model to ONNX format for deployment with TensorRT.

    3_onnx_export.sh
    #!/bin/bash
    set -ex
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    # Export to ONNX on a single GPU
    python3 bert_prune_distill_quantize.py \
        --model_name_or_path ${BASE_DIR}/int8_qat/ \
        --onnx_export_path ${BASE_DIR}/pruned_model_int8.onnx \