HF BERT: Prune, Distill & Quantize
This example shows how to compress a Hugging Face Bert large model for Question Answering
using the combination of modelopt.torch.prune
, modelopt.torch.distill
and modelopt.torch.quantize
. More specifically, we will:
Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation
Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation
Export the quantized model to ONNX format for deployment with TensorRT
Prerequisites
Install the Model Optimizer and optional torch and huggingface dependencies:
pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
Note
This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.
Full code
You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub
1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from https://github.com/huggingface/
2# transformers/blob/c52b515e948fc12ff58ad773a0385860d0162f61/examples/pytorch/question-answering
3#
4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
19# SPDX-License-Identifier: MIT
20
21# Permission is hereby granted, free of charge, to any person obtaining a
22# copy of this software and associated documentation files (the "Software"),
23# to deal in the Software without restriction, including without limitation
24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
25# and/or sell copies of the Software, and to permit persons to whom the
26# Software is furnished to do so, subject to the following conditions:
27
28# The above copyright notice and this permission notice shall be included in
29# all copies or substantial portions of the Software.
30
31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
37# DEALINGS IN THE SOFTWARE.
38"""
39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
41"""
42
43import argparse
44import collections
45import json
46import logging
47import math
48import os
49import random
50from typing import Any, Dict, List, Optional, Tuple
51
52import datasets
53import evaluate
54import numpy as np
55import torch
56import torch.nn as nn
57import transformers
58from accelerate import Accelerator
59from accelerate.logging import get_logger
60from accelerate.utils import set_seed
61from torch.utils.data import DataLoader
62from tqdm.auto import tqdm
63from transformers import (
64 AutoModelForQuestionAnswering,
65 AutoTokenizer,
66 DataCollatorWithPadding,
67 EvalPrediction,
68 PreTrainedTokenizer,
69 SchedulerType,
70 default_data_collator,
71 get_scheduler,
72)
73
74import modelopt.torch.distill as mtd
75import modelopt.torch.opt as mto
76import modelopt.torch.prune as mtp
77import modelopt.torch.quantization as mtq
78from modelopt.torch._deploy.utils import get_onnx_bytes
79
80logger = get_logger(__name__)
81
82SEED = 123
83
84
85def parse_args(input_args: Optional[List[str]] = None):
86 parser = argparse.ArgumentParser(
87 description="Finetune a transformers model on a Question Answering task"
88 )
89
90 # Training arguments
91 parser.add_argument(
92 "--model_name_or_path",
93 type=str,
94 default="bert-large-uncased-whole-word-masking-finetuned-squad",
95 help="Path to pretrained model or model identifier from huggingface.co/models.",
96 )
97 parser.add_argument(
98 "--do_train",
99 action="store_true",
100 help="Whether to run training / fine-tuning.",
101 )
102 parser.add_argument(
103 "--per_device_train_batch_size",
104 type=int,
105 default=16,
106 help="Batch size (per device) for the training dataloader.",
107 )
108 parser.add_argument(
109 "--per_device_eval_batch_size",
110 type=int,
111 default=64,
112 help="Batch size (per device) for the evaluation dataloader.",
113 )
114 parser.add_argument(
115 "--learning_rate",
116 type=float,
117 default=5e-5,
118 help="Initial learning rate (after the potential warmup period) to use.",
119 )
120 parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
121 parser.add_argument(
122 "--lr_scheduler_type",
123 type=SchedulerType,
124 default="linear",
125 help="The scheduler type to use.",
126 choices=[
127 "linear",
128 "cosine",
129 "cosine_with_restarts",
130 "polynomial",
131 "constant",
132 "constant_with_warmup",
133 ],
134 )
135 parser.add_argument(
136 "--num_warmup_steps",
137 type=int,
138 default=0,
139 help="Number of steps for the warmup in the lr scheduler.",
140 )
141 parser.add_argument(
142 "--num_train_epochs",
143 type=float,
144 default=2.0,
145 help="Total number of training epochs to perform.",
146 )
147 parser.add_argument(
148 "--max_train_steps",
149 type=int,
150 default=None,
151 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
152 )
153 parser.add_argument(
154 "--gradient_accumulation_steps",
155 type=int,
156 default=1,
157 help="Number of updates steps to accumulate before performing a backward/update pass.",
158 )
159 parser.add_argument(
160 "--preprocessing_num_workers",
161 type=int,
162 default=4,
163 help="The number of processes to use for preprocessing the dataset.",
164 )
165
166 # Logging and checkpointing arguments
167 parser.add_argument(
168 "--output_dir", type=str, default="results", help="Where to store the optimized models."
169 )
170 parser.add_argument(
171 "--with_tracking",
172 action="store_true",
173 help="Whether to enable experiment trackers for logging.",
174 )
175 parser.add_argument(
176 "--checkpointing_steps",
177 type=str,
178 default=None,
179 help=(
180 "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
181 " each epoch."
182 ),
183 )
184 parser.add_argument(
185 "--resume_from_last_ckpt",
186 action="store_true",
187 help="If the training should continue from the latest checkpoint in output_dir.",
188 )
189
190 # Misc arguments for Bert (should not be modified in most cases)
191 parser.add_argument(
192 "--max_seq_length",
193 type=int,
194 default=384,
195 help=(
196 "The maximum total input sequence length after tokenization. Sequences longer than this"
197 " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
198 ),
199 )
200 parser.add_argument(
201 "--pad_to_max_length",
202 action="store_true",
203 help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
204 )
205 parser.add_argument(
206 "--doc_stride",
207 type=int,
208 default=128,
209 help=(
210 "When splitting up a long document into chunks how much stride to take between chunks."
211 ),
212 )
213 parser.add_argument(
214 "--n_best_size",
215 type=int,
216 default=20,
217 help="The total number of n-best predictions to generate when looking for an answer.",
218 )
219 parser.add_argument(
220 "--max_answer_length",
221 type=int,
222 default=30,
223 help=(
224 "The maximum length of an answer that can be generated. This is needed because the"
225 " start and end predictions are not conditioned on one another."
226 ),
227 )
228
229 # Debugging arguments
230 parser.add_argument(
231 "--max_train_samples",
232 type=int,
233 default=None,
234 help="For debugging purposes or quicker training.",
235 )
236 parser.add_argument(
237 "--max_eval_samples",
238 type=int,
239 default=None,
240 help="For debugging purposes or quicker training.",
241 )
242
243 # Model Optimizer: pruning arguments
244 parser.add_argument(
245 "--do_modelopt_prune",
246 action="store_true",
247 help="Whether or not to use Model Optimizer pruning.",
248 )
249 parser.add_argument(
250 "--modelopt_prune_flops_percent",
251 type=float,
252 default=None,
253 help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
254 )
255
256 # Model Optimizer: quantization arguments
257 parser.add_argument(
258 "--modelopt_quantize_cfg",
259 help="Model Optimizer quantization config.",
260 choices=mtq.config.choices,
261 )
262
263 # Model Optimizer: Distillation arguments
264 parser.add_argument(
265 "--do_modelopt_distill",
266 action="store_true",
267 help="Whether or not to use distillation. A teacher model must be specified.",
268 )
269 parser.add_argument(
270 "--temperature",
271 type=float,
272 default=2.0,
273 help="The temperature to use when distilling.",
274 )
275
276 # Model Optimizer: save and restore arguments
277 parser.add_argument(
278 "--modelopt_save_file",
279 type=str,
280 default=None,
281 help="File (inside output_dir) to save the modelopt modified model to.",
282 )
283 parser.add_argument(
284 "--modelopt_restore_path",
285 type=str,
286 default=None,
287 help="Path to restore the modelopt modified model from.",
288 )
289
290 # ONNX export arguments
291 parser.add_argument(
292 "--onnx_export_file",
293 type=str,
294 default=None,
295 help="File (inside output_dir) to export the ONNX model to.",
296 )
297
298 args = parser.parse_args(input_args)
299
300 # Sanity checks
301 if args.do_modelopt_prune and not args.modelopt_prune_flops_percent:
302 raise ValueError(
303 "Need a `modelopt_prune_flops_percent` when `do_modelopt_prune` is passed."
304 )
305
306 return args
307
308
309def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
310 """Get the examples, dataset, dataloader, answer_column_name
311
312 You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
313 or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
314 (the dataset will be downloaded automatically from the datasets Hub).
315
316 For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
317 'text' is found. You can easily tweak this behavior (see below).
318 """
319
320 def prepare_train_features(examples):
321 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
322 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
323 # left whitespace
324 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
325
326 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
327 # in one example possible giving several features when a context is long, each of those features having a
328 # context that overlaps a bit the context of the previous feature.
329 tokenized_examples = tokenizer(
330 examples[question_column_name if pad_on_right else context_column_name],
331 examples[context_column_name if pad_on_right else question_column_name],
332 truncation="only_second" if pad_on_right else "only_first",
333 max_length=max_seq_length,
334 stride=args.doc_stride,
335 return_overflowing_tokens=True,
336 return_offsets_mapping=True,
337 padding="max_length" if args.pad_to_max_length else False,
338 )
339
340 # Since one example might give us several features if it has a long context, we need a map from a feature to
341 # its corresponding example. This key gives us just that.
342 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
343 # The offset mappings will give us a map from token to character position in the original context. This will
344 # help us compute the start_positions and end_positions.
345 offset_mapping = tokenized_examples.pop("offset_mapping")
346
347 # Let's label those examples!
348 tokenized_examples["start_positions"] = []
349 tokenized_examples["end_positions"] = []
350
351 for i, offsets in enumerate(offset_mapping):
352 # We will label impossible answers with the index of the CLS token.
353 input_ids = tokenized_examples["input_ids"][i]
354 cls_index = input_ids.index(tokenizer.cls_token_id)
355
356 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
357 sequence_ids = tokenized_examples.sequence_ids(i)
358
359 # One example can give several spans, this is the index of the example containing this span of text.
360 sample_index = sample_mapping[i]
361 answers = examples[answer_column_name][sample_index]
362 # If no answers are given, set the cls_index as answer.
363 if len(answers["answer_start"]) == 0:
364 tokenized_examples["start_positions"].append(cls_index)
365 tokenized_examples["end_positions"].append(cls_index)
366 else:
367 # Start/end character index of the answer in the text.
368 start_char = answers["answer_start"][0]
369 end_char = start_char + len(answers["text"][0])
370
371 # Start token index of the current span in the text.
372 token_start_index = 0
373 while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
374 token_start_index += 1
375
376 # End token index of the current span in the text.
377 token_end_index = len(input_ids) - 1
378 while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
379 token_end_index -= 1
380
381 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
382 if not (
383 offsets[token_start_index][0] <= start_char
384 and offsets[token_end_index][1] >= end_char
385 ):
386 tokenized_examples["start_positions"].append(cls_index)
387 tokenized_examples["end_positions"].append(cls_index)
388 else:
389 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
390 # Note: we could go after the last offset if the answer is the last word (edge case).
391 while (
392 token_start_index < len(offsets)
393 and offsets[token_start_index][0] <= start_char
394 ):
395 token_start_index += 1
396 tokenized_examples["start_positions"].append(token_start_index - 1)
397 while offsets[token_end_index][1] >= end_char:
398 token_end_index -= 1
399 tokenized_examples["end_positions"].append(token_end_index + 1)
400
401 return tokenized_examples
402
403 def prepare_validation_features(examples):
404 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
405 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
406 # left whitespace
407 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
408
409 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
410 # in one example possible giving several features when a context is long, each of those features having a
411 # context that overlaps a bit the context of the previous feature.
412 tokenized_examples = tokenizer(
413 examples[question_column_name if pad_on_right else context_column_name],
414 examples[context_column_name if pad_on_right else question_column_name],
415 truncation="only_second" if pad_on_right else "only_first",
416 max_length=max_seq_length,
417 stride=args.doc_stride,
418 return_overflowing_tokens=True,
419 return_offsets_mapping=True,
420 padding="max_length" if args.pad_to_max_length else False,
421 )
422
423 # Since one example might give us several features if it has a long context, we need a map from a feature to
424 # its corresponding example. This key gives us just that.
425 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
426
427 # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
428 # corresponding example_id and we will store the offset mappings.
429 tokenized_examples["example_id"] = []
430
431 for i in range(len(tokenized_examples["input_ids"])):
432 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
433 sequence_ids = tokenized_examples.sequence_ids(i)
434 context_index = 1 if pad_on_right else 0
435
436 # One example can give several spans, this is the index of the example containing this span of text.
437 sample_index = sample_mapping[i]
438 tokenized_examples["example_id"].append(examples["id"][sample_index])
439
440 # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
441 # position is part of the context or not.
442 tokenized_examples["offset_mapping"][i] = [
443 (o if sequence_ids[k] == context_index else None)
444 for k, o in enumerate(tokenized_examples["offset_mapping"][i])
445 ]
446
447 return tokenized_examples
448
449 examples, dataset, dataloader = {}, {}, {}
450
451 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
452 # download the dataset.
453 # Downloading and loading a dataset from the hub.
454 raw_datasets = datasets.load_dataset("squad")
455 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
456 # https://huggingface.co/docs/datasets/loading_datasets.
457
458 # Preprocessing the datasets.
459 # Preprocessing is slighlty different for training and evaluation.
460
461 column_names = raw_datasets["train"].column_names
462
463 question_column_name = "question" if "question" in column_names else column_names[0]
464 context_column_name = "context" if "context" in column_names else column_names[1]
465 answer_column_name = "answers" if "answers" in column_names else column_names[2]
466
467 # Padding side determines if we do (question|context) or (context|question).
468 pad_on_right = tokenizer.padding_side == "right"
469
470 if args.max_seq_length > tokenizer.model_max_length:
471 logger.warning(
472 f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
473 f" for the model ({tokenizer.model_max_length}). Using"
474 f" max_seq_length={tokenizer.model_max_length}."
475 )
476
477 max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
478
479 examples["train"] = raw_datasets["train"]
480 if args.max_train_samples is not None:
481 # We will select sample from whole data if agument is specified
482 examples["train"] = examples["train"].select(range(args.max_train_samples))
483
484 # Create train feature from dataset
485 with accelerator.main_process_first():
486 dataset["train"] = examples["train"].map(
487 prepare_train_features,
488 batched=True,
489 num_proc=args.preprocessing_num_workers,
490 remove_columns=column_names,
491 load_from_cache_file=True,
492 desc="Running tokenizer on train dataset",
493 )
494 # if args.max_train_samples is not None:
495 # # Number of samples might increase during Feature Creation, We select only specified max samples
496 # dataset["train"] = dataset["train"].select(range(args.max_train_samples))
497
498 examples["eval"] = raw_datasets["validation"]
499 if args.max_eval_samples is not None:
500 # We will select sample from whole data
501 examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
502 # Validation Feature Creation
503 with accelerator.main_process_first():
504 dataset["eval"] = examples["eval"].map(
505 prepare_validation_features,
506 batched=True,
507 num_proc=args.preprocessing_num_workers,
508 remove_columns=column_names,
509 load_from_cache_file=True,
510 desc="Running tokenizer on validation dataset",
511 )
512 # if args.max_eval_samples is not None:
513 # # During Feature creation dataset samples might increase, we will select required samples again
514 # dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
515
516 # Log a random sample from the training set:
517 for index in random.sample(range(len(dataset["train"])), 1):
518 logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
519
520 # DataLoaders creation:
521 if args.pad_to_max_length:
522 # If padding was already done ot max length, we use the default data collator that will just convert everything
523 # to tensors.
524 data_collator = default_data_collator
525 else:
526 # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
527 # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
528 # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
529 data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
530
531 dataloader["train"] = DataLoader(
532 dataset["train"],
533 shuffle=True,
534 collate_fn=data_collator,
535 batch_size=args.per_device_train_batch_size,
536 )
537
538 dataloader["eval"] = DataLoader(
539 dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
540 collate_fn=data_collator,
541 batch_size=args.per_device_eval_batch_size,
542 )
543
544 return examples, dataset, dataloader, answer_column_name
545
546
547def evaluate_model(
548 args,
549 model: nn.Module,
550 accelerator: Accelerator,
551 eval_examples: Any,
552 eval_dataset: Any,
553 eval_dataloader: DataLoader,
554 answer_column_name: str,
555 prefix: str = "Eval",
556):
557 def create_and_fill_np_array(start_or_end_logits, max_len):
558 """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
559
560 Args:
561 start_or_end_logits: This is the output predictions of the model.
562 We can only enter either start or end logits.
563 max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
564 """
565 step = 0
566 # create a numpy array and fill it with -100.
567 logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
568 # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
569 for i, output_logit in enumerate(start_or_end_logits): # populate columns
570 # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
571 # And after every iteration we have to change the step
572 batch_size = output_logit.shape[0]
573 cols = output_logit.shape[1]
574
575 if step + batch_size < len(eval_dataset):
576 logits_concat[step : step + batch_size, :cols] = output_logit
577 else:
578 logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
579
580 step += batch_size
581
582 return logits_concat
583
584 def postprocess_qa_predictions(
585 examples,
586 features,
587 predictions: Tuple[np.ndarray, np.ndarray],
588 version_2_with_negative: bool = False,
589 n_best_size: int = 20,
590 max_answer_length: int = 30,
591 null_score_diff_threshold: float = 0.0,
592 output_dir: Optional[str] = None,
593 prefix: Optional[str] = None,
594 ) -> EvalPrediction:
595 """Post-processes the predictions of a question-answering model to convert them to answers
596 that are substrings of the original contexts. This is the base postprocessing functions for
597 models that only return start and end logits.
598
599 Args:
600 examples: The non-preprocessed dataset.
601 features: The processed dataset.
602 predictions: The predictions of the model: two arrays containing the start logits and the end logits
603 respectively. Its first dimension must match the number of elements of `features`.
604 version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
605 n_best_size: The total number of n-best predictions to generate when looking for an answer.
606 max_answer_length: The maximum length of an answer that can be generated. This is needed
607 because the start and end predictions are not conditioned on one another.
608 null_score_diff_threshold: The threshold used to select the null answer: if the best answer
609 has a score that is less than the score of the null answer minus this threshold, the
610 null answer is selected for this example (note that the score of the null answer for
611 an example giving several features is the minimum of the scores for the null answer on
612 each feature: all features must be aligned on the fact they `want` to predict a null answer).
613 Only useful when `version_2_with_negative` is `True`.
614 output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
615 and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
616 answers, are saved in `output_dir`.
617 prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
618 """
619 if len(predictions) != 2:
620 raise ValueError(
621 "`predictions` should be a tuple with two elements (start_logits, end_logits)."
622 )
623 all_start_logits, all_end_logits = predictions
624
625 if len(predictions[0]) != len(features):
626 raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
627
628 # Build a map example to its corresponding features.
629 example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
630 features_per_example = collections.defaultdict(list)
631 for i, feature in enumerate(features):
632 features_per_example[example_id_to_index[feature["example_id"]]].append(i)
633
634 # The dictionaries we have to fill.
635 all_predictions = collections.OrderedDict()
636 all_nbest_json = collections.OrderedDict()
637 if version_2_with_negative:
638 scores_diff_json = collections.OrderedDict()
639
640 logger.debug(
641 f"Post-processing {len(examples)} example predictions split into"
642 f" {len(features)} features."
643 )
644
645 # Let's loop over all the examples!
646 for example_index, example in enumerate(examples):
647 # Those are the indices of the features associated to the current example.
648 feature_indices = features_per_example[example_index]
649
650 min_null_prediction = None
651 prelim_predictions = []
652
653 # Looping through all the features associated to the current example.
654 for feature_index in feature_indices:
655 # We grab the predictions of the model for this feature.
656 start_logits = all_start_logits[feature_index]
657 end_logits = all_end_logits[feature_index]
658 # This is what will allow us to map some the positions in our logits to span of texts in the original
659 # context.
660 offset_mapping = features[feature_index]["offset_mapping"]
661 # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
662 # context available in the current feature.
663 token_is_max_context = features[feature_index].get("token_is_max_context", None)
664
665 # Update minimum null prediction.
666 feature_null_score = start_logits[0] + end_logits[0]
667 if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
668 min_null_prediction = {
669 "offsets": (0, 0),
670 "score": feature_null_score,
671 "start_logit": start_logits[0],
672 "end_logit": end_logits[0],
673 }
674
675 # Go through all possibilities for the `n_best_size` greater start and end logits.
676 start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
677 end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
678 for start_index in start_indexes:
679 for end_index in end_indexes:
680 # Don't consider out-of-scope answers, either because the indices are out of bounds or
681 # correspond to part of the input_ids that are not in the context.
682 if (
683 start_index >= len(offset_mapping)
684 or end_index >= len(offset_mapping)
685 or offset_mapping[start_index] is None
686 or len(offset_mapping[start_index]) < 2
687 or offset_mapping[end_index] is None
688 or len(offset_mapping[end_index]) < 2
689 ):
690 continue
691 # Don't consider answers with a length that is either < 0 or > max_answer_length.
692 if (
693 end_index < start_index
694 or end_index - start_index + 1 > max_answer_length
695 ):
696 continue
697 # Don't consider answer that don't have the maximum context available (if such information is
698 # provided).
699 if token_is_max_context is not None and not token_is_max_context.get(
700 str(start_index), False
701 ):
702 continue
703
704 prelim_predictions.append(
705 {
706 "offsets": (
707 offset_mapping[start_index][0],
708 offset_mapping[end_index][1],
709 ),
710 "score": start_logits[start_index] + end_logits[end_index],
711 "start_logit": start_logits[start_index],
712 "end_logit": end_logits[end_index],
713 }
714 )
715 if version_2_with_negative and min_null_prediction is not None:
716 # Add the minimum null prediction
717 prelim_predictions.append(min_null_prediction)
718 null_score = min_null_prediction["score"]
719
720 # Only keep the best `n_best_size` predictions.
721 n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
722 :n_best_size
723 ]
724
725 # Add back the minimum null prediction if it was removed because of its low score.
726 if (
727 version_2_with_negative
728 and min_null_prediction is not None
729 and not any(p["offsets"] == (0, 0) for p in n_best_preds)
730 ):
731 n_best_preds.append(min_null_prediction)
732
733 # Use the offsets to gather the answer text in the original context.
734 context = example["context"]
735 for pred in n_best_preds:
736 offsets = pred.pop("offsets")
737 pred["text"] = context[offsets[0] : offsets[1]]
738
739 # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
740 # failure.
741 if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
742 n_best_preds.insert(
743 0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
744 )
745
746 # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
747 # using the LogSumExp trick).
748 scores = np.array([pred.pop("score") for pred in n_best_preds])
749 exp_scores = np.exp(scores - np.max(scores))
750 probs = exp_scores / exp_scores.sum()
751
752 # Include the probabilities in our n_best_preds.
753 for prob, pred in zip(probs, n_best_preds):
754 pred["probability"] = prob
755
756 # Pick the best prediction. If the null answer is not possible, this is easy.
757 if not version_2_with_negative:
758 all_predictions[example["id"]] = n_best_preds[0]["text"]
759 else:
760 # Otherwise we first need to find the best non-empty prediction.
761 i = 0
762 while n_best_preds[i]["text"] == "":
763 i += 1
764 best_non_null_pred = n_best_preds[i]
765
766 # Then we compare to the null prediction using the threshold.
767 score_diff = (
768 null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
769 )
770 scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
771 if score_diff > null_score_diff_threshold:
772 all_predictions[example["id"]] = ""
773 else:
774 all_predictions[example["id"]] = best_non_null_pred["text"]
775
776 # Make `n_best_preds` JSON-serializable by casting np.float back to float.
777 all_nbest_json[example["id"]] = [
778 {
779 k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
780 for k, v in pred.items()
781 }
782 for pred in n_best_preds
783 ]
784
785 # If we have an output_dir, let's save all those dicts.
786 if output_dir is not None:
787 if not os.path.isdir(output_dir):
788 raise EnvironmentError(f"{output_dir} is not a directory.")
789
790 prediction_file = os.path.join(
791 output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
792 )
793 nbest_file = os.path.join(
794 output_dir,
795 "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
796 )
797 if version_2_with_negative:
798 null_odds_file = os.path.join(
799 output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
800 )
801
802 logger.info(f"Saving predictions to {prediction_file}.")
803 with open(prediction_file, "w") as writer:
804 writer.write(json.dumps(all_predictions, indent=4) + "\n")
805 logger.info(f"Saving nbest_preds to {nbest_file}.")
806 with open(nbest_file, "w") as writer:
807 writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
808 if version_2_with_negative:
809 logger.info(f"Saving null_odds to {null_odds_file}.")
810 with open(null_odds_file, "w") as writer:
811 writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
812
813 # Format the result to the format the metric expects.
814 formatted_predictions = [
815 {"id": k, "prediction_text": v} for k, v in all_predictions.items()
816 ]
817
818 references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
819 return EvalPrediction(predictions=formatted_predictions, label_ids=references)
820
821 logger.info(f"***** Running {prefix} *****")
822 logger.info(f" Num examples = {len(eval_dataset)}")
823 logger.info(f" Batch size = {args.per_device_eval_batch_size}")
824
825 model.eval()
826 all_start_logits = []
827 all_end_logits = []
828 for _, batch in enumerate(eval_dataloader):
829 with torch.no_grad():
830 outputs = model(**batch)
831 start_logits = outputs.start_logits
832 end_logits = outputs.end_logits
833
834 if (
835 not args.pad_to_max_length
836 ): # necessary to pad predictions and labels for being gathered
837 start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
838 end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
839
840 all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
841 all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
842
843 # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
844 if args.do_modelopt_distill:
845 model.module.compute_kd_loss()
846
847 max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
848
849 # concatenate the numpy array
850 start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
851 end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
852
853 outputs_numpy = (start_logits_concat, end_logits_concat)
854 prediction = postprocess_qa_predictions(
855 examples=eval_examples,
856 features=eval_dataset,
857 predictions=outputs_numpy,
858 n_best_size=args.n_best_size,
859 max_answer_length=args.max_answer_length,
860 # output_dir=args.output_dir,
861 prefix=prefix,
862 )
863
864 metric = evaluate.load("squad")
865 eval_metric = metric.compute(
866 predictions=prediction.predictions, references=prediction.label_ids
867 )
868 logger.info(f"{prefix} metrics: {eval_metric}\n")
869 return eval_metric
870
871
872# Model Optimizer: Define a teacher factory for initializing the distillation model
873def teacher_factory(model_name_or_path):
874 return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
875
876
877# Model Optimizer: Define a custom distillation loss function that uses start and end logits
878class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
879 def forward(self, outputs_s, outputs_t):
880 loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
881 loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
882 loss = (loss_start + loss_end) / 2.0
883
884 return loss
885
886
887def train_and_evaluate_model(
888 args,
889 model: nn.Module,
890 tokenizer: PreTrainedTokenizer,
891 accelerator: Accelerator,
892 examples: Dict,
893 dataset: Dict,
894 dataloader: Dict[str, DataLoader],
895 answer_column_name,
896):
897 # Optimizer
898 # Split weights in two groups, one with weight decay and the other not.
899 no_decay = ["bias", "LayerNorm.weight"]
900 optimizer_grouped_parameters = [
901 {
902 "params": [
903 p
904 for n, p in model.named_parameters()
905 if p.requires_grad and not any(nd in n for nd in no_decay)
906 ],
907 "weight_decay": args.weight_decay,
908 },
909 {
910 "params": [
911 p
912 for n, p in model.named_parameters()
913 if p.requires_grad and any(nd in n for nd in no_decay)
914 ],
915 "weight_decay": 0.0,
916 },
917 ]
918 optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
919
920 # Scheduler and math around the number of training steps.
921 overrode_max_train_steps = False
922 num_update_steps_per_epoch = math.ceil(
923 len(dataloader["train"]) / args.gradient_accumulation_steps
924 )
925 if args.max_train_steps is None:
926 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
927 overrode_max_train_steps = True
928
929 lr_scheduler = get_scheduler(
930 name=args.lr_scheduler_type,
931 optimizer=optimizer,
932 num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
933 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
934 )
935
936 # Prepare everything with our `accelerator`.
937 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
938 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
939 )
940
941 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
942 num_update_steps_per_epoch = math.ceil(
943 len(dataloader["train"]) / args.gradient_accumulation_steps
944 )
945 if overrode_max_train_steps:
946 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
947 # Afterwards we recalculate our number of training epochs
948 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
949
950 # Figure out how many steps we should save the Accelerator states
951 checkpointing_steps = args.checkpointing_steps
952 if checkpointing_steps is not None and checkpointing_steps.isdigit():
953 checkpointing_steps = int(checkpointing_steps)
954
955 # We need to initialize the trackers we use, and also store our configuration.
956 # The trackers initializes automatically on the main process.
957 if args.with_tracking:
958 experiment_config = vars(args)
959 # TensorBoard cannot log Enums, need the raw value
960 experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
961 accelerator.init_trackers("tensorboard", experiment_config)
962
963 # Train!
964 total_batch_size = (
965 args.per_device_train_batch_size
966 * accelerator.num_processes
967 * args.gradient_accumulation_steps
968 )
969
970 logger.info("***** Running training *****")
971 logger.info(f" Num examples = {len(dataset['train'])}")
972 logger.info(f" Num Epochs = {args.num_train_epochs}")
973 logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
974 logger.info(
975 f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
976 )
977 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
978 logger.info(f" Total optimization steps = {args.max_train_steps}")
979
980 # Only show the progress bar once on each machine.
981 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
982 completed_steps = 0
983 starting_epoch = 0
984 resume_step = None
985
986 # Potentially load in the weights and states from a previous save
987 if args.resume_from_last_ckpt:
988 # Get the most recent checkpoint
989 dirs = [
990 f.path
991 for f in os.scandir(args.output_dir)
992 if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
993 ]
994 if len(dirs) == 0:
995 logger.warning("No checkpoint found in output_dir. Training from scratch!")
996 else:
997 latest_dir = max(dirs, key=os.path.getctime)
998 accelerator.load_state(latest_dir)
999
1000 # Extract `epoch_{i}` or `step_{i}`
1001 latest_dir = os.path.basename(latest_dir)
1002 if "epoch" in latest_dir:
1003 starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1004 completed_steps = starting_epoch * num_update_steps_per_epoch
1005 else:
1006 # need to multiply `gradient_accumulation_steps` to reflect real steps
1007 resume_step = (
1008 int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1009 )
1010 starting_epoch = resume_step // len(dataloader["train"])
1011 completed_steps = resume_step // args.gradient_accumulation_steps
1012 resume_step -= starting_epoch * len(dataloader["train"])
1013
1014 # update the progress_bar if load from checkpoint
1015 progress_bar.update(completed_steps)
1016
1017 # Evaluate before training (e.g. PTQ accuracy before QAT)
1018 eval_metric = evaluate_model(
1019 args,
1020 model,
1021 accelerator,
1022 examples["eval"],
1023 dataset["eval"],
1024 dataloader["eval"],
1025 answer_column_name,
1026 )
1027 for epoch in range(starting_epoch, args.num_train_epochs):
1028 model.train()
1029 if args.with_tracking:
1030 total_loss = 0
1031 if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1032 # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1033 active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1034 else:
1035 active_dataloader = dataloader["train"]
1036 for batch in active_dataloader:
1037 optimizer.zero_grad()
1038 outputs = model(**batch)
1039
1040 # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1041 if args.do_modelopt_distill:
1042 loss = model.module.compute_kd_loss()
1043 else:
1044 loss, _, _ = outputs.to_tuple()
1045
1046 # We keep track of the loss at each epoch
1047 if args.with_tracking:
1048 total_loss += loss.detach().float()
1049
1050 accelerator.backward(loss)
1051 optimizer.step()
1052 lr_scheduler.step()
1053
1054 # Checks if the accelerator has performed an optimization step behind the scenes
1055 if accelerator.sync_gradients:
1056 progress_bar.update(1)
1057 completed_steps += 1
1058
1059 if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1060 accelerator.save_state(os.path.join(args.output_dir, f"step_{completed_steps}"))
1061
1062 if completed_steps >= args.max_train_steps:
1063 break
1064
1065 if args.checkpointing_steps == "epoch":
1066 accelerator.save_state(os.path.join(args.output_dir, f"epoch_{epoch}"))
1067
1068 eval_metric = evaluate_model(
1069 args,
1070 model,
1071 accelerator,
1072 examples["eval"],
1073 dataset["eval"],
1074 dataloader["eval"],
1075 answer_column_name,
1076 )
1077
1078 if args.with_tracking:
1079 log = {
1080 "squad": eval_metric,
1081 "train_loss": total_loss.item() / len(dataloader["train"]), # type: ignore[attr-defined]
1082 "epoch": epoch,
1083 "step": completed_steps,
1084 }
1085 accelerator.log(log, step=completed_steps)
1086
1087 accelerator.wait_for_everyone()
1088 if accelerator.is_main_process and eval_metric:
1089 logger.info(json.dumps(eval_metric, indent=4))
1090 # Prefix all keys with prefix + '_'
1091 for key in list(eval_metric.keys()):
1092 eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094 with open(os.path.join(args.output_dir, "results.json"), "w") as f:
1095 json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099 args = parse_args(input_args)
1100
1101 # Initialize the accelerator
1102 accelerator_log_kwargs = {}
1103 if args.with_tracking:
1104 accelerator_log_kwargs["log_with"] = "tensorboard"
1105 accelerator_log_kwargs["project_dir"] = args.output_dir
1106 accelerator = Accelerator(
1107 gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108 )
1109
1110 # Setup logging
1111 logging.basicConfig(
1112 format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113 datefmt="%m/%d/%Y %H:%M:%S",
1114 level=logging.INFO,
1115 )
1116 logger.info(accelerator.state, main_process_only=False)
1117 if accelerator.is_local_main_process:
1118 datasets.utils.logging.set_verbosity_warning()
1119 transformers.utils.logging.set_verbosity_info()
1120 else:
1121 datasets.utils.logging.set_verbosity_error()
1122 transformers.utils.logging.set_verbosity_error()
1123
1124 # Set the training seed
1125 set_seed(SEED)
1126
1127 # Handle the output_dir creation
1128 if accelerator.is_main_process:
1129 os.makedirs(args.output_dir, exist_ok=True)
1130 accelerator.wait_for_everyone()
1131
1132 # Load pretrained model and tokenizer
1133 tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1134 model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1135 dummy_input = model.dummy_inputs["input_ids"]
1136
1137 # Get datasets
1138 examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1139 args, tokenizer, accelerator
1140 )
1141
1142 # Model Optimizer: Save the pruned or quantized model
1143 def save_modelopt_model(model):
1144 if accelerator.is_main_process and args.modelopt_save_file is not None:
1145 save_path = os.path.join(args.output_dir, args.modelopt_save_file)
1146 logger.info(f"Saving modelopt optimized model to {save_path}")
1147 mto.save(model, save_path)
1148
1149 # Model Optimizer: Restore the pruned or quantized model from a checkpoint
1150 if args.modelopt_restore_path:
1151 assert os.path.exists(
1152 args.modelopt_restore_path
1153 ), f"{args.modelopt_restore_path} does not exist."
1154 logger.info(f"Restoring model from {args.modelopt_restore_path}")
1155 model = mto.restore(model, args.modelopt_restore_path)
1156
1157 # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1158 if args.do_modelopt_prune:
1159 logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1160
1161 # NOTE: gradnas does not perform synchronization across data parallel groups
1162 # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1163 # in the data parallel group have the same gradients for pruning
1164 model = model.to(accelerator.device)
1165 dummy_input = dummy_input.to(accelerator.device)
1166
1167 # Search for the best pruned model
1168 # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1169 model, _ = mtp.prune(
1170 model=model,
1171 mode="gradnas",
1172 constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1173 dummy_input=dummy_input,
1174 config={
1175 "data_loader": dataloader["train"],
1176 "collect_func": lambda batch: (batch,),
1177 "loss_func": lambda output, batch: (
1178 output["loss"] if isinstance(output, dict) else output[0]
1179 ),
1180 },
1181 )
1182 save_modelopt_model(model)
1183
1184 # Model Optimizer: Quantize the model to INT8 precision
1185 if args.modelopt_quantize_cfg:
1186 logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1187
1188 # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1189 # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1190 # in the data parallel group have same calibration statistics
1191 model = model.to(accelerator.device)
1192
1193 def forward_loop(model):
1194 num_samples = 256 # Use only 256 samples for PTQ calibration
1195 num_batches = num_samples // args.per_device_train_batch_size
1196 for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1197 batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1198 model(**batch)
1199 if idx >= num_batches:
1200 break
1201
1202 model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1203 torch.cuda.empty_cache()
1204 save_modelopt_model(model)
1205
1206 if args.do_train:
1207 # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1208 if args.do_modelopt_distill:
1209 logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1210
1211 kd_config = {
1212 "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1213 "criterion": StartEndLogitsDistillationLoss(args.temperature),
1214 }
1215 model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1216
1217 train_and_evaluate_model(
1218 args,
1219 model,
1220 tokenizer,
1221 accelerator,
1222 examples,
1223 dataset,
1224 dataloader,
1225 answer_column_name,
1226 )
1227
1228 # Model Optimizer: Export the distilled model
1229 if args.do_modelopt_distill:
1230 model = mtd.export(model)
1231
1232 save_modelopt_model(model)
1233
1234 if accelerator.is_main_process and args.onnx_export_file is not None:
1235 save_path = os.path.join(args.output_dir, args.onnx_export_file)
1236 logger.info(f"Exporting ONNX model to {save_path}")
1237
1238 # Move the model and dummy_input to the device
1239 model = model.to(accelerator.device)
1240 dummy_input = dummy_input.to(accelerator.device)
1241
1242 with open(save_path, "wb") as f:
1243 f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1244
1245 logger.info("Done!")
1246
1247
1248if __name__ == "__main__":
1249 main()
Commands
First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent if [ ! -f "${BASE_DIR}/pruned_model.pth" ]; then modelopt_args="--do_modelopt_prune --modelopt_prune_flops_percent ${FLOPS_PERCENT}" else modelopt_args="--modelopt_restore_path ${BASE_DIR}/pruned_model.pth" fi modelopt_args="${modelopt_args} --modelopt_save_file pruned_model.pth" # Run pruning followed by distributed fine-tuning on the pruned model accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${BASE_DIR} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 1e-4 \ --per_device_train_batch_size 16 \ --num_train_epochs 15 \ --with_tracking \ --checkpointing_steps epoch \ --resume_from_last_ckpt
Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent OUTPUT_DIR=${BASE_DIR}/int8_quantized if [ ! -f "${OUTPUT_DIR}/quantized_model.pth" ]; then modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --modelopt_restore_path ${BASE_DIR}/pruned_model.pth" else modelopt_args="--modelopt_restore_path ${OUTPUT_DIR}/quantized_model.pth" fi modelopt_args="${modelopt_args} --modelopt_save_file quantized_model.pth" # Run distributed QAT on the pruned model with 0.1x LR and less epochs accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${OUTPUT_DIR} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 5e-6 \ --per_device_train_batch_size 16 \ --num_train_epochs 2 \ --with_tracking \ --checkpointing_steps epoch \ --resume_from_last_ckpt
Export the quantized model to ONNX format for deployment with TensorRT.
#!/bin/bash set -e set -x FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent # Export to ONNX on a single GPU python3 bert_prune_distill_quantize.py \ --model_name_or_path bert-large-uncased-whole-word-masking-finetuned-squad \ --output_dir ${BASE_DIR}/onnx \ --modelopt_restore_path ${BASE_DIR}/int8_quantized/quantized_model.pth \ --onnx_export_file pruned_model_int8.onnx \