HF BERT: Prune, Distill & Quantize
This example shows how to compress a Hugging Face Bert large model for Question Answering
using the combination of modelopt.torch.prune
, modelopt.torch.distill
and modelopt.torch.quantize
. More specifically, we will:
Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation
Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation
Export the quantized model to ONNX format for deployment with TensorRT
Prerequisites
Install the Model Optimizer and optional torch and huggingface dependencies:
pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
Note
This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.
Full code
You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub
1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from https://github.com/huggingface/
2# transformers/blob/c52b515e948fc12ff58ad773a0385860d0162f61/examples/pytorch/question-answering
3#
4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
19# SPDX-License-Identifier: MIT
20
21# Permission is hereby granted, free of charge, to any person obtaining a
22# copy of this software and associated documentation files (the "Software"),
23# to deal in the Software without restriction, including without limitation
24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
25# and/or sell copies of the Software, and to permit persons to whom the
26# Software is furnished to do so, subject to the following conditions:
27
28# The above copyright notice and this permission notice shall be included in
29# all copies or substantial portions of the Software.
30
31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
37# DEALINGS IN THE SOFTWARE.
38"""
39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
41"""
42import argparse
43import collections
44import json
45import logging
46import math
47import os
48import random
49from typing import Any, Dict, List, Optional, Tuple
50
51import datasets
52import evaluate
53import numpy as np
54import torch
55import torch.nn as nn
56import transformers
57from accelerate import Accelerator
58from accelerate.logging import get_logger
59from accelerate.utils import set_seed
60from torch.utils.data import DataLoader
61from tqdm.auto import tqdm
62from transformers import (
63 AutoModelForQuestionAnswering,
64 AutoTokenizer,
65 DataCollatorWithPadding,
66 EvalPrediction,
67 PreTrainedTokenizer,
68 SchedulerType,
69 default_data_collator,
70 get_scheduler,
71)
72
73import modelopt.torch.distill as mtd
74import modelopt.torch.opt as mto
75import modelopt.torch.prune as mtp
76import modelopt.torch.quantization as mtq
77from modelopt.torch._deploy.utils import get_onnx_bytes
78
79logger = get_logger(__name__)
80
81SEED = 123
82
83
84def parse_args(input_args: Optional[List[str]] = None):
85 parser = argparse.ArgumentParser(
86 description="Finetune a transformers model on a Question Answering task"
87 )
88
89 # Training arguments
90 parser.add_argument(
91 "--model_name_or_path",
92 type=str,
93 default="bert-large-uncased-whole-word-masking-finetuned-squad",
94 help="Path to pretrained model or model identifier from huggingface.co/models.",
95 )
96 parser.add_argument(
97 "--do_train",
98 action="store_true",
99 help="Whether to run training / fine-tuning.",
100 )
101 parser.add_argument(
102 "--per_device_train_batch_size",
103 type=int,
104 default=16,
105 help="Batch size (per device) for the training dataloader.",
106 )
107 parser.add_argument(
108 "--per_device_eval_batch_size",
109 type=int,
110 default=64,
111 help="Batch size (per device) for the evaluation dataloader.",
112 )
113 parser.add_argument(
114 "--learning_rate",
115 type=float,
116 default=5e-5,
117 help="Initial learning rate (after the potential warmup period) to use.",
118 )
119 parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
120 parser.add_argument(
121 "--lr_scheduler_type",
122 type=SchedulerType,
123 default="linear",
124 help="The scheduler type to use.",
125 choices=[
126 "linear",
127 "cosine",
128 "cosine_with_restarts",
129 "polynomial",
130 "constant",
131 "constant_with_warmup",
132 ],
133 )
134 parser.add_argument(
135 "--num_warmup_steps",
136 type=int,
137 default=0,
138 help="Number of steps for the warmup in the lr scheduler.",
139 )
140 parser.add_argument(
141 "--num_train_epochs",
142 type=float,
143 default=2.0,
144 help="Total number of training epochs to perform.",
145 )
146 parser.add_argument(
147 "--max_train_steps",
148 type=int,
149 default=None,
150 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
151 )
152 parser.add_argument(
153 "--gradient_accumulation_steps",
154 type=int,
155 default=1,
156 help="Number of updates steps to accumulate before performing a backward/update pass.",
157 )
158 parser.add_argument(
159 "--preprocessing_num_workers",
160 type=int,
161 default=4,
162 help="The number of processes to use for preprocessing the dataset.",
163 )
164
165 # Logging and checkpointing arguments
166 parser.add_argument(
167 "--output_dir", type=str, default="results", help="Where to store the optimized models."
168 )
169 parser.add_argument(
170 "--with_tracking",
171 action="store_true",
172 help="Whether to enable experiment trackers for logging.",
173 )
174 parser.add_argument(
175 "--checkpointing_steps",
176 type=str,
177 default=None,
178 help=(
179 "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
180 " each epoch."
181 ),
182 )
183 parser.add_argument(
184 "--resume_from_last_ckpt",
185 action="store_true",
186 help="If the training should continue from the latest checkpoint in output_dir.",
187 )
188
189 # Misc arguments for Bert (should not be modified in most cases)
190 parser.add_argument(
191 "--max_seq_length",
192 type=int,
193 default=384,
194 help=(
195 "The maximum total input sequence length after tokenization. Sequences longer than this"
196 " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
197 ),
198 )
199 parser.add_argument(
200 "--pad_to_max_length",
201 action="store_true",
202 help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
203 )
204 parser.add_argument(
205 "--doc_stride",
206 type=int,
207 default=128,
208 help=(
209 "When splitting up a long document into chunks how much stride to take between chunks."
210 ),
211 )
212 parser.add_argument(
213 "--n_best_size",
214 type=int,
215 default=20,
216 help="The total number of n-best predictions to generate when looking for an answer.",
217 )
218 parser.add_argument(
219 "--max_answer_length",
220 type=int,
221 default=30,
222 help=(
223 "The maximum length of an answer that can be generated. This is needed because the"
224 " start and end predictions are not conditioned on one another."
225 ),
226 )
227
228 # Debugging arguments
229 parser.add_argument(
230 "--max_train_samples",
231 type=int,
232 default=None,
233 help="For debugging purposes or quicker training.",
234 )
235 parser.add_argument(
236 "--max_eval_samples",
237 type=int,
238 default=None,
239 help="For debugging purposes or quicker training.",
240 )
241
242 # Model Optimizer: pruning arguments
243 parser.add_argument(
244 "--do_modelopt_prune",
245 action="store_true",
246 help="Whether or not to use Model Optimizer pruning.",
247 )
248 parser.add_argument(
249 "--modelopt_prune_flops_percent",
250 type=float,
251 default=None,
252 help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
253 )
254
255 # Model Optimizer: quantization arguments
256 parser.add_argument(
257 "--modelopt_quantize_cfg",
258 help="Model Optimizer quantization config.",
259 choices=mtq.config.choices,
260 )
261
262 # Model Optimizer: Distillation arguments
263 parser.add_argument(
264 "--do_modelopt_distill",
265 action="store_true",
266 help="Whether or not to use distillation. A teacher model must be specified.",
267 )
268 parser.add_argument(
269 "--temperature",
270 type=float,
271 default=2.0,
272 help="The temperature to use when distilling.",
273 )
274
275 # Model Optimizer: save and restore arguments
276 parser.add_argument(
277 "--modelopt_save_file",
278 type=str,
279 default=None,
280 help="File (inside output_dir) to save the modelopt modified model to.",
281 )
282 parser.add_argument(
283 "--modelopt_restore_path",
284 type=str,
285 default=None,
286 help="Path to restore the modelopt modified model from.",
287 )
288
289 # ONNX export arguments
290 parser.add_argument(
291 "--onnx_export_file",
292 type=str,
293 default=None,
294 help="File (inside output_dir) to export the ONNX model to.",
295 )
296
297 args = parser.parse_args(input_args)
298
299 # Sanity checks
300 if args.do_modelopt_prune and not args.modelopt_prune_flops_percent:
301 raise ValueError(
302 "Need a `modelopt_prune_flops_percent` when `do_modelopt_prune` is passed."
303 )
304
305 return args
306
307
308def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
309 """Get the examples, dataset, dataloader, answer_column_name
310
311 You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
312 or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
313 (the dataset will be downloaded automatically from the datasets Hub).
314
315 For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
316 'text' is found. You can easily tweak this behavior (see below).
317 """
318
319 def prepare_train_features(examples):
320 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
321 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
322 # left whitespace
323 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
324
325 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
326 # in one example possible giving several features when a context is long, each of those features having a
327 # context that overlaps a bit the context of the previous feature.
328 tokenized_examples = tokenizer(
329 examples[question_column_name if pad_on_right else context_column_name],
330 examples[context_column_name if pad_on_right else question_column_name],
331 truncation="only_second" if pad_on_right else "only_first",
332 max_length=max_seq_length,
333 stride=args.doc_stride,
334 return_overflowing_tokens=True,
335 return_offsets_mapping=True,
336 padding="max_length" if args.pad_to_max_length else False,
337 )
338
339 # Since one example might give us several features if it has a long context, we need a map from a feature to
340 # its corresponding example. This key gives us just that.
341 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
342 # The offset mappings will give us a map from token to character position in the original context. This will
343 # help us compute the start_positions and end_positions.
344 offset_mapping = tokenized_examples.pop("offset_mapping")
345
346 # Let's label those examples!
347 tokenized_examples["start_positions"] = []
348 tokenized_examples["end_positions"] = []
349
350 for i, offsets in enumerate(offset_mapping):
351 # We will label impossible answers with the index of the CLS token.
352 input_ids = tokenized_examples["input_ids"][i]
353 cls_index = input_ids.index(tokenizer.cls_token_id)
354
355 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
356 sequence_ids = tokenized_examples.sequence_ids(i)
357
358 # One example can give several spans, this is the index of the example containing this span of text.
359 sample_index = sample_mapping[i]
360 answers = examples[answer_column_name][sample_index]
361 # If no answers are given, set the cls_index as answer.
362 if len(answers["answer_start"]) == 0:
363 tokenized_examples["start_positions"].append(cls_index)
364 tokenized_examples["end_positions"].append(cls_index)
365 else:
366 # Start/end character index of the answer in the text.
367 start_char = answers["answer_start"][0]
368 end_char = start_char + len(answers["text"][0])
369
370 # Start token index of the current span in the text.
371 token_start_index = 0
372 while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
373 token_start_index += 1
374
375 # End token index of the current span in the text.
376 token_end_index = len(input_ids) - 1
377 while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
378 token_end_index -= 1
379
380 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
381 if not (
382 offsets[token_start_index][0] <= start_char
383 and offsets[token_end_index][1] >= end_char
384 ):
385 tokenized_examples["start_positions"].append(cls_index)
386 tokenized_examples["end_positions"].append(cls_index)
387 else:
388 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
389 # Note: we could go after the last offset if the answer is the last word (edge case).
390 while (
391 token_start_index < len(offsets)
392 and offsets[token_start_index][0] <= start_char
393 ):
394 token_start_index += 1
395 tokenized_examples["start_positions"].append(token_start_index - 1)
396 while offsets[token_end_index][1] >= end_char:
397 token_end_index -= 1
398 tokenized_examples["end_positions"].append(token_end_index + 1)
399
400 return tokenized_examples
401
402 def prepare_validation_features(examples):
403 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
404 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
405 # left whitespace
406 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
407
408 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
409 # in one example possible giving several features when a context is long, each of those features having a
410 # context that overlaps a bit the context of the previous feature.
411 tokenized_examples = tokenizer(
412 examples[question_column_name if pad_on_right else context_column_name],
413 examples[context_column_name if pad_on_right else question_column_name],
414 truncation="only_second" if pad_on_right else "only_first",
415 max_length=max_seq_length,
416 stride=args.doc_stride,
417 return_overflowing_tokens=True,
418 return_offsets_mapping=True,
419 padding="max_length" if args.pad_to_max_length else False,
420 )
421
422 # Since one example might give us several features if it has a long context, we need a map from a feature to
423 # its corresponding example. This key gives us just that.
424 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
425
426 # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
427 # corresponding example_id and we will store the offset mappings.
428 tokenized_examples["example_id"] = []
429
430 for i in range(len(tokenized_examples["input_ids"])):
431 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
432 sequence_ids = tokenized_examples.sequence_ids(i)
433 context_index = 1 if pad_on_right else 0
434
435 # One example can give several spans, this is the index of the example containing this span of text.
436 sample_index = sample_mapping[i]
437 tokenized_examples["example_id"].append(examples["id"][sample_index])
438
439 # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
440 # position is part of the context or not.
441 tokenized_examples["offset_mapping"][i] = [
442 (o if sequence_ids[k] == context_index else None)
443 for k, o in enumerate(tokenized_examples["offset_mapping"][i])
444 ]
445
446 return tokenized_examples
447
448 examples, dataset, dataloader = {}, {}, {}
449
450 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
451 # download the dataset.
452 # Downloading and loading a dataset from the hub.
453 raw_datasets = datasets.load_dataset("squad")
454 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
455 # https://huggingface.co/docs/datasets/loading_datasets.
456
457 # Preprocessing the datasets.
458 # Preprocessing is slighlty different for training and evaluation.
459
460 column_names = raw_datasets["train"].column_names
461
462 question_column_name = "question" if "question" in column_names else column_names[0]
463 context_column_name = "context" if "context" in column_names else column_names[1]
464 answer_column_name = "answers" if "answers" in column_names else column_names[2]
465
466 # Padding side determines if we do (question|context) or (context|question).
467 pad_on_right = tokenizer.padding_side == "right"
468
469 if args.max_seq_length > tokenizer.model_max_length:
470 logger.warning(
471 f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
472 f" for the model ({tokenizer.model_max_length}). Using"
473 f" max_seq_length={tokenizer.model_max_length}."
474 )
475
476 max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
477
478 examples["train"] = raw_datasets["train"]
479 if args.max_train_samples is not None:
480 # We will select sample from whole data if agument is specified
481 examples["train"] = examples["train"].select(range(args.max_train_samples))
482
483 # Create train feature from dataset
484 with accelerator.main_process_first():
485 dataset["train"] = examples["train"].map(
486 prepare_train_features,
487 batched=True,
488 num_proc=args.preprocessing_num_workers,
489 remove_columns=column_names,
490 load_from_cache_file=True,
491 desc="Running tokenizer on train dataset",
492 )
493 # if args.max_train_samples is not None:
494 # # Number of samples might increase during Feature Creation, We select only specified max samples
495 # dataset["train"] = dataset["train"].select(range(args.max_train_samples))
496
497 examples["eval"] = raw_datasets["validation"]
498 if args.max_eval_samples is not None:
499 # We will select sample from whole data
500 examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
501 # Validation Feature Creation
502 with accelerator.main_process_first():
503 dataset["eval"] = examples["eval"].map(
504 prepare_validation_features,
505 batched=True,
506 num_proc=args.preprocessing_num_workers,
507 remove_columns=column_names,
508 load_from_cache_file=True,
509 desc="Running tokenizer on validation dataset",
510 )
511 # if args.max_eval_samples is not None:
512 # # During Feature creation dataset samples might increase, we will select required samples again
513 # dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
514
515 # Log a random sample from the training set:
516 for index in random.sample(range(len(dataset["train"])), 1):
517 logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
518
519 # DataLoaders creation:
520 if args.pad_to_max_length:
521 # If padding was already done ot max length, we use the default data collator that will just convert everything
522 # to tensors.
523 data_collator = default_data_collator
524 else:
525 # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
526 # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
527 # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
528 data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
529
530 dataloader["train"] = DataLoader(
531 dataset["train"],
532 shuffle=True,
533 collate_fn=data_collator,
534 batch_size=args.per_device_train_batch_size,
535 )
536
537 dataloader["eval"] = DataLoader(
538 dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
539 collate_fn=data_collator,
540 batch_size=args.per_device_eval_batch_size,
541 )
542
543 return examples, dataset, dataloader, answer_column_name
544
545
546def evaluate_model(
547 args,
548 model: nn.Module,
549 accelerator: Accelerator,
550 eval_examples: Any,
551 eval_dataset: Any,
552 eval_dataloader: DataLoader,
553 answer_column_name: str,
554 prefix: str = "Eval",
555):
556 def create_and_fill_np_array(start_or_end_logits, max_len):
557 """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
558
559 Args:
560 start_or_end_logits: This is the output predictions of the model.
561 We can only enter either start or end logits.
562 max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
563 """
564 step = 0
565 # create a numpy array and fill it with -100.
566 logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
567 # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
568 for i, output_logit in enumerate(start_or_end_logits): # populate columns
569 # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
570 # And after every iteration we have to change the step
571 batch_size = output_logit.shape[0]
572 cols = output_logit.shape[1]
573
574 if step + batch_size < len(eval_dataset):
575 logits_concat[step : step + batch_size, :cols] = output_logit
576 else:
577 logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
578
579 step += batch_size
580
581 return logits_concat
582
583 def postprocess_qa_predictions(
584 examples,
585 features,
586 predictions: Tuple[np.ndarray, np.ndarray],
587 version_2_with_negative: bool = False,
588 n_best_size: int = 20,
589 max_answer_length: int = 30,
590 null_score_diff_threshold: float = 0.0,
591 output_dir: Optional[str] = None,
592 prefix: Optional[str] = None,
593 ) -> EvalPrediction:
594 """Post-processes the predictions of a question-answering model to convert them to answers
595 that are substrings of the original contexts. This is the base postprocessing functions for
596 models that only return start and end logits.
597
598 Args:
599 examples: The non-preprocessed dataset.
600 features: The processed dataset.
601 predictions: The predictions of the model: two arrays containing the start logits and the end logits
602 respectively. Its first dimension must match the number of elements of `features`.
603 version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
604 n_best_size: The total number of n-best predictions to generate when looking for an answer.
605 max_answer_length: The maximum length of an answer that can be generated. This is needed
606 because the start and end predictions are not conditioned on one another.
607 null_score_diff_threshold: The threshold used to select the null answer: if the best answer
608 has a score that is less than the score of the null answer minus this threshold, the
609 null answer is selected for this example (note that the score of the null answer for
610 an example giving several features is the minimum of the scores for the null answer on
611 each feature: all features must be aligned on the fact they `want` to predict a null answer).
612 Only useful when `version_2_with_negative` is `True`.
613 output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
614 and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
615 answers, are saved in `output_dir`.
616 prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
617 """
618 if len(predictions) != 2:
619 raise ValueError(
620 "`predictions` should be a tuple with two elements (start_logits, end_logits)."
621 )
622 all_start_logits, all_end_logits = predictions
623
624 if len(predictions[0]) != len(features):
625 raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
626
627 # Build a map example to its corresponding features.
628 example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
629 features_per_example = collections.defaultdict(list)
630 for i, feature in enumerate(features):
631 features_per_example[example_id_to_index[feature["example_id"]]].append(i)
632
633 # The dictionaries we have to fill.
634 all_predictions = collections.OrderedDict()
635 all_nbest_json = collections.OrderedDict()
636 if version_2_with_negative:
637 scores_diff_json = collections.OrderedDict()
638
639 logger.debug(
640 f"Post-processing {len(examples)} example predictions split into"
641 f" {len(features)} features."
642 )
643
644 # Let's loop over all the examples!
645 for example_index, example in enumerate(examples):
646 # Those are the indices of the features associated to the current example.
647 feature_indices = features_per_example[example_index]
648
649 min_null_prediction = None
650 prelim_predictions = []
651
652 # Looping through all the features associated to the current example.
653 for feature_index in feature_indices:
654 # We grab the predictions of the model for this feature.
655 start_logits = all_start_logits[feature_index]
656 end_logits = all_end_logits[feature_index]
657 # This is what will allow us to map some the positions in our logits to span of texts in the original
658 # context.
659 offset_mapping = features[feature_index]["offset_mapping"]
660 # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
661 # context available in the current feature.
662 token_is_max_context = features[feature_index].get("token_is_max_context", None)
663
664 # Update minimum null prediction.
665 feature_null_score = start_logits[0] + end_logits[0]
666 if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
667 min_null_prediction = {
668 "offsets": (0, 0),
669 "score": feature_null_score,
670 "start_logit": start_logits[0],
671 "end_logit": end_logits[0],
672 }
673
674 # Go through all possibilities for the `n_best_size` greater start and end logits.
675 start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
676 end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
677 for start_index in start_indexes:
678 for end_index in end_indexes:
679 # Don't consider out-of-scope answers, either because the indices are out of bounds or
680 # correspond to part of the input_ids that are not in the context.
681 if (
682 start_index >= len(offset_mapping)
683 or end_index >= len(offset_mapping)
684 or offset_mapping[start_index] is None
685 or len(offset_mapping[start_index]) < 2
686 or offset_mapping[end_index] is None
687 or len(offset_mapping[end_index]) < 2
688 ):
689 continue
690 # Don't consider answers with a length that is either < 0 or > max_answer_length.
691 if (
692 end_index < start_index
693 or end_index - start_index + 1 > max_answer_length
694 ):
695 continue
696 # Don't consider answer that don't have the maximum context available (if such information is
697 # provided).
698 if token_is_max_context is not None and not token_is_max_context.get(
699 str(start_index), False
700 ):
701 continue
702
703 prelim_predictions.append(
704 {
705 "offsets": (
706 offset_mapping[start_index][0],
707 offset_mapping[end_index][1],
708 ),
709 "score": start_logits[start_index] + end_logits[end_index],
710 "start_logit": start_logits[start_index],
711 "end_logit": end_logits[end_index],
712 }
713 )
714 if version_2_with_negative and min_null_prediction is not None:
715 # Add the minimum null prediction
716 prelim_predictions.append(min_null_prediction)
717 null_score = min_null_prediction["score"]
718
719 # Only keep the best `n_best_size` predictions.
720 n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
721 :n_best_size
722 ]
723
724 # Add back the minimum null prediction if it was removed because of its low score.
725 if (
726 version_2_with_negative
727 and min_null_prediction is not None
728 and not any(p["offsets"] == (0, 0) for p in n_best_preds)
729 ):
730 n_best_preds.append(min_null_prediction)
731
732 # Use the offsets to gather the answer text in the original context.
733 context = example["context"]
734 for pred in n_best_preds:
735 offsets = pred.pop("offsets")
736 pred["text"] = context[offsets[0] : offsets[1]]
737
738 # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
739 # failure.
740 if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
741 n_best_preds.insert(
742 0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
743 )
744
745 # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
746 # using the LogSumExp trick).
747 scores = np.array([pred.pop("score") for pred in n_best_preds])
748 exp_scores = np.exp(scores - np.max(scores))
749 probs = exp_scores / exp_scores.sum()
750
751 # Include the probabilities in our n_best_preds.
752 for prob, pred in zip(probs, n_best_preds):
753 pred["probability"] = prob
754
755 # Pick the best prediction. If the null answer is not possible, this is easy.
756 if not version_2_with_negative:
757 all_predictions[example["id"]] = n_best_preds[0]["text"]
758 else:
759 # Otherwise we first need to find the best non-empty prediction.
760 i = 0
761 while n_best_preds[i]["text"] == "":
762 i += 1
763 best_non_null_pred = n_best_preds[i]
764
765 # Then we compare to the null prediction using the threshold.
766 score_diff = (
767 null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
768 )
769 scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
770 if score_diff > null_score_diff_threshold:
771 all_predictions[example["id"]] = ""
772 else:
773 all_predictions[example["id"]] = best_non_null_pred["text"]
774
775 # Make `n_best_preds` JSON-serializable by casting np.float back to float.
776 all_nbest_json[example["id"]] = [
777 {
778 k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
779 for k, v in pred.items()
780 }
781 for pred in n_best_preds
782 ]
783
784 # If we have an output_dir, let's save all those dicts.
785 if output_dir is not None:
786 if not os.path.isdir(output_dir):
787 raise EnvironmentError(f"{output_dir} is not a directory.")
788
789 prediction_file = os.path.join(
790 output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
791 )
792 nbest_file = os.path.join(
793 output_dir,
794 "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
795 )
796 if version_2_with_negative:
797 null_odds_file = os.path.join(
798 output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
799 )
800
801 logger.info(f"Saving predictions to {prediction_file}.")
802 with open(prediction_file, "w") as writer:
803 writer.write(json.dumps(all_predictions, indent=4) + "\n")
804 logger.info(f"Saving nbest_preds to {nbest_file}.")
805 with open(nbest_file, "w") as writer:
806 writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
807 if version_2_with_negative:
808 logger.info(f"Saving null_odds to {null_odds_file}.")
809 with open(null_odds_file, "w") as writer:
810 writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
811
812 # Format the result to the format the metric expects.
813 formatted_predictions = [
814 {"id": k, "prediction_text": v} for k, v in all_predictions.items()
815 ]
816
817 references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
818 return EvalPrediction(predictions=formatted_predictions, label_ids=references)
819
820 logger.info(f"***** Running {prefix} *****")
821 logger.info(f" Num examples = {len(eval_dataset)}")
822 logger.info(f" Batch size = {args.per_device_eval_batch_size}")
823
824 model.eval()
825 all_start_logits = []
826 all_end_logits = []
827 for _, batch in enumerate(eval_dataloader):
828 with torch.no_grad():
829 outputs = model(**batch)
830 start_logits = outputs.start_logits
831 end_logits = outputs.end_logits
832
833 if (
834 not args.pad_to_max_length
835 ): # necessary to pad predictions and labels for being gathered
836 start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
837 end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
838
839 all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
840 all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
841
842 # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
843 if args.do_modelopt_distill:
844 model.module.compute_kd_loss()
845
846 max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
847
848 # concatenate the numpy array
849 start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
850 end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
851
852 outputs_numpy = (start_logits_concat, end_logits_concat)
853 prediction = postprocess_qa_predictions(
854 examples=eval_examples,
855 features=eval_dataset,
856 predictions=outputs_numpy,
857 n_best_size=args.n_best_size,
858 max_answer_length=args.max_answer_length,
859 # output_dir=args.output_dir,
860 prefix=prefix,
861 )
862
863 metric = evaluate.load("squad")
864 eval_metric = metric.compute(
865 predictions=prediction.predictions, references=prediction.label_ids
866 )
867 logger.info(f"{prefix} metrics: {eval_metric}\n")
868 return eval_metric
869
870
871# Model Optimizer: Define a teacher factory for initializing the distillation model
872def teacher_factory(model_name_or_path):
873 return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
874
875
876# Model Optimizer: Define a custom distillation loss function that uses start and end logits
877class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
878
879 def forward(self, outputs_s, outputs_t):
880 loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
881 loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
882 loss = (loss_start + loss_end) / 2.0
883
884 return loss
885
886
887def train_and_evaluate_model(
888 args,
889 model: nn.Module,
890 tokenizer: PreTrainedTokenizer,
891 accelerator: Accelerator,
892 examples: Dict,
893 dataset: Dict,
894 dataloader: Dict[str, DataLoader],
895 answer_column_name,
896):
897 # Optimizer
898 # Split weights in two groups, one with weight decay and the other not.
899 no_decay = ["bias", "LayerNorm.weight"]
900 optimizer_grouped_parameters = [
901 {
902 "params": [
903 p
904 for n, p in model.named_parameters()
905 if p.requires_grad and not any(nd in n for nd in no_decay)
906 ],
907 "weight_decay": args.weight_decay,
908 },
909 {
910 "params": [
911 p
912 for n, p in model.named_parameters()
913 if p.requires_grad and any(nd in n for nd in no_decay)
914 ],
915 "weight_decay": 0.0,
916 },
917 ]
918 optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
919
920 # Scheduler and math around the number of training steps.
921 overrode_max_train_steps = False
922 num_update_steps_per_epoch = math.ceil(
923 len(dataloader["train"]) / args.gradient_accumulation_steps
924 )
925 if args.max_train_steps is None:
926 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
927 overrode_max_train_steps = True
928
929 lr_scheduler = get_scheduler(
930 name=args.lr_scheduler_type,
931 optimizer=optimizer,
932 num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
933 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
934 )
935
936 # Prepare everything with our `accelerator`.
937 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
938 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
939 )
940
941 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
942 num_update_steps_per_epoch = math.ceil(
943 len(dataloader["train"]) / args.gradient_accumulation_steps
944 )
945 if overrode_max_train_steps:
946 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
947 # Afterwards we recalculate our number of training epochs
948 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
949
950 # Figure out how many steps we should save the Accelerator states
951 checkpointing_steps = args.checkpointing_steps
952 if checkpointing_steps is not None and checkpointing_steps.isdigit():
953 checkpointing_steps = int(checkpointing_steps)
954
955 # We need to initialize the trackers we use, and also store our configuration.
956 # The trackers initializes automatically on the main process.
957 if args.with_tracking:
958 experiment_config = vars(args)
959 # TensorBoard cannot log Enums, need the raw value
960 experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
961 accelerator.init_trackers("tensorboard", experiment_config)
962
963 # Train!
964 total_batch_size = (
965 args.per_device_train_batch_size
966 * accelerator.num_processes
967 * args.gradient_accumulation_steps
968 )
969
970 logger.info("***** Running training *****")
971 logger.info(f" Num examples = {len(dataset['train'])}")
972 logger.info(f" Num Epochs = {args.num_train_epochs}")
973 logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
974 logger.info(
975 f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
976 )
977 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
978 logger.info(f" Total optimization steps = {args.max_train_steps}")
979
980 # Only show the progress bar once on each machine.
981 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
982 completed_steps = 0
983 starting_epoch = 0
984 resume_step = None
985
986 # Potentially load in the weights and states from a previous save
987 if args.resume_from_last_ckpt:
988 # Get the most recent checkpoint
989 dirs = [
990 f.path
991 for f in os.scandir(args.output_dir)
992 if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
993 ]
994 if len(dirs) == 0:
995 logger.warning("No checkpoint found in output_dir. Training from scratch!")
996 else:
997 latest_dir = max(dirs, key=os.path.getctime)
998 accelerator.load_state(latest_dir)
999
1000 # Extract `epoch_{i}` or `step_{i}`
1001 latest_dir = os.path.basename(latest_dir)
1002 if "epoch" in latest_dir:
1003 starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1004 completed_steps = starting_epoch * num_update_steps_per_epoch
1005 else:
1006 # need to multiply `gradient_accumulation_steps` to reflect real steps
1007 resume_step = (
1008 int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1009 )
1010 starting_epoch = resume_step // len(dataloader["train"])
1011 completed_steps = resume_step // args.gradient_accumulation_steps
1012 resume_step -= starting_epoch * len(dataloader["train"])
1013
1014 # update the progress_bar if load from checkpoint
1015 progress_bar.update(completed_steps)
1016
1017 # Evaluate before training (e.g. PTQ accuracy before QAT)
1018 eval_metric = evaluate_model(
1019 args,
1020 model,
1021 accelerator,
1022 examples["eval"],
1023 dataset["eval"],
1024 dataloader["eval"],
1025 answer_column_name,
1026 )
1027 for epoch in range(starting_epoch, args.num_train_epochs):
1028 model.train()
1029 if args.with_tracking:
1030 total_loss = 0
1031 if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1032 # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1033 active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1034 else:
1035 active_dataloader = dataloader["train"]
1036 for batch in active_dataloader:
1037 optimizer.zero_grad()
1038 outputs = model(**batch)
1039
1040 # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1041 if args.do_modelopt_distill:
1042 loss = model.module.compute_kd_loss()
1043 else:
1044 loss, _, _ = outputs.to_tuple()
1045
1046 # We keep track of the loss at each epoch
1047 if args.with_tracking:
1048 total_loss += loss.detach().float()
1049
1050 accelerator.backward(loss)
1051 optimizer.step()
1052 lr_scheduler.step()
1053
1054 # Checks if the accelerator has performed an optimization step behind the scenes
1055 if accelerator.sync_gradients:
1056 progress_bar.update(1)
1057 completed_steps += 1
1058
1059 if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1060 accelerator.save_state(os.path.join(args.output_dir, f"step_{completed_steps}"))
1061
1062 if completed_steps >= args.max_train_steps:
1063 break
1064
1065 if args.checkpointing_steps == "epoch":
1066 accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}"))
1067
1068 eval_metric = evaluate_model(
1069 args,
1070 model,
1071 accelerator,
1072 examples["eval"],
1073 dataset["eval"],
1074 dataloader["eval"],
1075 answer_column_name,
1076 )
1077
1078 if args.with_tracking:
1079 log = {
1080 "squad": eval_metric,
1081 "train_loss": total_loss.item() / len(dataloader["train"]), # type: ignore[attr-defined]
1082 "epoch": epoch,
1083 "step": completed_steps,
1084 }
1085 accelerator.log(log, step=completed_steps)
1086
1087 accelerator.wait_for_everyone()
1088 if accelerator.is_main_process and eval_metric:
1089 logger.info(json.dumps(eval_metric, indent=4))
1090 # Prefix all keys with prefix + '_'
1091 for key in list(eval_metric.keys()):
1092 eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094 with open(os.path.join(args.output_dir, "results.json"), "w") as f:
1095 json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099 args = parse_args(input_args)
1100
1101 # Initialize the accelerator
1102 accelerator_log_kwargs = {}
1103 if args.with_tracking:
1104 accelerator_log_kwargs["log_with"] = "tensorboard"
1105 accelerator_log_kwargs["project_dir"] = args.output_dir
1106 accelerator = Accelerator(
1107 gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108 )
1109
1110 # Setup logging
1111 logging.basicConfig(
1112 format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113 datefmt="%m/%d/%Y %H:%M:%S",
1114 level=logging.INFO,
1115 )
1116 logger.info(accelerator.state, main_process_only=False)
1117 if accelerator.is_local_main_process:
1118 datasets.utils.logging.set_verbosity_warning()
1119 transformers.utils.logging.set_verbosity_info()
1120 else:
1121 datasets.utils.logging.set_verbosity_error()
1122 transformers.utils.logging.set_verbosity_error()
1123
1124 # Set the training seed
1125 set_seed(SEED)
1126
1127 # Handle the output_dir creation
1128 if accelerator.is_main_process:
1129 os.makedirs(args.output_dir, exist_ok=True)
1130 accelerator.wait_for_everyone()
1131
1132 # Load pretrained model and tokenizer
1133 tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1134 model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1135 dummy_input = model.dummy_inputs["input_ids"]
1136
1137 # Get datasets
1138 examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1139 args, tokenizer, accelerator
1140 )
1141
1142 # Model Optimizer: Save the pruned or quantized model
1143 def save_modelopt_model(model):
1144 if accelerator.is_main_process and args.modelopt_save_file is not None:
1145 save_path = os.path.join(args.output_dir, args.modelopt_save_file)
1146 logger.info(f"Saving modelopt optimized model to {save_path}")
1147 mto.save(model, save_path)
1148
1149 # Model Optimizer: Restore the pruned or quantized model from a checkpoint
1150 if args.modelopt_restore_path:
1151 assert os.path.exists(
1152 args.modelopt_restore_path
1153 ), f"{args.modelopt_restore_path} does not exist."
1154 logger.info(f"Restoring model from {args.modelopt_restore_path}")
1155 model = mto.restore(model, args.modelopt_restore_path)
1156
1157 # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1158 if args.do_modelopt_prune:
1159 logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1160
1161 # NOTE: gradnas does not perform synchronization across data parallel groups
1162 # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1163 # in the data parallel group have the same gradients for pruning
1164 model = model.to(accelerator.device)
1165 dummy_input = dummy_input.to(accelerator.device)
1166
1167 # Search for the best pruned model
1168 # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1169 model, _ = mtp.prune(
1170 model=model,
1171 mode="gradnas",
1172 constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1173 dummy_input=dummy_input,
1174 config={
1175 "data_loader": dataloader["train"],
1176 "collect_func": lambda batch: (batch,),
1177 "loss_func": lambda output, batch: (
1178 output["loss"] if isinstance(output, dict) else output[0]
1179 ),
1180 },
1181 )
1182 save_modelopt_model(model)
1183
1184 # Model Optimizer: Quantize the model to INT8 precision
1185 if args.modelopt_quantize_cfg:
1186 logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1187
1188 # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1189 # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1190 # in the data parallel group have same calibration statistics
1191 model = model.to(accelerator.device)
1192
1193 def forward_loop(model):
1194 num_samples = 256 # Use only 256 samples for PTQ calibration
1195 num_batches = num_samples // args.per_device_train_batch_size
1196 for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1197 batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1198 model(**batch)
1199 if idx >= num_batches:
1200 break
1201
1202 model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1203 torch.cuda.empty_cache()
1204 save_modelopt_model(model)
1205
1206 if args.do_train:
1207 # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1208 if args.do_modelopt_distill:
1209 logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1210
1211 kd_config = {
1212 "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1213 "criterion": StartEndLogitsDistillationLoss(args.temperature),
1214 }
1215 model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1216
1217 train_and_evaluate_model(
1218 args,
1219 model,
1220 tokenizer,
1221 accelerator,
1222 examples,
1223 dataset,
1224 dataloader,
1225 answer_column_name,
1226 )
1227
1228 # Model Optimizer: Export the distilled model
1229 if args.do_modelopt_distill:
1230 model = mtd.export(model)
1231
1232 save_modelopt_model(model)
1233
1234 if accelerator.is_main_process and args.onnx_export_file is not None:
1235 save_path = os.path.join(args.output_dir, args.onnx_export_file)
1236 logger.info(f"Exporting ONNX model to {save_path}")
1237
1238 # Move the model and dummy_input to the device
1239 model = model.to(accelerator.device)
1240 dummy_input = dummy_input.to(accelerator.device)
1241
1242 with open(save_path, "wb") as f:
1243 f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1244
1245 logger.info("Done!")
1246
1247
1248if __name__ == "__main__":
1249 main()
Commands
First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent if [ ! -f "${BASE_DIR}/pruned_model.pth" ]; then modelopt_args="--do_modelopt_prune --modelopt_prune_flops_percent ${FLOPS_PERCENT}" else modelopt_args="--modelopt_restore_path ${BASE_DIR}/pruned_model.pth" fi modelopt_args="${modelopt_args} --modelopt_save_file pruned_model.pth" # Run pruning followed by distributed fine-tuning on the pruned model accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${BASE_DIR} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 1e-4 \ --per_device_train_batch_size 16 \ --num_train_epochs 15 \ --with_tracking \ --checkpointing_steps epoch \ --resume_from_last_ckpt
Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent OUTPUT_DIR=${BASE_DIR}/int8_quantized if [ ! -f "${OUTPUT_DIR}/quantized_model.pth" ]; then modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --modelopt_restore_path ${BASE_DIR}/pruned_model.pth" else modelopt_args="--modelopt_restore_path ${OUTPUT_DIR}/quantized_model.pth" fi modelopt_args="${modelopt_args} --modelopt_save_file quantized_model.pth" # Run distributed QAT on the pruned model with 0.1x LR and less epochs accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${OUTPUT_DIR} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 5e-6 \ --per_device_train_batch_size 16 \ --num_train_epochs 2 \ --with_tracking \ --checkpointing_steps epoch \ --resume_from_last_ckpt
Export the quantized model to ONNX format for deployment with TensorRT.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent # Export to ONNX on a single GPU python3 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${BASE_DIR}/onnx \ --modelopt_restore_path ${BASE_DIR}/int8_quantized/quantized_model.pth \ --onnx_export_file pruned_model_int8.onnx \