HF BERT: Prune, Distill & Quantize

This example shows how to compress a Hugging Face Bert large model for Question Answering using the combination of modelopt.torch.prune, modelopt.torch.distill and modelopt.torch.quantize. More specifically, we will:

  1. Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation

  2. Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation

  3. Export the quantized model to ONNX format for deployment with TensorRT

Prerequisites

  1. Install the Model Optimizer and optional torch and huggingface dependencies:

    pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
    

Note

This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.

Full code

You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub

bert_prune_distill_quantize.py
   1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from https://github.com/huggingface/
   2# transformers/blob/c52b515e948fc12ff58ad773a0385860d0162f61/examples/pytorch/question-answering
   3#
   4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
   5#
   6# Licensed under the Apache License, Version 2.0 (the "License");
   7# you may not use this file except in compliance with the License.
   8# You may obtain a copy of the License at
   9#
  10#     http://www.apache.org/licenses/LICENSE-2.0
  11#
  12# Unless required by applicable law or agreed to in writing, software
  13# distributed under the License is distributed on an "AS IS" BASIS,
  14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15# See the License for the specific language governing permissions and
  16# limitations under the License.
  17
  18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  19# SPDX-License-Identifier: MIT
  20
  21# Permission is hereby granted, free of charge, to any person obtaining a
  22# copy of this software and associated documentation files (the "Software"),
  23# to deal in the Software without restriction, including without limitation
  24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
  25# and/or sell copies of the Software, and to permit persons to whom the
  26# Software is furnished to do so, subject to the following conditions:
  27
  28# The above copyright notice and this permission notice shall be included in
  29# all copies or substantial portions of the Software.
  30
  31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
  34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  37# DEALINGS IN THE SOFTWARE.
  38"""
  39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
  40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
  41"""
  42
  43import argparse
  44import collections
  45import json
  46import logging
  47import math
  48import os
  49import random
  50from typing import Any, Dict, List, Optional, Tuple
  51
  52import datasets
  53import evaluate
  54import numpy as np
  55import torch
  56import torch.nn as nn
  57import transformers
  58from accelerate import Accelerator
  59from accelerate.logging import get_logger
  60from accelerate.utils import set_seed
  61from torch.utils.data import DataLoader
  62from tqdm.auto import tqdm
  63from transformers import (
  64    AutoModelForQuestionAnswering,
  65    AutoTokenizer,
  66    DataCollatorWithPadding,
  67    EvalPrediction,
  68    PreTrainedTokenizer,
  69    SchedulerType,
  70    default_data_collator,
  71    get_scheduler,
  72)
  73
  74import modelopt.torch.distill as mtd
  75import modelopt.torch.opt as mto
  76import modelopt.torch.prune as mtp
  77import modelopt.torch.quantization as mtq
  78from modelopt.torch._deploy.utils import get_onnx_bytes
  79
  80logger = get_logger(__name__)
  81
  82SEED = 123
  83
  84
  85def parse_args(input_args: Optional[List[str]] = None):
  86    parser = argparse.ArgumentParser(
  87        description="Finetune a transformers model on a Question Answering task"
  88    )
  89
  90    # Training arguments
  91    parser.add_argument(
  92        "--model_name_or_path",
  93        type=str,
  94        default="bert-large-uncased-whole-word-masking-finetuned-squad",
  95        help="Path to pretrained model or model identifier from huggingface.co/models.",
  96    )
  97    parser.add_argument(
  98        "--do_train",
  99        action="store_true",
 100        help="Whether to run training / fine-tuning.",
 101    )
 102    parser.add_argument(
 103        "--per_device_train_batch_size",
 104        type=int,
 105        default=16,
 106        help="Batch size (per device) for the training dataloader.",
 107    )
 108    parser.add_argument(
 109        "--per_device_eval_batch_size",
 110        type=int,
 111        default=64,
 112        help="Batch size (per device) for the evaluation dataloader.",
 113    )
 114    parser.add_argument(
 115        "--learning_rate",
 116        type=float,
 117        default=5e-5,
 118        help="Initial learning rate (after the potential warmup period) to use.",
 119    )
 120    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
 121    parser.add_argument(
 122        "--lr_scheduler_type",
 123        type=SchedulerType,
 124        default="linear",
 125        help="The scheduler type to use.",
 126        choices=[
 127            "linear",
 128            "cosine",
 129            "cosine_with_restarts",
 130            "polynomial",
 131            "constant",
 132            "constant_with_warmup",
 133        ],
 134    )
 135    parser.add_argument(
 136        "--num_warmup_steps",
 137        type=int,
 138        default=0,
 139        help="Number of steps for the warmup in the lr scheduler.",
 140    )
 141    parser.add_argument(
 142        "--num_train_epochs",
 143        type=float,
 144        default=2.0,
 145        help="Total number of training epochs to perform.",
 146    )
 147    parser.add_argument(
 148        "--max_train_steps",
 149        type=int,
 150        default=None,
 151        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
 152    )
 153    parser.add_argument(
 154        "--gradient_accumulation_steps",
 155        type=int,
 156        default=1,
 157        help="Number of updates steps to accumulate before performing a backward/update pass.",
 158    )
 159    parser.add_argument(
 160        "--preprocessing_num_workers",
 161        type=int,
 162        default=4,
 163        help="The number of processes to use for preprocessing the dataset.",
 164    )
 165
 166    # Logging and checkpointing arguments
 167    parser.add_argument(
 168        "--output_dir", type=str, default="results", help="Where to store the optimized models."
 169    )
 170    parser.add_argument(
 171        "--with_tracking",
 172        action="store_true",
 173        help="Whether to enable experiment trackers for logging.",
 174    )
 175    parser.add_argument(
 176        "--checkpointing_steps",
 177        type=str,
 178        default=None,
 179        help=(
 180            "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
 181            " each epoch."
 182        ),
 183    )
 184    parser.add_argument(
 185        "--resume_from_last_ckpt",
 186        action="store_true",
 187        help="If the training should continue from the latest checkpoint in output_dir.",
 188    )
 189
 190    # Misc arguments for Bert (should not be modified in most cases)
 191    parser.add_argument(
 192        "--max_seq_length",
 193        type=int,
 194        default=384,
 195        help=(
 196            "The maximum total input sequence length after tokenization. Sequences longer than this"
 197            " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
 198        ),
 199    )
 200    parser.add_argument(
 201        "--pad_to_max_length",
 202        action="store_true",
 203        help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
 204    )
 205    parser.add_argument(
 206        "--doc_stride",
 207        type=int,
 208        default=128,
 209        help=(
 210            "When splitting up a long document into chunks how much stride to take between chunks."
 211        ),
 212    )
 213    parser.add_argument(
 214        "--n_best_size",
 215        type=int,
 216        default=20,
 217        help="The total number of n-best predictions to generate when looking for an answer.",
 218    )
 219    parser.add_argument(
 220        "--max_answer_length",
 221        type=int,
 222        default=30,
 223        help=(
 224            "The maximum length of an answer that can be generated. This is needed because the"
 225            " start and end predictions are not conditioned on one another."
 226        ),
 227    )
 228
 229    # Debugging arguments
 230    parser.add_argument(
 231        "--max_train_samples",
 232        type=int,
 233        default=None,
 234        help="For debugging purposes or quicker training.",
 235    )
 236    parser.add_argument(
 237        "--max_eval_samples",
 238        type=int,
 239        default=None,
 240        help="For debugging purposes or quicker training.",
 241    )
 242
 243    # Model Optimizer: pruning arguments
 244    parser.add_argument(
 245        "--do_modelopt_prune",
 246        action="store_true",
 247        help="Whether or not to use Model Optimizer pruning.",
 248    )
 249    parser.add_argument(
 250        "--modelopt_prune_flops_percent",
 251        type=float,
 252        default=None,
 253        help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
 254    )
 255
 256    # Model Optimizer: quantization arguments
 257    parser.add_argument(
 258        "--modelopt_quantize_cfg",
 259        help="Model Optimizer quantization config.",
 260        choices=mtq.config.choices,
 261    )
 262
 263    # Model Optimizer: Distillation arguments
 264    parser.add_argument(
 265        "--do_modelopt_distill",
 266        action="store_true",
 267        help="Whether or not to use distillation. A teacher model must be specified.",
 268    )
 269    parser.add_argument(
 270        "--temperature",
 271        type=float,
 272        default=2.0,
 273        help="The temperature to use when distilling.",
 274    )
 275
 276    # Model Optimizer: save and restore arguments
 277    parser.add_argument(
 278        "--modelopt_save_file",
 279        type=str,
 280        default=None,
 281        help="File (inside output_dir) to save the modelopt modified model to.",
 282    )
 283    parser.add_argument(
 284        "--modelopt_restore_path",
 285        type=str,
 286        default=None,
 287        help="Path to restore the modelopt modified model from.",
 288    )
 289
 290    # ONNX export arguments
 291    parser.add_argument(
 292        "--onnx_export_file",
 293        type=str,
 294        default=None,
 295        help="File (inside output_dir) to export the ONNX model to.",
 296    )
 297
 298    args = parser.parse_args(input_args)
 299
 300    # Sanity checks
 301    if args.do_modelopt_prune and not args.modelopt_prune_flops_percent:
 302        raise ValueError(
 303            "Need a `modelopt_prune_flops_percent` when `do_modelopt_prune` is passed."
 304        )
 305
 306    return args
 307
 308
 309def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
 310    """Get the examples, dataset, dataloader, answer_column_name
 311
 312    You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 313    or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
 314    (the dataset will be downloaded automatically from the datasets Hub).
 315
 316    For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
 317    'text' is found. You can easily tweak this behavior (see below).
 318    """
 319
 320    def prepare_train_features(examples):
 321        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 322        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 323        # left whitespace
 324        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 325
 326        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 327        # in one example possible giving several features when a context is long, each of those features having a
 328        # context that overlaps a bit the context of the previous feature.
 329        tokenized_examples = tokenizer(
 330            examples[question_column_name if pad_on_right else context_column_name],
 331            examples[context_column_name if pad_on_right else question_column_name],
 332            truncation="only_second" if pad_on_right else "only_first",
 333            max_length=max_seq_length,
 334            stride=args.doc_stride,
 335            return_overflowing_tokens=True,
 336            return_offsets_mapping=True,
 337            padding="max_length" if args.pad_to_max_length else False,
 338        )
 339
 340        # Since one example might give us several features if it has a long context, we need a map from a feature to
 341        # its corresponding example. This key gives us just that.
 342        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 343        # The offset mappings will give us a map from token to character position in the original context. This will
 344        # help us compute the start_positions and end_positions.
 345        offset_mapping = tokenized_examples.pop("offset_mapping")
 346
 347        # Let's label those examples!
 348        tokenized_examples["start_positions"] = []
 349        tokenized_examples["end_positions"] = []
 350
 351        for i, offsets in enumerate(offset_mapping):
 352            # We will label impossible answers with the index of the CLS token.
 353            input_ids = tokenized_examples["input_ids"][i]
 354            cls_index = input_ids.index(tokenizer.cls_token_id)
 355
 356            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 357            sequence_ids = tokenized_examples.sequence_ids(i)
 358
 359            # One example can give several spans, this is the index of the example containing this span of text.
 360            sample_index = sample_mapping[i]
 361            answers = examples[answer_column_name][sample_index]
 362            # If no answers are given, set the cls_index as answer.
 363            if len(answers["answer_start"]) == 0:
 364                tokenized_examples["start_positions"].append(cls_index)
 365                tokenized_examples["end_positions"].append(cls_index)
 366            else:
 367                # Start/end character index of the answer in the text.
 368                start_char = answers["answer_start"][0]
 369                end_char = start_char + len(answers["text"][0])
 370
 371                # Start token index of the current span in the text.
 372                token_start_index = 0
 373                while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
 374                    token_start_index += 1
 375
 376                # End token index of the current span in the text.
 377                token_end_index = len(input_ids) - 1
 378                while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
 379                    token_end_index -= 1
 380
 381                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
 382                if not (
 383                    offsets[token_start_index][0] <= start_char
 384                    and offsets[token_end_index][1] >= end_char
 385                ):
 386                    tokenized_examples["start_positions"].append(cls_index)
 387                    tokenized_examples["end_positions"].append(cls_index)
 388                else:
 389                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
 390                    # Note: we could go after the last offset if the answer is the last word (edge case).
 391                    while (
 392                        token_start_index < len(offsets)
 393                        and offsets[token_start_index][0] <= start_char
 394                    ):
 395                        token_start_index += 1
 396                    tokenized_examples["start_positions"].append(token_start_index - 1)
 397                    while offsets[token_end_index][1] >= end_char:
 398                        token_end_index -= 1
 399                    tokenized_examples["end_positions"].append(token_end_index + 1)
 400
 401        return tokenized_examples
 402
 403    def prepare_validation_features(examples):
 404        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
 405        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
 406        # left whitespace
 407        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
 408
 409        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
 410        # in one example possible giving several features when a context is long, each of those features having a
 411        # context that overlaps a bit the context of the previous feature.
 412        tokenized_examples = tokenizer(
 413            examples[question_column_name if pad_on_right else context_column_name],
 414            examples[context_column_name if pad_on_right else question_column_name],
 415            truncation="only_second" if pad_on_right else "only_first",
 416            max_length=max_seq_length,
 417            stride=args.doc_stride,
 418            return_overflowing_tokens=True,
 419            return_offsets_mapping=True,
 420            padding="max_length" if args.pad_to_max_length else False,
 421        )
 422
 423        # Since one example might give us several features if it has a long context, we need a map from a feature to
 424        # its corresponding example. This key gives us just that.
 425        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
 426
 427        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
 428        # corresponding example_id and we will store the offset mappings.
 429        tokenized_examples["example_id"] = []
 430
 431        for i in range(len(tokenized_examples["input_ids"])):
 432            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
 433            sequence_ids = tokenized_examples.sequence_ids(i)
 434            context_index = 1 if pad_on_right else 0
 435
 436            # One example can give several spans, this is the index of the example containing this span of text.
 437            sample_index = sample_mapping[i]
 438            tokenized_examples["example_id"].append(examples["id"][sample_index])
 439
 440            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
 441            # position is part of the context or not.
 442            tokenized_examples["offset_mapping"][i] = [
 443                (o if sequence_ids[k] == context_index else None)
 444                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
 445            ]
 446
 447        return tokenized_examples
 448
 449    examples, dataset, dataloader = {}, {}, {}
 450
 451    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
 452    # download the dataset.
 453    # Downloading and loading a dataset from the hub.
 454    raw_datasets = datasets.load_dataset("squad")
 455    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
 456    # https://huggingface.co/docs/datasets/loading_datasets.
 457
 458    # Preprocessing the datasets.
 459    # Preprocessing is slighlty different for training and evaluation.
 460
 461    column_names = raw_datasets["train"].column_names
 462
 463    question_column_name = "question" if "question" in column_names else column_names[0]
 464    context_column_name = "context" if "context" in column_names else column_names[1]
 465    answer_column_name = "answers" if "answers" in column_names else column_names[2]
 466
 467    # Padding side determines if we do (question|context) or (context|question).
 468    pad_on_right = tokenizer.padding_side == "right"
 469
 470    if args.max_seq_length > tokenizer.model_max_length:
 471        logger.warning(
 472            f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
 473            f" for the model ({tokenizer.model_max_length}). Using"
 474            f" max_seq_length={tokenizer.model_max_length}."
 475        )
 476
 477    max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
 478
 479    examples["train"] = raw_datasets["train"]
 480    if args.max_train_samples is not None:
 481        # We will select sample from whole data if agument is specified
 482        examples["train"] = examples["train"].select(range(args.max_train_samples))
 483
 484    # Create train feature from dataset
 485    with accelerator.main_process_first():
 486        dataset["train"] = examples["train"].map(
 487            prepare_train_features,
 488            batched=True,
 489            num_proc=args.preprocessing_num_workers,
 490            remove_columns=column_names,
 491            load_from_cache_file=True,
 492            desc="Running tokenizer on train dataset",
 493        )
 494        # if args.max_train_samples is not None:
 495        #     # Number of samples might increase during Feature Creation, We select only specified max samples
 496        #     dataset["train"] = dataset["train"].select(range(args.max_train_samples))
 497
 498    examples["eval"] = raw_datasets["validation"]
 499    if args.max_eval_samples is not None:
 500        # We will select sample from whole data
 501        examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
 502    # Validation Feature Creation
 503    with accelerator.main_process_first():
 504        dataset["eval"] = examples["eval"].map(
 505            prepare_validation_features,
 506            batched=True,
 507            num_proc=args.preprocessing_num_workers,
 508            remove_columns=column_names,
 509            load_from_cache_file=True,
 510            desc="Running tokenizer on validation dataset",
 511        )
 512        # if args.max_eval_samples is not None:
 513        #     # During Feature creation dataset samples might increase, we will select required samples again
 514        #     dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
 515
 516    # Log a random sample from the training set:
 517    for index in random.sample(range(len(dataset["train"])), 1):
 518        logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
 519
 520    # DataLoaders creation:
 521    if args.pad_to_max_length:
 522        # If padding was already done ot max length, we use the default data collator that will just convert everything
 523        # to tensors.
 524        data_collator = default_data_collator
 525    else:
 526        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
 527        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
 528        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
 529        data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
 530
 531    dataloader["train"] = DataLoader(
 532        dataset["train"],
 533        shuffle=True,
 534        collate_fn=data_collator,
 535        batch_size=args.per_device_train_batch_size,
 536    )
 537
 538    dataloader["eval"] = DataLoader(
 539        dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
 540        collate_fn=data_collator,
 541        batch_size=args.per_device_eval_batch_size,
 542    )
 543
 544    return examples, dataset, dataloader, answer_column_name
 545
 546
 547def evaluate_model(
 548    args,
 549    model: nn.Module,
 550    accelerator: Accelerator,
 551    eval_examples: Any,
 552    eval_dataset: Any,
 553    eval_dataloader: DataLoader,
 554    answer_column_name: str,
 555    prefix: str = "Eval",
 556):
 557    def create_and_fill_np_array(start_or_end_logits, max_len):
 558        """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
 559
 560        Args:
 561            start_or_end_logits: This is the output predictions of the model.
 562                We can only enter either start or end logits.
 563            max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
 564        """
 565        step = 0
 566        # create a numpy array and fill it with -100.
 567        logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
 568        # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
 569        for i, output_logit in enumerate(start_or_end_logits):  # populate columns
 570            # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
 571            # And after every iteration we have to change the step
 572            batch_size = output_logit.shape[0]
 573            cols = output_logit.shape[1]
 574
 575            if step + batch_size < len(eval_dataset):
 576                logits_concat[step : step + batch_size, :cols] = output_logit
 577            else:
 578                logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
 579
 580            step += batch_size
 581
 582        return logits_concat
 583
 584    def postprocess_qa_predictions(
 585        examples,
 586        features,
 587        predictions: Tuple[np.ndarray, np.ndarray],
 588        version_2_with_negative: bool = False,
 589        n_best_size: int = 20,
 590        max_answer_length: int = 30,
 591        null_score_diff_threshold: float = 0.0,
 592        output_dir: Optional[str] = None,
 593        prefix: Optional[str] = None,
 594    ) -> EvalPrediction:
 595        """Post-processes the predictions of a question-answering model to convert them to answers
 596        that are substrings of  the original contexts. This is the base postprocessing functions for
 597        models that only return start and end logits.
 598
 599        Args:
 600            examples: The non-preprocessed dataset.
 601            features: The processed dataset.
 602            predictions: The predictions of the model: two arrays containing the start logits and the end logits
 603                respectively. Its first dimension must match the number of elements of `features`.
 604            version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
 605            n_best_size: The total number of n-best predictions to generate when looking for an answer.
 606            max_answer_length: The maximum length of an answer that can be generated. This is needed
 607                because the start and end predictions are not conditioned on one another.
 608            null_score_diff_threshold: The threshold used to select the null answer: if the best answer
 609                has a score that is less than the score of the null answer minus this threshold, the
 610                null answer is selected for this example (note that the score of the null answer for
 611                an example giving several features is the minimum of the scores for the null answer on
 612                each feature: all features must be aligned on the fact they `want` to predict a null answer).
 613                Only useful when `version_2_with_negative` is `True`.
 614            output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
 615                and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
 616                answers, are saved in `output_dir`.
 617            prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
 618        """
 619        if len(predictions) != 2:
 620            raise ValueError(
 621                "`predictions` should be a tuple with two elements (start_logits, end_logits)."
 622            )
 623        all_start_logits, all_end_logits = predictions
 624
 625        if len(predictions[0]) != len(features):
 626            raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
 627
 628        # Build a map example to its corresponding features.
 629        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
 630        features_per_example = collections.defaultdict(list)
 631        for i, feature in enumerate(features):
 632            features_per_example[example_id_to_index[feature["example_id"]]].append(i)
 633
 634        # The dictionaries we have to fill.
 635        all_predictions = collections.OrderedDict()
 636        all_nbest_json = collections.OrderedDict()
 637        if version_2_with_negative:
 638            scores_diff_json = collections.OrderedDict()
 639
 640        logger.debug(
 641            f"Post-processing {len(examples)} example predictions split into"
 642            f" {len(features)} features."
 643        )
 644
 645        # Let's loop over all the examples!
 646        for example_index, example in enumerate(examples):
 647            # Those are the indices of the features associated to the current example.
 648            feature_indices = features_per_example[example_index]
 649
 650            min_null_prediction = None
 651            prelim_predictions = []
 652
 653            # Looping through all the features associated to the current example.
 654            for feature_index in feature_indices:
 655                # We grab the predictions of the model for this feature.
 656                start_logits = all_start_logits[feature_index]
 657                end_logits = all_end_logits[feature_index]
 658                # This is what will allow us to map some the positions in our logits to span of texts in the original
 659                # context.
 660                offset_mapping = features[feature_index]["offset_mapping"]
 661                # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
 662                # context available in the current feature.
 663                token_is_max_context = features[feature_index].get("token_is_max_context", None)
 664
 665                # Update minimum null prediction.
 666                feature_null_score = start_logits[0] + end_logits[0]
 667                if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
 668                    min_null_prediction = {
 669                        "offsets": (0, 0),
 670                        "score": feature_null_score,
 671                        "start_logit": start_logits[0],
 672                        "end_logit": end_logits[0],
 673                    }
 674
 675                # Go through all possibilities for the `n_best_size` greater start and end logits.
 676                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
 677                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
 678                for start_index in start_indexes:
 679                    for end_index in end_indexes:
 680                        # Don't consider out-of-scope answers, either because the indices are out of bounds or
 681                        # correspond to part of the input_ids that are not in the context.
 682                        if (
 683                            start_index >= len(offset_mapping)
 684                            or end_index >= len(offset_mapping)
 685                            or offset_mapping[start_index] is None
 686                            or len(offset_mapping[start_index]) < 2
 687                            or offset_mapping[end_index] is None
 688                            or len(offset_mapping[end_index]) < 2
 689                        ):
 690                            continue
 691                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
 692                        if (
 693                            end_index < start_index
 694                            or end_index - start_index + 1 > max_answer_length
 695                        ):
 696                            continue
 697                        # Don't consider answer that don't have the maximum context available (if such information is
 698                        # provided).
 699                        if token_is_max_context is not None and not token_is_max_context.get(
 700                            str(start_index), False
 701                        ):
 702                            continue
 703
 704                        prelim_predictions.append(
 705                            {
 706                                "offsets": (
 707                                    offset_mapping[start_index][0],
 708                                    offset_mapping[end_index][1],
 709                                ),
 710                                "score": start_logits[start_index] + end_logits[end_index],
 711                                "start_logit": start_logits[start_index],
 712                                "end_logit": end_logits[end_index],
 713                            }
 714                        )
 715            if version_2_with_negative and min_null_prediction is not None:
 716                # Add the minimum null prediction
 717                prelim_predictions.append(min_null_prediction)
 718                null_score = min_null_prediction["score"]
 719
 720            # Only keep the best `n_best_size` predictions.
 721            n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
 722                :n_best_size
 723            ]
 724
 725            # Add back the minimum null prediction if it was removed because of its low score.
 726            if (
 727                version_2_with_negative
 728                and min_null_prediction is not None
 729                and not any(p["offsets"] == (0, 0) for p in n_best_preds)
 730            ):
 731                n_best_preds.append(min_null_prediction)
 732
 733            # Use the offsets to gather the answer text in the original context.
 734            context = example["context"]
 735            for pred in n_best_preds:
 736                offsets = pred.pop("offsets")
 737                pred["text"] = context[offsets[0] : offsets[1]]
 738
 739            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
 740            # failure.
 741            if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
 742                n_best_preds.insert(
 743                    0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
 744                )
 745
 746            # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
 747            # using the LogSumExp trick).
 748            scores = np.array([pred.pop("score") for pred in n_best_preds])
 749            exp_scores = np.exp(scores - np.max(scores))
 750            probs = exp_scores / exp_scores.sum()
 751
 752            # Include the probabilities in our n_best_preds.
 753            for prob, pred in zip(probs, n_best_preds):
 754                pred["probability"] = prob
 755
 756            # Pick the best prediction. If the null answer is not possible, this is easy.
 757            if not version_2_with_negative:
 758                all_predictions[example["id"]] = n_best_preds[0]["text"]
 759            else:
 760                # Otherwise we first need to find the best non-empty prediction.
 761                i = 0
 762                while n_best_preds[i]["text"] == "":
 763                    i += 1
 764                best_non_null_pred = n_best_preds[i]
 765
 766                # Then we compare to the null prediction using the threshold.
 767                score_diff = (
 768                    null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
 769                )
 770                scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
 771                if score_diff > null_score_diff_threshold:
 772                    all_predictions[example["id"]] = ""
 773                else:
 774                    all_predictions[example["id"]] = best_non_null_pred["text"]
 775
 776            # Make `n_best_preds` JSON-serializable by casting np.float back to float.
 777            all_nbest_json[example["id"]] = [
 778                {
 779                    k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
 780                    for k, v in pred.items()
 781                }
 782                for pred in n_best_preds
 783            ]
 784
 785        # If we have an output_dir, let's save all those dicts.
 786        if output_dir is not None:
 787            if not os.path.isdir(output_dir):
 788                raise EnvironmentError(f"{output_dir} is not a directory.")
 789
 790            prediction_file = os.path.join(
 791                output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
 792            )
 793            nbest_file = os.path.join(
 794                output_dir,
 795                "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
 796            )
 797            if version_2_with_negative:
 798                null_odds_file = os.path.join(
 799                    output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
 800                )
 801
 802            logger.info(f"Saving predictions to {prediction_file}.")
 803            with open(prediction_file, "w") as writer:
 804                writer.write(json.dumps(all_predictions, indent=4) + "\n")
 805            logger.info(f"Saving nbest_preds to {nbest_file}.")
 806            with open(nbest_file, "w") as writer:
 807                writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
 808            if version_2_with_negative:
 809                logger.info(f"Saving null_odds to {null_odds_file}.")
 810                with open(null_odds_file, "w") as writer:
 811                    writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
 812
 813        # Format the result to the format the metric expects.
 814        formatted_predictions = [
 815            {"id": k, "prediction_text": v} for k, v in all_predictions.items()
 816        ]
 817
 818        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
 819        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
 820
 821    logger.info(f"***** Running {prefix} *****")
 822    logger.info(f"  Num examples = {len(eval_dataset)}")
 823    logger.info(f"  Batch size = {args.per_device_eval_batch_size}")
 824
 825    model.eval()
 826    all_start_logits = []
 827    all_end_logits = []
 828    for _, batch in enumerate(eval_dataloader):
 829        with torch.no_grad():
 830            outputs = model(**batch)
 831            start_logits = outputs.start_logits
 832            end_logits = outputs.end_logits
 833
 834            if (
 835                not args.pad_to_max_length
 836            ):  # necessary to pad predictions and labels for being gathered
 837                start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
 838                end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
 839
 840            all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
 841            all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
 842
 843    # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
 844    if args.do_modelopt_distill:
 845        model.module.compute_kd_loss()
 846
 847    max_len = max([x.shape[1] for x in all_start_logits])  # Get the max_length of the tensor
 848
 849    # concatenate the numpy array
 850    start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
 851    end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
 852
 853    outputs_numpy = (start_logits_concat, end_logits_concat)
 854    prediction = postprocess_qa_predictions(
 855        examples=eval_examples,
 856        features=eval_dataset,
 857        predictions=outputs_numpy,
 858        n_best_size=args.n_best_size,
 859        max_answer_length=args.max_answer_length,
 860        # output_dir=args.output_dir,
 861        prefix=prefix,
 862    )
 863
 864    metric = evaluate.load("squad")
 865    eval_metric = metric.compute(
 866        predictions=prediction.predictions, references=prediction.label_ids
 867    )
 868    logger.info(f"{prefix} metrics: {eval_metric}\n")
 869    return eval_metric
 870
 871
 872# Model Optimizer: Define a teacher factory for initializing the distillation model
 873def teacher_factory(model_name_or_path):
 874    return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
 875
 876
 877# Model Optimizer: Define a custom distillation loss function that uses start and end logits
 878class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
 879    def forward(self, outputs_s, outputs_t):
 880        loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
 881        loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
 882        loss = (loss_start + loss_end) / 2.0
 883
 884        return loss
 885
 886
 887def train_and_evaluate_model(
 888    args,
 889    model: nn.Module,
 890    tokenizer: PreTrainedTokenizer,
 891    accelerator: Accelerator,
 892    examples: Dict,
 893    dataset: Dict,
 894    dataloader: Dict[str, DataLoader],
 895    answer_column_name,
 896):
 897    # Optimizer
 898    # Split weights in two groups, one with weight decay and the other not.
 899    no_decay = ["bias", "LayerNorm.weight"]
 900    optimizer_grouped_parameters = [
 901        {
 902            "params": [
 903                p
 904                for n, p in model.named_parameters()
 905                if p.requires_grad and not any(nd in n for nd in no_decay)
 906            ],
 907            "weight_decay": args.weight_decay,
 908        },
 909        {
 910            "params": [
 911                p
 912                for n, p in model.named_parameters()
 913                if p.requires_grad and any(nd in n for nd in no_decay)
 914            ],
 915            "weight_decay": 0.0,
 916        },
 917    ]
 918    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
 919
 920    # Scheduler and math around the number of training steps.
 921    overrode_max_train_steps = False
 922    num_update_steps_per_epoch = math.ceil(
 923        len(dataloader["train"]) / args.gradient_accumulation_steps
 924    )
 925    if args.max_train_steps is None:
 926        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 927        overrode_max_train_steps = True
 928
 929    lr_scheduler = get_scheduler(
 930        name=args.lr_scheduler_type,
 931        optimizer=optimizer,
 932        num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
 933        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
 934    )
 935
 936    # Prepare everything with our `accelerator`.
 937    model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
 938        model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
 939    )
 940
 941    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 942    num_update_steps_per_epoch = math.ceil(
 943        len(dataloader["train"]) / args.gradient_accumulation_steps
 944    )
 945    if overrode_max_train_steps:
 946        args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
 947    # Afterwards we recalculate our number of training epochs
 948    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 949
 950    # Figure out how many steps we should save the Accelerator states
 951    checkpointing_steps = args.checkpointing_steps
 952    if checkpointing_steps is not None and checkpointing_steps.isdigit():
 953        checkpointing_steps = int(checkpointing_steps)
 954
 955    # We need to initialize the trackers we use, and also store our configuration.
 956    # The trackers initializes automatically on the main process.
 957    if args.with_tracking:
 958        experiment_config = vars(args)
 959        # TensorBoard cannot log Enums, need the raw value
 960        experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
 961        accelerator.init_trackers("tensorboard", experiment_config)
 962
 963    # Train!
 964    total_batch_size = (
 965        args.per_device_train_batch_size
 966        * accelerator.num_processes
 967        * args.gradient_accumulation_steps
 968    )
 969
 970    logger.info("***** Running training *****")
 971    logger.info(f"  Num examples = {len(dataset['train'])}")
 972    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 973    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
 974    logger.info(
 975        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
 976    )
 977    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 978    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 979
 980    # Only show the progress bar once on each machine.
 981    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
 982    completed_steps = 0
 983    starting_epoch = 0
 984    resume_step = None
 985
 986    # Potentially load in the weights and states from a previous save
 987    if args.resume_from_last_ckpt:
 988        # Get the most recent checkpoint
 989        dirs = [
 990            f.path
 991            for f in os.scandir(args.output_dir)
 992            if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
 993        ]
 994        if len(dirs) == 0:
 995            logger.warning("No checkpoint found in output_dir. Training from scratch!")
 996        else:
 997            latest_dir = max(dirs, key=os.path.getctime)
 998            accelerator.load_state(latest_dir)
 999
1000            # Extract `epoch_{i}` or `step_{i}`
1001            latest_dir = os.path.basename(latest_dir)
1002            if "epoch" in latest_dir:
1003                starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1004                completed_steps = starting_epoch * num_update_steps_per_epoch
1005            else:
1006                # need to multiply `gradient_accumulation_steps` to reflect real steps
1007                resume_step = (
1008                    int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1009                )
1010                starting_epoch = resume_step // len(dataloader["train"])
1011                completed_steps = resume_step // args.gradient_accumulation_steps
1012                resume_step -= starting_epoch * len(dataloader["train"])
1013
1014    # update the progress_bar if load from checkpoint
1015    progress_bar.update(completed_steps)
1016
1017    # Evaluate before training (e.g. PTQ accuracy before QAT)
1018    eval_metric = evaluate_model(
1019        args,
1020        model,
1021        accelerator,
1022        examples["eval"],
1023        dataset["eval"],
1024        dataloader["eval"],
1025        answer_column_name,
1026    )
1027    for epoch in range(starting_epoch, args.num_train_epochs):
1028        model.train()
1029        if args.with_tracking:
1030            total_loss = 0
1031        if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1032            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1033            active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1034        else:
1035            active_dataloader = dataloader["train"]
1036        for batch in active_dataloader:
1037            optimizer.zero_grad()
1038            outputs = model(**batch)
1039
1040            # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1041            if args.do_modelopt_distill:
1042                loss = model.module.compute_kd_loss()
1043            else:
1044                loss, _, _ = outputs.to_tuple()
1045
1046            # We keep track of the loss at each epoch
1047            if args.with_tracking:
1048                total_loss += loss.detach().float()
1049
1050            accelerator.backward(loss)
1051            optimizer.step()
1052            lr_scheduler.step()
1053
1054            # Checks if the accelerator has performed an optimization step behind the scenes
1055            if accelerator.sync_gradients:
1056                progress_bar.update(1)
1057                completed_steps += 1
1058
1059            if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1060                accelerator.save_state(os.path.join(args.output_dir, f"step_{completed_steps}"))
1061
1062            if completed_steps >= args.max_train_steps:
1063                break
1064
1065        if args.checkpointing_steps == "epoch":
1066            accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}"))
1067
1068        eval_metric = evaluate_model(
1069            args,
1070            model,
1071            accelerator,
1072            examples["eval"],
1073            dataset["eval"],
1074            dataloader["eval"],
1075            answer_column_name,
1076        )
1077
1078        if args.with_tracking:
1079            log = {
1080                "squad": eval_metric,
1081                "train_loss": total_loss.item() / len(dataloader["train"]),  # type: ignore[attr-defined]
1082                "epoch": epoch,
1083                "step": completed_steps,
1084            }
1085            accelerator.log(log, step=completed_steps)
1086
1087    accelerator.wait_for_everyone()
1088    if accelerator.is_main_process and eval_metric:
1089        logger.info(json.dumps(eval_metric, indent=4))
1090        # Prefix all keys with prefix + '_'
1091        for key in list(eval_metric.keys()):
1092            eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094        with open(os.path.join(args.output_dir, "results.json"), "w") as f:
1095            json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099    args = parse_args(input_args)
1100
1101    # Initialize the accelerator
1102    accelerator_log_kwargs = {}
1103    if args.with_tracking:
1104        accelerator_log_kwargs["log_with"] = "tensorboard"
1105        accelerator_log_kwargs["project_dir"] = args.output_dir
1106    accelerator = Accelerator(
1107        gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108    )
1109
1110    # Setup logging
1111    logging.basicConfig(
1112        format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113        datefmt="%m/%d/%Y %H:%M:%S",
1114        level=logging.INFO,
1115    )
1116    logger.info(accelerator.state, main_process_only=False)
1117    if accelerator.is_local_main_process:
1118        datasets.utils.logging.set_verbosity_warning()
1119        transformers.utils.logging.set_verbosity_info()
1120    else:
1121        datasets.utils.logging.set_verbosity_error()
1122        transformers.utils.logging.set_verbosity_error()
1123
1124    # Set the training seed
1125    set_seed(SEED)
1126
1127    # Handle the output_dir creation
1128    if accelerator.is_main_process:
1129        os.makedirs(args.output_dir, exist_ok=True)
1130    accelerator.wait_for_everyone()
1131
1132    # Load pretrained model and tokenizer
1133    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1134    model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1135    dummy_input = model.dummy_inputs["input_ids"]
1136
1137    # Get datasets
1138    examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1139        args, tokenizer, accelerator
1140    )
1141
1142    # Model Optimizer: Save the pruned or quantized model
1143    def save_modelopt_model(model):
1144        if accelerator.is_main_process and args.modelopt_save_file is not None:
1145            save_path = os.path.join(args.output_dir, args.modelopt_save_file)
1146            logger.info(f"Saving modelopt optimized model to {save_path}")
1147            mto.save(model, save_path)
1148
1149    # Model Optimizer: Restore the pruned or quantized model from a checkpoint
1150    if args.modelopt_restore_path:
1151        assert os.path.exists(
1152            args.modelopt_restore_path
1153        ), f"{args.modelopt_restore_path} does not exist."
1154        logger.info(f"Restoring model from {args.modelopt_restore_path}")
1155        model = mto.restore(model, args.modelopt_restore_path)
1156
1157    # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1158    if args.do_modelopt_prune:
1159        logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1160
1161        # NOTE: gradnas does not perform synchronization across data parallel groups
1162        # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1163        # in the data parallel group have the same gradients for pruning
1164        model = model.to(accelerator.device)
1165        dummy_input = dummy_input.to(accelerator.device)
1166
1167        # Search for the best pruned model
1168        # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1169        model, _ = mtp.prune(
1170            model=model,
1171            mode="gradnas",
1172            constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1173            dummy_input=dummy_input,
1174            config={
1175                "data_loader": dataloader["train"],
1176                "collect_func": lambda batch: (batch,),
1177                "loss_func": lambda output, batch: (
1178                    output["loss"] if isinstance(output, dict) else output[0]
1179                ),
1180            },
1181        )
1182        save_modelopt_model(model)
1183
1184    # Model Optimizer: Quantize the model to INT8 precision
1185    if args.modelopt_quantize_cfg:
1186        logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1187
1188        # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1189        # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1190        # in the data parallel group have same calibration statistics
1191        model = model.to(accelerator.device)
1192
1193        def forward_loop(model):
1194            num_samples = 256  # Use only 256 samples for PTQ calibration
1195            num_batches = num_samples // args.per_device_train_batch_size
1196            for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1197                batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1198                model(**batch)
1199                if idx >= num_batches:
1200                    break
1201
1202        model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1203        torch.cuda.empty_cache()
1204        save_modelopt_model(model)
1205
1206    if args.do_train:
1207        # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1208        if args.do_modelopt_distill:
1209            logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1210
1211            kd_config = {
1212                "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1213                "criterion": StartEndLogitsDistillationLoss(args.temperature),
1214            }
1215            model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1216
1217        train_and_evaluate_model(
1218            args,
1219            model,
1220            tokenizer,
1221            accelerator,
1222            examples,
1223            dataset,
1224            dataloader,
1225            answer_column_name,
1226        )
1227
1228        # Model Optimizer: Export the distilled model
1229        if args.do_modelopt_distill:
1230            model = mtd.export(model)
1231
1232        save_modelopt_model(model)
1233
1234    if accelerator.is_main_process and args.onnx_export_file is not None:
1235        save_path = os.path.join(args.output_dir, args.onnx_export_file)
1236        logger.info(f"Exporting ONNX model to {save_path}")
1237
1238        # Move the model and dummy_input to the device
1239        model = model.to(accelerator.device)
1240        dummy_input = dummy_input.to(accelerator.device)
1241
1242        with open(save_path, "wb") as f:
1243            f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1244
1245    logger.info("Done!")
1246
1247
1248if __name__ == "__main__":
1249    main()

Commands

  1. First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.

    1_prune.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    if [ ! -f "${BASE_DIR}/pruned_model.pth" ]; then
        modelopt_args="--do_modelopt_prune --modelopt_prune_flops_percent ${FLOPS_PERCENT}"
    else
        modelopt_args="--modelopt_restore_path ${BASE_DIR}/pruned_model.pth"
    fi
    modelopt_args="${modelopt_args} --modelopt_save_file pruned_model.pth"
    
    # Run pruning followed by distributed fine-tuning on the pruned model
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${BASE_DIR} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 1e-4 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 15 \
        --with_tracking \
        --checkpointing_steps epoch \
        --resume_from_last_ckpt
    
  2. Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.

    2_int8_quantize.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    OUTPUT_DIR=${BASE_DIR}/int8_quantized
    
    if [ ! -f "${OUTPUT_DIR}/quantized_model.pth" ]; then
        modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --modelopt_restore_path ${BASE_DIR}/pruned_model.pth"
    else
        modelopt_args="--modelopt_restore_path ${OUTPUT_DIR}/quantized_model.pth"
    fi
    
    modelopt_args="${modelopt_args} --modelopt_save_file quantized_model.pth"
    
    # Run distributed QAT on the pruned model with 0.1x LR and less epochs
    accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${OUTPUT_DIR} \
        ${modelopt_args} \
        --do_train \
        --do_modelopt_distill \
        --lr_scheduler_type cosine \
        --learning_rate 5e-6 \
        --per_device_train_batch_size 16 \
        --num_train_epochs 2 \
        --with_tracking \
        --checkpointing_steps epoch \
        --resume_from_last_ckpt
    
  3. Export the quantized model to ONNX format for deployment with TensorRT.

    3_onnx_export.sh
    #!/bin/bash
    set -e
    set -x
    
    FLOPS_PERCENT=50
    BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent
    
    # Export to ONNX on a single GPU
    python3 bert_prune_distill_quantize.py \
        --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \
        --output_dir ${BASE_DIR}/onnx \
        --modelopt_restore_path ${BASE_DIR}/int8_quantized/quantized_model.pth \
        --onnx_export_file pruned_model_int8.onnx \