HF BERT: Prune, Distill & Quantize

This example shows how to compress a Hugging Face Bert large model for Question Answering using the combination of modelopt.torch.prune, modelopt.torch.distill and modelopt.torch.quantize. More specifically, we will:

  1. Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation

  2. Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation

  3. Export the quantized model to ONNX format for deployment with TensorRT

Prerequisites

  1. Install Model Optimizer with optional torch and huggingface dependencies:

    pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
    

Note

This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.

Full code

You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub

bert_prune_distill_quantize.py
   1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from
   2# https://github.com/huggingface/transformers/blob/c52b515e/examples/pytorch/question-answering
   3#
   4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
   5#
   6# Licensed under the Apache License, Version 2.0 (the "License");
   7# you may not use this file except in compliance with the License.
   8# You may obtain a copy of the License at
   9#
  10#     http://www.apache.org/licenses/LICENSE-2.0
  11#
  12# Unless required by applicable law or agreed to in writing, software
  13# distributed under the License is distributed on an "AS IS" BASIS,
  14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15# See the License for the specific language governing permissions and
  16# limitations under the License.
  17
  18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  19# SPDX-License-Identifier: MIT
  20#
  21# Permission is hereby granted, free of charge, to any person obtaining a
  22# copy of this software and associated documentation files (the "Software"),
  23# to deal in the Software without restriction, including without limitation
  24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
  25# and/or sell copies of the Software, and to permit persons to whom the
  26# Software is furnished to do so, subject to the following conditions:
  27#
  28# The above copyright notice and this permission notice shall be included in
  29# all copies or substantial portions of the Software.
  30#
  31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  37# DEALINGS IN THE SOFTWARE.
  38"""
  39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
  40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
  41"""
  42
  43import argparse
  44import collections
  45import json
  46import logging
  47import math
  48import os
  49import random
  50from typing import Any, Dict, List, Optional, Tuple
  51
  52import datasets
  53import evaluate
  54import numpy as np
  55import torch
  56import torch.nn as nn
  57import transformers
  58from accelerate import Accelerator
  59from accelerate.logging import get_logger
  60from accelerate.utils import set_seed
  61from torch.utils.data import DataLoader
  62from tqdm.auto import tqdm
  63from transformers import (
  64    AutoModelForQuestionAnswering,
  65    AutoTokenizer,
  66    DataCollatorWithPadding,
  67    EvalPrediction,
  68    PreTrainedTokenizer,
  69    SchedulerType,
  70    default_data_collator,
  71    get_scheduler,
  72)
  73
  74# Model Optimizer: imports
  75import modelopt.torch.distill as mtd
  76import modelopt.torch.opt as mto
  77import modelopt.torch.prune as mtp
  78import modelopt.torch.quantization as mtq
  79from modelopt.torch._deploy.utils import get_onnx_bytes
  80
  81# Enable automatic save/load of modelopt_state with huggingface checkpointing
  82mto.enable_huggingface_checkpointing()
  83
  84logger = get_logger(__name__)
  85
  86SEED = 123
  87
  88
  89def parse_args(input_args: Optional[List[str]] = None):
  90    parser = argparse.ArgumentParser(
  91        description="Finetune a transformers model on a Question Answering task"
  92    )
  93
  94    # Training arguments
  95    parser.add_argument(
  96        "--model_name_or_path",
  97        type=str,
  98        default="bert-large-uncased-whole-word-masking-finetuned-squad",
  99        help="Path to pretrained model or model identifier from huggingface.co/models.",
 100    )
 101    parser.add_argument(
 102        "--do_train", action="store_true", help="Whether to run training / fine-tuning."
 103    )
 104    parser.add_argument(
 105        "--per_device_train_batch_size",
 106        type=int,
 107        default=16,
 108        help="Batch size (per device) for the training dataloader.",
 109    )
 110    parser.add_argument(
 111        "--per_device_eval_batch_size",
 112        type=int,
 113        default=64,
 114        help="Batch size (per device) for the evaluation dataloader.",
 115    )
 116    parser.add_argument(
 117        "--learning_rate",
 118        type=float,
 119        default=5e-5,
 120        help="Initial learning rate (after the potential warmup period) to use.",
 121    )
 122    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
 123    parser.add_argument(
 124        "--lr_scheduler_type",
 125        type=SchedulerType,
 126        default="linear",
 127        help="The scheduler type to use.",
 128        choices=[
 129            "linear",
 130            "cosine",
 131            "cosine_with_restarts",
 132            "polynomial",
 133            "constant",
 134            "constant_with_warmup",
 135        ],
 136    )
 137    parser.add_argument(
 138        "--num_warmup_steps",
 139        type=int,
 140        default=0,
 141        help="Number of steps for the warmup in the lr scheduler.",
 142    )
 143    parser.add_argument(
 144        "--num_train_epochs",
 145        type=float,
 146        default=2.0,
 147        help="Total number of training epochs to perform.",
 148    )
 149    parser.add_argument(
 150        "--max_train_steps",
 151        type=int,
 152        default=None,
 153        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
 154    )
 155    parser.add_argument(
 156        "--gradient_accumulation_steps",
 157        type=int,
 158        default=1,
 159        help="Number of updates steps to accumulate before performing a backward/update pass.",
 160    )
 161    parser.add_argument(
 162        "--preprocessing_num_workers",
 163        type=int,
 164        default=4,
 165        help="The number of processes to use for preprocessing the dataset.",
 166    )
 167
 168    # Logging and checkpointing arguments
 169    parser.add_argument(
 170        "--finetuned_model_path",
 171        type=str,
 172        default=None,
 173        help="Path to save the finetuned (pruned or quantized) model for restoring later with `.from_pretrained()`.",
 174    )
 175    parser.add_argument(
 176        "--with_tracking",
 177        action="store_true",
 178        help="Whether to enable experiment trackers for logging.",
 179    )
 180    parser.add_argument(
 181        "--checkpointing_steps",
 182        type=str,
 183        default="epoch",
 184        help=(
 185            "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
 186            " each epoch."
 187        ),
 188    )
 189    parser.add_argument(
 190        "--resume_from_last_ckpt",
 191        action="store_true",
 192        help="If the training should continue from the latest checkpoint in model_name_or_path.",
 193    )
 194    parser.add_argument(
 195        "--onnx_export_path", type=str, default=None, help="Path to export the ONNX model to."
 196    )
 197
 198    # Misc arguments for Bert (should not be modified in most cases)
 199    parser.add_argument(
 200        "--max_seq_length",
 201        type=int,
 202        default=384,
 203        help=(
 204            "The maximum total input sequence length after tokenization. Sequences longer than this"
 205            " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
 206        ),
 207    )
 208    parser.add_argument(
 209        "--pad_to_max_length",
 210        action="store_true",
 211        help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
 212    )
 213    parser.add_argument(
 214        "--doc_stride",
 215        type=int,
 216        default=128,
 217        help=(
 218            "When splitting up a long document into chunks how much stride to take between chunks."
 219        ),
 220    )
 221    parser.add_argument(
 222        "--n_best_size",
 223        type=int,
 224        default=20,
 225        help="The total number of n-best predictions to generate when looking for an answer.",
 226    )
 227    parser.add_argument(
 228        "--max_answer_length",
 229        type=int,
 230        default=30,
 231        help=(
 232            "The maximum length of an answer that can be generated. This is needed because the"
 233            " start and end predictions are not conditioned on one another."
 234        ),
 235    )
 236
 237    # Debugging arguments
 238    parser.add_argument(
 239        "--max_train_samples",
 240        type=int,
 241        default=None,
 242        help="For debugging purposes or quicker training.",
 243    )
 244    parser.add_argument(
 245        "--max_eval_samples",
 246        type=int,
 247        default=None,
 248        help="For debugging purposes or quicker training.",
 249    )
 250
 251    # Model Optimizer: pruning arguments
 252    parser.add_argument(
 253        "--do_modelopt_prune",
 254        action="store_true",
 255        help="Whether or not to use Model Optimizer pruning.",
 256    )
 257    parser.add_argument(
 258        "--modelopt_prune_flops_percent",
 259        type=float,
 260        default=None,
 261        help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
 262    )
 263    parser.add_argument(
 264        "--pruned_model_path",
 265        type=str,
 266        default=None,
 267        help="Path to save the pruned model for further finetuning.",
 268    )
 269
 270    # Model Optimizer: quantization arguments
 271    parser.add_argument(
 272        "--modelopt_quantize_cfg",
 273        help="Model Optimizer quantization config.",
 274        choices=mtq.config.choices,
 275    )
 276
 277    # Model Optimizer: Distillation arguments
 278    parser.add_argument(
 279        "--do_modelopt_distill",
 280        action="store_true",
 281        help="Whether or not to use distillation. A teacher model must be specified.",
 282    )
 283    parser.add_argument(
 284        "--temperature", type=float, default=2.0, help="The temperature to use when distilling."
 285    )
 286    parser.add_argument(
 287        "--ptq_model_path",
 288        type=str,
 289        default=None,
 290        help="Path to save the PTQ quantized model for further QAT.",
 291    )
 292
 293    args = parser.parse_args(input_args)
 294
 295    # Sanity checks
 296    if args.do_train and not args.finetuned_model_path:
 297        raise ValueError("`finetuned_model_path` required when `do_train` is passed.")
 298    if args.do_modelopt_prune and not (
 299        args.modelopt_prune_flops_percent and args.pruned_model_path
 300    ):
 301        raise ValueError(
 302            "`modelopt_prune_flops_percent` and `pruned_model_path` required when `do_modelopt_prune` is passed."
 303        )
 304    if args.modelopt_quantize_cfg and not args.ptq_model_path:
 305        raise ValueError("`ptq_model_path` required when `modelopt_quantize_cfg` is passed.")
 306
 307    return args
 308
 309
 310def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
 311    """Get the examples, dataset, dataloader, answer_column_name
 312
 313    You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 314    or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
 315    (the dataset will be downloaded automatically from the datasets Hub).
 316
 317    For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
 318    'text' is found. You can easily tweak this behavior (see below).
 319    """
 320
 321    def prepare_train_features(examples):
 322        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 323        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 324        # left whitespace
 325        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 326
 327        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 328        # in one example possible giving several features when a context is long, each of those features having a
 329        # context that overlaps a bit the context of the previous feature.
 330        tokenized_examples = tokenizer(
 331            examples[question_column_name if pad_on_right else context_column_name],
 332            examples[context_column_name if pad_on_right else question_column_name],
 333            truncation="only_second" if pad_on_right else "only_first",
 334            max_length=max_seq_length,
 335            stride=args.doc_stride,
 336            return_overflowing_tokens=True,
 337            return_offsets_mapping=True,
 338            padding="max_length" if args.pad_to_max_length else False,
 339        )
 340
 341        # Since one example might give us several features if it has a long context, we need a map from a feature to
 342        # its corresponding example. This key gives us just that.
 343        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 344        # The offset mappings will give us a map from token to character position in the original context. This will
 345        # help us compute the start_positions and end_positions.
 346        offset_mapping = tokenized_examples.pop("offset_mapping")
 347
 348        # Let's label those examples!
 349        tokenized_examples["start_positions"] = []
 350        tokenized_examples["end_positions"] = []
 351
 352        for i, offsets in enumerate(offset_mapping):
 353            # We will label impossible answers with the index of the CLS token.
 354            input_ids = tokenized_examples["input_ids"][i]
 355            cls_index = input_ids.index(tokenizer.cls_token_id)
 356
 357            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 358            sequence_ids = tokenized_examples.sequence_ids(i)
 359
 360            # One example can give several spans, this is the index of the example containing this span of text.
 361            sample_index = sample_mapping[i]
 362            answers = examples[answer_column_name][sample_index]
 363            # If no answers are given, set the cls_index as answer.
 364            if len(answers["answer_start"]) == 0:
 365                tokenized_examples["start_positions"].append(cls_index)
 366                tokenized_examples["end_positions"].append(cls_index)
 367            else:
 368                # Start/end character index of the answer in the text.
 369                start_char = answers["answer_start"][0]
 370                end_char = start_char + len(answers["text"][0])
 371
 372                # Start token index of the current span in the text.
 373                token_start_index = 0
 374                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
 375                    token_start_index += 1
 376
 377                # End token index of the current span in the text.
 378                token_end_index = len(input_ids) - 1
 379                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
 380                    token_end_index -= 1
 381
 382                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
 383                if not (
 384                    offsets[token_start_index][0] <= start_char
 385                    and offsets[token_end_index][1] >= end_char
 386                ):
 387                    tokenized_examples["start_positions"].append(cls_index)
 388                    tokenized_examples["end_positions"].append(cls_index)
 389                else:
 390                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
 391                    # Note: we could go after the last offset if the answer is the last word (edge case).
 392                    while (
 393                        token_start_index < len(offsets)
 394                        and offsets[token_start_index][0] <= start_char
 395                    ):
 396                        token_start_index += 1
 397                    tokenized_examples["start_positions"].append(token_start_index - 1)
 398                    while offsets[token_end_index][1] >= end_char:
 399                        token_end_index -= 1
 400                    tokenized_examples["end_positions"].append(token_end_index + 1)
 401
 402        return tokenized_examples
 403
 404    def prepare_validation_features(examples):
 405        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 406        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 407        # left whitespace
 408        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 409
 410        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 411        # in one example possible giving several features when a context is long, each of those features having a
 412        # context that overlaps a bit the context of the previous feature.
 413        tokenized_examples = tokenizer(
 414            examples[question_column_name if pad_on_right else context_column_name],
 415            examples[context_column_name if pad_on_right else question_column_name],
 416            truncation="only_second" if pad_on_right else "only_first",
 417            max_length=max_seq_length,
 418            stride=args.doc_stride,
 419            return_overflowing_tokens=True,
 420            return_offsets_mapping=True,
 421            padding="max_length" if args.pad_to_max_length else False,
 422        )
 423
 424        # Since one example might give us several features if it has a long context, we need a map from a feature to
 425        # its corresponding example. This key gives us just that.
 426        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 427
 428        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
 429        # corresponding example_id and we will store the offset mappings.
 430        tokenized_examples["example_id"] = []
 431
 432        for i in range(len(tokenized_examples["input_ids"])):
 433            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 434            sequence_ids = tokenized_examples.sequence_ids(i)
 435            context_index = 1 if pad_on_right else 0
 436
 437            # One example can give several spans, this is the index of the example containing this span of text.
 438            sample_index = sample_mapping[i]
 439            tokenized_examples["example_id"].append(examples["id"][sample_index])
 440
 441            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
 442            # position is part of the context or not.
 443            tokenized_examples["offset_mapping"][i] = [
 444                (o if sequence_ids[k] == context_index else None)
 445                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
 446            ]
 447
 448        return tokenized_examples
 449
 450    examples, dataset, dataloader = {}, {}, {}
 451
 452    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
 453    # download the dataset.
 454    # Downloading and loading a dataset from the hub.
 455    raw_datasets = datasets.load_dataset("squad")
 456    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
 457    # https://huggingface.co/docs/datasets/loading_datasets.
 458
 459    # Preprocessing the datasets.
 460    # Preprocessing is slighlty different for training and evaluation.
 461
 462    column_names = raw_datasets["train"].column_names
 463
 464    question_column_name = "question" if "question" in column_names else column_names[0]
 465    context_column_name = "context" if "context" in column_names else column_names[1]
 466    answer_column_name = "answers" if "answers" in column_names else column_names[2]
 467
 468    # Padding side determines if we do (question|context) or (context|question).
 469    pad_on_right = tokenizer.padding_side == "right"
 470
 471    if args.max_seq_length > tokenizer.model_max_length:
 472        logger.warning(
 473            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
 474            f" for the model ({tokenizer.model_max_length}). Using"
 475            f" max_seq_length={tokenizer.model_max_length}."
 476        )
 477
 478    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
 479
 480    examples["train"] = raw_datasets["train"]
 481    if args.max_train_samples is not None:
 482        # We will select sample from whole data if agument is specified
 483        examples["train"] = examples["train"].select(range(args.max_train_samples))
 484
 485    # Create train feature from dataset
 486    with accelerator.main_process_first():
 487        dataset["train"] = examples["train"].map(
 488            prepare_train_features,
 489            batched=True,
 490            num_proc=args.preprocessing_num_workers,
 491            remove_columns=column_names,
 492            load_from_cache_file=True,
 493            desc="Running tokenizer on train dataset",
 494        )
 495        # if args.max_train_samples is not None:
 496        #     # Number of samples might increase during Feature Creation, We select only specified max samples
 497        #     dataset["train"] = dataset["train"].select(range(args.max_train_samples))
 498
 499    examples["eval"] = raw_datasets["validation"]
 500    if args.max_eval_samples is not None:
 501        # We will select sample from whole data
 502        examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
 503    # Validation Feature Creation
 504    with accelerator.main_process_first():
 505        dataset["eval"] = examples["eval"].map(
 506            prepare_validation_features,
 507            batched=True,
 508            num_proc=args.preprocessing_num_workers,
 509            remove_columns=column_names,
 510            load_from_cache_file=True,
 511            desc="Running tokenizer on validation dataset",
 512        )
 513        # if args.max_eval_samples is not None:
 514        #     # During Feature creation dataset samples might increase, we will select required samples again
 515        #     dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
 516
 517    # Log a random sample from the training set:
 518    for index in random.sample(range(len(dataset["train"])), 1):
 519        logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
 520
 521    # DataLoaders creation:
 522    if args.pad_to_max_length:
 523        # If padding was already done ot max length, we use the default data collator that will just convert everything
 524        # to tensors.
 525        data_collator = default_data_collator
 526    else:
 527        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
 528        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
 529        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
 530        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
 531
 532    dataloader["train"] = DataLoader(
 533        dataset["train"],
 534        shuffle=True,
 535        collate_fn=data_collator,
 536        batch_size=args.per_device_train_batch_size,
 537    )
 538
 539    dataloader["eval"] = DataLoader(
 540        dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
 541        collate_fn=data_collator,
 542        batch_size=args.per_device_eval_batch_size,
 543    )
 544
 545    return examples, dataset, dataloader, answer_column_name
 546
 547
 548def evaluate_model(
 549    args,
 550    model: nn.Module,
 551    accelerator: Accelerator,
 552    eval_examples: Any,
 553    eval_dataset: Any,
 554    eval_dataloader: DataLoader,
 555    answer_column_name: str,
 556    prefix: str = "Eval",
 557):
 558    def create_and_fill_np_array(start_or_end_logits, max_len):
 559        """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
 560
 561        Args:
 562            start_or_end_logits: This is the output predictions of the model.
 563                We can only enter either start or end logits.
 564            max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
 565        """
 566        step = 0
 567        # create a numpy array and fill it with -100.
 568        logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
 569        # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
 570        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
 571            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
 572            # And after every iteration we have to change the step
 573            batch_size = output_logit.shape[0]
 574            cols = output_logit.shape[1]
 575
 576            if step + batch_size < len(eval_dataset):
 577                logits_concat[step : step + batch_size, :cols] = output_logit
 578            else:
 579                logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
 580
 581            step += batch_size
 582
 583        return logits_concat
 584
 585    def postprocess_qa_predictions(
 586        examples,
 587        features,
 588        predictions: Tuple[np.ndarray, np.ndarray],
 589        version_2_with_negative: bool = False,
 590        n_best_size: int = 20,
 591        max_answer_length: int = 30,
 592        null_score_diff_threshold: float = 0.0,
 593        output_dir: Optional[str] = None,
 594        prefix: Optional[str] = None,
 595    ) -> EvalPrediction:
 596        """Post-processes the predictions of a question-answering model to convert them to answers
 597        that are substrings of  the original contexts. This is the base postprocessing functions for
 598        models that only return start and end logits.
 599
 600        Args:
 601            examples: The non-preprocessed dataset.
 602            features: The processed dataset.
 603            predictions: The predictions of the model: two arrays containing the start logits and the end logits
 604                respectively. Its first dimension must match the number of elements of `features`.
 605            version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
 606            n_best_size: The total number of n-best predictions to generate when looking for an answer.
 607            max_answer_length: The maximum length of an answer that can be generated. This is needed
 608                because the start and end predictions are not conditioned on one another.
 609            null_score_diff_threshold: The threshold used to select the null answer: if the best answer
 610                has a score that is less than the score of the null answer minus this threshold, the
 611                null answer is selected for this example (note that the score of the null answer for
 612                an example giving several features is the minimum of the scores for the null answer on
 613                each feature: all features must be aligned on the fact they `want` to predict a null answer).
 614                Only useful when `version_2_with_negative` is `True`.
 615            output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
 616                and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
 617                answers, are saved in `output_dir`.
 618            prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
 619        """
 620        if len(predictions) != 2:
 621            raise ValueError(
 622                "`predictions` should be a tuple with two elements (start_logits, end_logits)."
 623            )
 624        all_start_logits, all_end_logits = predictions
 625
 626        if len(predictions[0]) != len(features):
 627            raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
 628
 629        # Build a map example to its corresponding features.
 630        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
 631        features_per_example = collections.defaultdict(list)
 632        for i, feature in enumerate(features):
 633            features_per_example[example_id_to_index[feature["example_id"]]].append(i)
 634
 635        # The dictionaries we have to fill.
 636        all_predictions = collections.OrderedDict()
 637        all_nbest_json = collections.OrderedDict()
 638        if version_2_with_negative:
 639            scores_diff_json = collections.OrderedDict()
 640
 641        logger.debug(
 642            f"Post-processing {len(examples)} example predictions split into"
 643            f" {len(features)} features."
 644        )
 645
 646        # Let's loop over all the examples!
 647        for example_index, example in enumerate(examples):
 648            # Those are the indices of the features associated to the current example.
 649            feature_indices = features_per_example[example_index]
 650
 651            min_null_prediction = None
 652            prelim_predictions = []
 653
 654            # Looping through all the features associated to the current example.
 655            for feature_index in feature_indices:
 656                # We grab the predictions of the model for this feature.
 657                start_logits = all_start_logits[feature_index]
 658                end_logits = all_end_logits[feature_index]
 659                # This is what will allow us to map some the positions in our logits to span of texts in the original
 660                # context.
 661                offset_mapping = features[feature_index]["offset_mapping"]
 662                # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
 663                # context available in the current feature.
 664                token_is_max_context = features[feature_index].get("token_is_max_context", None)
 665
 666                # Update minimum null prediction.
 667                feature_null_score = start_logits[0] + end_logits[0]
 668                if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
 669                    min_null_prediction = {
 670                        "offsets": (0, 0),
 671                        "score": feature_null_score,
 672                        "start_logit": start_logits[0],
 673                        "end_logit": end_logits[0],
 674                    }
 675
 676                # Go through all possibilities for the `n_best_size` greater start and end logits.
 677                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
 678                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
 679                for start_index in start_indexes:
 680                    for end_index in end_indexes:
 681                        # Don't consider out-of-scope answers, either because the indices are out of bounds or
 682                        # correspond to part of the input_ids that are not in the context.
 683                        if (
 684                            start_index >= len(offset_mapping)
 685                            or end_index >= len(offset_mapping)
 686                            or offset_mapping[start_index] is None
 687                            or len(offset_mapping[start_index]) < 2
 688                            or offset_mapping[end_index] is None
 689                            or len(offset_mapping[end_index]) < 2
 690                        ):
 691                            continue
 692                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
 693                        if (
 694                            end_index < start_index
 695                            or end_index - start_index + 1 > max_answer_length
 696                        ):
 697                            continue
 698                        # Don't consider answer that don't have the maximum context available (if such information is
 699                        # provided).
 700                        if token_is_max_context is not None and not token_is_max_context.get(
 701                            str(start_index), False
 702                        ):
 703                            continue
 704
 705                        prelim_predictions.append(
 706                            {
 707                                "offsets": (
 708                                    offset_mapping[start_index][0],
 709                                    offset_mapping[end_index][1],
 710                                ),
 711                                "score": start_logits[start_index] + end_logits[end_index],
 712                                "start_logit": start_logits[start_index],
 713                                "end_logit": end_logits[end_index],
 714                            }
 715                        )
 716            if version_2_with_negative and min_null_prediction is not None:
 717                # Add the minimum null prediction
 718                prelim_predictions.append(min_null_prediction)
 719                null_score = min_null_prediction["score"]
 720
 721            # Only keep the best `n_best_size` predictions.
 722            n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
 723                :n_best_size
 724            ]
 725
 726            # Add back the minimum null prediction if it was removed because of its low score.
 727            if (
 728                version_2_with_negative
 729                and min_null_prediction is not None
 730                and not any(p["offsets"] == (0, 0) for p in n_best_preds)
 731            ):
 732                n_best_preds.append(min_null_prediction)
 733
 734            # Use the offsets to gather the answer text in the original context.
 735            context = example["context"]
 736            for pred in n_best_preds:
 737                offsets = pred.pop("offsets")
 738                pred["text"] = context[offsets[0] : offsets[1]]
 739
 740            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
 741            # failure.
 742            if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
 743                n_best_preds.insert(
 744                    0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
 745                )
 746
 747            # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
 748            # using the LogSumExp trick).
 749            scores = np.array([pred.pop("score") for pred in n_best_preds])
 750            exp_scores = np.exp(scores - np.max(scores))
 751            probs = exp_scores / exp_scores.sum()
 752
 753            # Include the probabilities in our n_best_preds.
 754            for prob, pred in zip(probs, n_best_preds):
 755                pred["probability"] = prob
 756
 757            # Pick the best prediction. If the null answer is not possible, this is easy.
 758            if not version_2_with_negative:
 759                all_predictions[example["id"]] = n_best_preds[0]["text"]
 760            else:
 761                # Otherwise we first need to find the best non-empty prediction.
 762                i = 0
 763                while n_best_preds[i]["text"] == "":
 764                    i += 1
 765                best_non_null_pred = n_best_preds[i]
 766
 767                # Then we compare to the null prediction using the threshold.
 768                score_diff = (
 769                    null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
 770                )
 771                scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
 772                if score_diff > null_score_diff_threshold:
 773                    all_predictions[example["id"]] = ""
 774                else:
 775                    all_predictions[example["id"]] = best_non_null_pred["text"]
 776
 777            # Make `n_best_preds` JSON-serializable by casting np.float back to float.
 778            all_nbest_json[example["id"]] = [
 779                {
 780                    k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
 781                    for k, v in pred.items()
 782                }
 783                for pred in n_best_preds
 784            ]
 785
 786        # If we have an output_dir, let's save all those dicts.
 787        if output_dir is not None:
 788            if not os.path.isdir(output_dir):
 789                raise EnvironmentError(f"{output_dir} is not a directory.")
 790
 791            prediction_file = os.path.join(
 792                output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
 793            )
 794            nbest_file = os.path.join(
 795                output_dir,
 796                "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
 797            )
 798            if version_2_with_negative:
 799                null_odds_file = os.path.join(
 800                    output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
 801                )
 802
 803            logger.info(f"Saving predictions to {prediction_file}.")
 804            with open(prediction_file, "w") as writer:
 805                writer.write(json.dumps(all_predictions, indent=4) + "\n")
 806            logger.info(f"Saving nbest_preds to {nbest_file}.")
 807            with open(nbest_file, "w") as writer:
 808                writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
 809            if version_2_with_negative:
 810                logger.info(f"Saving null_odds to {null_odds_file}.")
 811                with open(null_odds_file, "w") as writer:
 812                    writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
 813
 814        # Format the result to the format the metric expects.
 815        formatted_predictions = [
 816            {"id": k, "prediction_text": v} for k, v in all_predictions.items()
 817        ]
 818
 819        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
 820        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
 821
 822    logger.info(f"***** Running {prefix} *****")
 823    logger.info(f"  Num examples = {len(eval_dataset)}")
 824    logger.info(f"  Batch size = {args.per_device_eval_batch_size}")
 825
 826    model.eval()
 827    all_start_logits = []
 828    all_end_logits = []
 829    for _, batch in enumerate(eval_dataloader):
 830        with torch.no_grad():
 831            outputs = model(**batch)
 832            start_logits = outputs.start_logits
 833            end_logits = outputs.end_logits
 834
 835            if (
 836                not args.pad_to_max_length
 837            ):  # necessary to pad predictions and labels for being gathered
 838                start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
 839                end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
 840
 841            all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
 842            all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
 843
 844    max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
 845
 846    # concatenate the numpy array
 847    start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
 848    end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
 849
 850    outputs_numpy = (start_logits_concat, end_logits_concat)
 851    prediction = postprocess_qa_predictions(
 852        examples=eval_examples,
 853        features=eval_dataset,
 854        predictions=outputs_numpy,
 855        n_best_size=args.n_best_size,
 856        max_answer_length=args.max_answer_length,
 857        # output_dir=args.finetuned_model_path,
 858        prefix=prefix,
 859    )
 860
 861    metric = evaluate.load("squad")
 862    eval_metric = metric.compute(
 863        predictions=prediction.predictions, references=prediction.label_ids
 864    )
 865    logger.info(f"{prefix} metrics: {eval_metric}\n")
 866    return eval_metric
 867
 868
 869# Model Optimizer: Define a teacher factory for initializing the distillation model
 870def teacher_factory(model_name_or_path):
 871    return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
 872
 873
 874# Model Optimizer: Define a custom distillation loss function that uses start and end logits
 875class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
 876    def forward(self, outputs_s, outputs_t):
 877        loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
 878        loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
 879        loss = (loss_start + loss_end) / 2.0
 880
 881        return loss
 882
 883
 884def train_and_evaluate_model(
 885    args,
 886    model: nn.Module,
 887    accelerator: Accelerator,
 888    examples: Dict,
 889    dataset: Dict,
 890    dataloader: Dict[str, DataLoader],
 891    answer_column_name,
 892):
 893    # Optimizer
 894    # Split weights in two groups, one with weight decay and the other not.
 895    no_decay = ["bias", "LayerNorm.weight"]
 896    optimizer_grouped_parameters = [
 897        {
 898            "params": [
 899                p
 900                for n, p in model.named_parameters()
 901                if p.requires_grad and not any(nd in n for nd in no_decay)
 902            ],
 903            "weight_decay": args.weight_decay,
 904        },
 905        {
 906            "params": [
 907                p
 908                for n, p in model.named_parameters()
 909                if p.requires_grad and any(nd in n for nd in no_decay)
 910            ],
 911            "weight_decay": 0.0,
 912        },
 913    ]
 914    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 915
 916    # Scheduler and math around the number of training steps.
 917    overrode_max_train_steps = False
 918    num_update_steps_per_epoch = math.ceil(
 919        len(dataloader["train"]) / args.gradient_accumulation_steps
 920    )
 921    if args.max_train_steps is None:
 922        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 923        overrode_max_train_steps = True
 924
 925    lr_scheduler = get_scheduler(
 926        name=args.lr_scheduler_type,
 927        optimizer=optimizer,
 928        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
 929        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
 930    )
 931
 932    # Prepare everything with our `accelerator`.
 933    model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
 934        model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
 935    )
 936
 937    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 938    num_update_steps_per_epoch = math.ceil(
 939        len(dataloader["train"]) / args.gradient_accumulation_steps
 940    )
 941    if overrode_max_train_steps:
 942        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 943    # Afterwards we recalculate our number of training epochs
 944    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 945
 946    # Figure out how many steps we should save the Accelerator states
 947    checkpointing_steps = args.checkpointing_steps
 948    if checkpointing_steps is not None and checkpointing_steps.isdigit():
 949        checkpointing_steps = int(checkpointing_steps)
 950
 951    # We need to initialize the trackers we use, and also store our configuration.
 952    # The trackers initializes automatically on the main process.
 953    if args.with_tracking:
 954        experiment_config = vars(args)
 955        # TensorBoard cannot log Enums, need the raw value
 956        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
 957        accelerator.init_trackers("tensorboard", experiment_config)
 958
 959    # Train!
 960    total_batch_size = (
 961        args.per_device_train_batch_size
 962        * accelerator.num_processes
 963        * args.gradient_accumulation_steps
 964    )
 965
 966    logger.info("***** Running training *****")
 967    logger.info(f"  Num examples = {len(dataset['train'])}")
 968    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 969    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
 970    logger.info(
 971        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
 972    )
 973    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 974    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 975
 976    # Only show the progress bar once on each machine.
 977    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
 978    completed_steps = 0
 979    starting_epoch = 0
 980    resume_step = None
 981
 982    # Potentially load in the weights and states from a previous save
 983    if args.resume_from_last_ckpt:
 984        # Get the most recent checkpoint
 985        dirs = [
 986            f.path
 987            for f in os.scandir(args.finetuned_model_path)
 988            if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
 989        ]
 990        if len(dirs) == 0:
 991            logger.warning(
 992                f"No checkpoint found in {args.finetuned_model_path}. Training from scratch!"
 993            )
 994        else:
 995            latest_dir = max(dirs, key=os.path.getctime)
 996            accelerator.load_state(latest_dir)
 997
 998            # Extract `epoch_{i}` or `step_{i}`
 999            latest_dir = os.path.basename(latest_dir)
1000            if "epoch" in latest_dir:
1001                starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1002                completed_steps = starting_epoch * num_update_steps_per_epoch
1003            else:
1004                # need to multiply `gradient_accumulation_steps` to reflect real steps
1005                resume_step = (
1006                    int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1007                )
1008                starting_epoch = resume_step // len(dataloader["train"])
1009                completed_steps = resume_step // args.gradient_accumulation_steps
1010                resume_step -= starting_epoch * len(dataloader["train"])
1011
1012    # update the progress_bar if load from checkpoint
1013    progress_bar.update(completed_steps)
1014
1015    # Evaluate before training (e.g. PTQ accuracy before QAT)
1016    eval_metric = evaluate_model(
1017        args,
1018        model,
1019        accelerator,
1020        examples["eval"],
1021        dataset["eval"],
1022        dataloader["eval"],
1023        answer_column_name,
1024    )
1025    for epoch in range(starting_epoch, args.num_train_epochs):
1026        model.train()
1027        if args.with_tracking:
1028            total_loss = 0
1029        if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1030            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1031            active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1032        else:
1033            active_dataloader = dataloader["train"]
1034        for batch in active_dataloader:
1035            optimizer.zero_grad()
1036            outputs = model(**batch)
1037
1038            # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1039            if args.do_modelopt_distill:
1040                loss = model.module.compute_kd_loss()
1041            else:
1042                loss, _, _ = outputs.to_tuple()
1043
1044            # We keep track of the loss at each epoch
1045            if args.with_tracking:
1046                total_loss += loss.detach().float()
1047
1048            accelerator.backward(loss)
1049            optimizer.step()
1050            lr_scheduler.step()
1051
1052            # Checks if the accelerator has performed an optimization step behind the scenes
1053            if accelerator.sync_gradients:
1054                progress_bar.update(1)
1055                completed_steps += 1
1056
1057            if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1058                accelerator.save_state(
1059                    os.path.join(args.finetuned_model_path, f"step_{completed_steps}")
1060                )
1061
1062            if completed_steps >= args.max_train_steps:
1063                break
1064
1065        if args.checkpointing_steps == "epoch":
1066            accelerator.save_state(os.path.join(args.finetuned_model_path, f"epoch_{epoch}"))
1067
1068        eval_metric = evaluate_model(
1069            args,
1070            model,
1071            accelerator,
1072            examples["eval"],
1073            dataset["eval"],
1074            dataloader["eval"],
1075            answer_column_name,
1076        )
1077
1078        if args.with_tracking:
1079            log = {
1080                "squad": eval_metric,
1081                "train_loss": total_loss.item() / len(dataloader["train"]),  # type: ignore[attr-defined]
1082                "epoch": epoch,
1083                "step": completed_steps,
1084            }
1085            accelerator.log(log, step=completed_steps)
1086
1087    accelerator.wait_for_everyone()
1088    if accelerator.is_main_process and eval_metric:
1089        logger.info(json.dumps(eval_metric, indent=4))
1090        # Prefix all keys with prefix + '_'
1091        for key in list(eval_metric.keys()):
1092            eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094        with open(os.path.join(args.finetuned_model_path, "results.json"), "w") as f:
1095            json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099    args = parse_args(input_args)
1100
1101    # Initialize the accelerator
1102    accelerator_log_kwargs = {}
1103    if args.with_tracking:
1104        accelerator_log_kwargs["log_with"] = "tensorboard"
1105        accelerator_log_kwargs["project_dir"] = args.finetuned_model_path
1106    accelerator = Accelerator(
1107        gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108    )
1109
1110    # Setup logging
1111    logging.basicConfig(
1112        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113        datefmt="%m/%d/%Y %H:%M:%S",
1114        level=logging.INFO,
1115    )
1116    logger.info(accelerator.state, main_process_only=False)
1117    if accelerator.is_local_main_process:
1118        datasets.utils.logging.set_verbosity_warning()
1119        transformers.utils.logging.set_verbosity_info()
1120    else:
1121        datasets.utils.logging.set_verbosity_error()
1122        transformers.utils.logging.set_verbosity_error()
1123
1124    # Set the training seed
1125    set_seed(SEED)
1126
1127    accelerator.wait_for_everyone()
1128
1129    # Load pretrained model and tokenizer
1130    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1131    model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1132    dummy_input = model.dummy_inputs["input_ids"]
1133
1134    # Get datasets
1135    examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1136        args, tokenizer, accelerator
1137    )
1138
1139    def save(model, output_path):
1140        if accelerator.is_main_process:
1141            os.makedirs(os.path.dirname(output_path), exist_ok=True)
1142            model = accelerator.unwrap_model(model)
1143            model.save_pretrained(output_path)
1144            tokenizer.save_pretrained(output_path)
1145            logger.info(f"Saved model and tokenizer to {output_path}")
1146
1147    # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1148    if args.do_modelopt_prune:
1149        logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1150
1151        # NOTE: gradnas does not perform synchronization across data parallel groups
1152        # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1153        # in the data parallel group have the same gradients for pruning
1154        model = model.to(accelerator.device)
1155        dummy_input = dummy_input.to(accelerator.device)
1156
1157        # Search for the best pruned model
1158        # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1159        model, _ = mtp.prune(
1160            model=model,
1161            mode="gradnas",
1162            constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1163            dummy_input=dummy_input,
1164            config={
1165                "data_loader": dataloader["train"],
1166                "collect_func": lambda batch: (batch,),
1167                "loss_func": lambda output, batch: (
1168                    output["loss"] if isinstance(output, dict) else output[0]
1169                ),
1170            },
1171        )
1172        save(model, args.pruned_model_path)
1173
1174    # Model Optimizer: Quantize the model to INT8 precision
1175    if args.modelopt_quantize_cfg:
1176        logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1177
1178        # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1179        # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1180        # in the data parallel group have same calibration statistics
1181        model = model.to(accelerator.device)
1182
1183        def forward_loop(model):
1184            num_samples = 256  # Use only 256 samples for PTQ calibration
1185            num_batches = num_samples // args.per_device_train_batch_size
1186            for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1187                batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1188                model(**batch)
1189                if idx >= num_batches:
1190                    break
1191
1192        model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1193        torch.cuda.empty_cache()
1194        save(model, args.ptq_model_path)
1195
1196    if args.do_train:
1197        # Handle the finetuned_model_path creation
1198        if accelerator.is_main_process:
1199            os.makedirs(args.finetuned_model_path, exist_ok=True)
1200
1201        # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1202        if args.do_modelopt_distill:
1203            logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1204
1205            kd_config = {
1206                "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1207                "criterion": StartEndLogitsDistillationLoss(args.temperature),
1208            }
1209            model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1210
1211        train_and_evaluate_model(
1212            args, model, accelerator, examples, dataset, dataloader, answer_column_name
1213        )
1214
1215        # Model Optimizer: Export the distilled model
1216        if args.do_modelopt_distill:
1217            model = mtd.export(model)
1218
1219        save(model, args.finetuned_model_path)
1220
1221    if accelerator.is_main_process and args.onnx_export_path is not None:
1222        logger.info(f"Exporting ONNX model to {args.onnx_export_path}")
1223
1224        # Move the model and dummy_input to the device
1225        model = model.to(accelerator.device)
1226        dummy_input = dummy_input.to(accelerator.device)
1227
1228        with open(args.onnx_export_path, "wb") as f:
1229            f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1230
1231    logger.info("Done!")
1232
1233
1234if __name__ == "__main__":
1235    main()

Commands

  1. First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.

    1_prune.sh
    #!/bin/bash
    set -ex
    
    MODEL_NAME_OR_PATH=bert-large-uncased-whole-word-masking-finetuned-squad
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    PRUNED_MODEL_PATH=${BASE_DIR}/pruned/
    FINETUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/
    
    modelopt_args=""
    if [ ! -d ${PRUNED_MODEL_PATH} ]; then
        modelopt_args="--do_modelopt_prune \
            --modelopt_prune_flops_percent ${FLOPS_PERCENT} \
            --pruned_model_path ${PRUNED_MODEL_PATH}"
    else
        MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH}
    fi
    
    # Run pruning followed by distributed fine-tuning on the pruned model
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path ${MODEL_NAME_OR_PATH} \
        --finetuned_model_path ${FINETUNED_MODEL_PATH} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 1e-4 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 15 \
        --with_tracking \
        --resume_from_last_ckpt
    
  2. Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.

    2_int8_quantize.sh
    #!/bin/bash
    set -ex
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    PRUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/
    PTQ_MODEL_PATH=${BASE_DIR}/int8_ptq/
    QAT_MODEL_PATH=${BASE_DIR}/int8_qat/
    
    modelopt_args=""
    if [ ! -d ${PTQ_MODEL_PATH} ]; then
        modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --ptq_model_path ${PTQ_MODEL_PATH}"
        MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH}
    else
        MODEL_NAME_OR_PATH=${PTQ_MODEL_PATH}
    fi
    
    # Run distributed QAT on the pruned model with 0.1x LR and less epochs
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path ${MODEL_NAME_OR_PATH} \
        --finetuned_model_path ${QAT_MODEL_PATH} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 5e-6 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 2 \
        --with_tracking \
        --resume_from_last_ckpt
    
  3. Export the quantized model to ONNX format for deployment with TensorRT.

    3_onnx_export.sh
    #!/bin/bash
    set -ex
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    # Export to ONNX on a single GPU
    python3 bert_prune_distill_quantize.py \
        --model_name_or_path ${BASE_DIR}/int8_qat/ \
        --onnx_export_path ${BASE_DIR}/pruned_model_int8.onnx \