HF BERT: Prune, Distill & Quantize
This example shows how to compress a Hugging Face Bert large model for Question Answering
using the combination of modelopt.torch.prune
, modelopt.torch.distill
and modelopt.torch.quantize
. More specifically, we will:
Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation
Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation
Export the quantized model to ONNX format for deployment with TensorRT
Prerequisites
Install Model Optimizer with optional torch and huggingface dependencies:
pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
Note
This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.
Full code
You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub
1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from
2# https://github.com/huggingface/transformers/blob/c52b515e/examples/pytorch/question-answering
3#
4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
19# SPDX-License-Identifier: MIT
20#
21# Permission is hereby granted, free of charge, to any person obtaining a
22# copy of this software and associated documentation files (the "Software"),
23# to deal in the Software without restriction, including without limitation
24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
25# and/or sell copies of the Software, and to permit persons to whom the
26# Software is furnished to do so, subject to the following conditions:
27#
28# The above copyright notice and this permission notice shall be included in
29# all copies or substantial portions of the Software.
30#
31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
37# DEALINGS IN THE SOFTWARE.
38"""
39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
41"""
42
43import argparse
44import collections
45import json
46import logging
47import math
48import os
49import random
50from typing import Any, Dict, List, Optional, Tuple
51
52import datasets
53import evaluate
54import numpy as np
55import torch
56import torch.nn as nn
57import transformers
58from accelerate import Accelerator
59from accelerate.logging import get_logger
60from accelerate.utils import set_seed
61from torch.utils.data import DataLoader
62from tqdm.auto import tqdm
63from transformers import (
64 AutoModelForQuestionAnswering,
65 AutoTokenizer,
66 DataCollatorWithPadding,
67 EvalPrediction,
68 PreTrainedTokenizer,
69 SchedulerType,
70 default_data_collator,
71 get_scheduler,
72)
73
74# Model Optimizer: imports
75import modelopt.torch.distill as mtd
76import modelopt.torch.opt as mto
77import modelopt.torch.prune as mtp
78import modelopt.torch.quantization as mtq
79from modelopt.torch._deploy.utils import get_onnx_bytes
80
81# Enable automatic save/load of modelopt_state with huggingface checkpointing
82mto.enable_huggingface_checkpointing()
83
84logger = get_logger(__name__)
85
86SEED = 123
87
88
89def parse_args(input_args: Optional[List[str]] = None):
90 parser = argparse.ArgumentParser(
91 description="Finetune a transformers model on a Question Answering task"
92 )
93
94 # Training arguments
95 parser.add_argument(
96 "--model_name_or_path",
97 type=str,
98 default="bert-large-uncased-whole-word-masking-finetuned-squad",
99 help="Path to pretrained model or model identifier from huggingface.co/models.",
100 )
101 parser.add_argument(
102 "--do_train", action="store_true", help="Whether to run training / fine-tuning."
103 )
104 parser.add_argument(
105 "--per_device_train_batch_size",
106 type=int,
107 default=16,
108 help="Batch size (per device) for the training dataloader.",
109 )
110 parser.add_argument(
111 "--per_device_eval_batch_size",
112 type=int,
113 default=64,
114 help="Batch size (per device) for the evaluation dataloader.",
115 )
116 parser.add_argument(
117 "--learning_rate",
118 type=float,
119 default=5e-5,
120 help="Initial learning rate (after the potential warmup period) to use.",
121 )
122 parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
123 parser.add_argument(
124 "--lr_scheduler_type",
125 type=SchedulerType,
126 default="linear",
127 help="The scheduler type to use.",
128 choices=[
129 "linear",
130 "cosine",
131 "cosine_with_restarts",
132 "polynomial",
133 "constant",
134 "constant_with_warmup",
135 ],
136 )
137 parser.add_argument(
138 "--num_warmup_steps",
139 type=int,
140 default=0,
141 help="Number of steps for the warmup in the lr scheduler.",
142 )
143 parser.add_argument(
144 "--num_train_epochs",
145 type=float,
146 default=2.0,
147 help="Total number of training epochs to perform.",
148 )
149 parser.add_argument(
150 "--max_train_steps",
151 type=int,
152 default=None,
153 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
154 )
155 parser.add_argument(
156 "--gradient_accumulation_steps",
157 type=int,
158 default=1,
159 help="Number of updates steps to accumulate before performing a backward/update pass.",
160 )
161 parser.add_argument(
162 "--preprocessing_num_workers",
163 type=int,
164 default=4,
165 help="The number of processes to use for preprocessing the dataset.",
166 )
167
168 # Logging and checkpointing arguments
169 parser.add_argument(
170 "--finetuned_model_path",
171 type=str,
172 default=None,
173 help="Path to save the finetuned (pruned or quantized) model for restoring later with `.from_pretrained()`.",
174 )
175 parser.add_argument(
176 "--with_tracking",
177 action="store_true",
178 help="Whether to enable experiment trackers for logging.",
179 )
180 parser.add_argument(
181 "--checkpointing_steps",
182 type=str,
183 default="epoch",
184 help=(
185 "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
186 " each epoch."
187 ),
188 )
189 parser.add_argument(
190 "--resume_from_last_ckpt",
191 action="store_true",
192 help="If the training should continue from the latest checkpoint in model_name_or_path.",
193 )
194 parser.add_argument(
195 "--onnx_export_path", type=str, default=None, help="Path to export the ONNX model to."
196 )
197
198 # Misc arguments for Bert (should not be modified in most cases)
199 parser.add_argument(
200 "--max_seq_length",
201 type=int,
202 default=384,
203 help=(
204 "The maximum total input sequence length after tokenization. Sequences longer than this"
205 " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
206 ),
207 )
208 parser.add_argument(
209 "--pad_to_max_length",
210 action="store_true",
211 help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
212 )
213 parser.add_argument(
214 "--doc_stride",
215 type=int,
216 default=128,
217 help=(
218 "When splitting up a long document into chunks how much stride to take between chunks."
219 ),
220 )
221 parser.add_argument(
222 "--n_best_size",
223 type=int,
224 default=20,
225 help="The total number of n-best predictions to generate when looking for an answer.",
226 )
227 parser.add_argument(
228 "--max_answer_length",
229 type=int,
230 default=30,
231 help=(
232 "The maximum length of an answer that can be generated. This is needed because the"
233 " start and end predictions are not conditioned on one another."
234 ),
235 )
236
237 # Debugging arguments
238 parser.add_argument(
239 "--max_train_samples",
240 type=int,
241 default=None,
242 help="For debugging purposes or quicker training.",
243 )
244 parser.add_argument(
245 "--max_eval_samples",
246 type=int,
247 default=None,
248 help="For debugging purposes or quicker training.",
249 )
250
251 # Model Optimizer: pruning arguments
252 parser.add_argument(
253 "--do_modelopt_prune",
254 action="store_true",
255 help="Whether or not to use Model Optimizer pruning.",
256 )
257 parser.add_argument(
258 "--modelopt_prune_flops_percent",
259 type=float,
260 default=None,
261 help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
262 )
263 parser.add_argument(
264 "--pruned_model_path",
265 type=str,
266 default=None,
267 help="Path to save the pruned model for further finetuning.",
268 )
269
270 # Model Optimizer: quantization arguments
271 parser.add_argument(
272 "--modelopt_quantize_cfg",
273 help="Model Optimizer quantization config.",
274 choices=mtq.config.choices,
275 )
276
277 # Model Optimizer: Distillation arguments
278 parser.add_argument(
279 "--do_modelopt_distill",
280 action="store_true",
281 help="Whether or not to use distillation. A teacher model must be specified.",
282 )
283 parser.add_argument(
284 "--temperature", type=float, default=2.0, help="The temperature to use when distilling."
285 )
286 parser.add_argument(
287 "--ptq_model_path",
288 type=str,
289 default=None,
290 help="Path to save the PTQ quantized model for further QAT.",
291 )
292
293 args = parser.parse_args(input_args)
294
295 # Sanity checks
296 if args.do_train and not args.finetuned_model_path:
297 raise ValueError("`finetuned_model_path` required when `do_train` is passed.")
298 if args.do_modelopt_prune and not (
299 args.modelopt_prune_flops_percent and args.pruned_model_path
300 ):
301 raise ValueError(
302 "`modelopt_prune_flops_percent` and `pruned_model_path` required when `do_modelopt_prune` is passed."
303 )
304 if args.modelopt_quantize_cfg and not args.ptq_model_path:
305 raise ValueError("`ptq_model_path` required when `modelopt_quantize_cfg` is passed.")
306
307 return args
308
309
310def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
311 """Get the examples, dataset, dataloader, answer_column_name
312
313 You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
314 or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
315 (the dataset will be downloaded automatically from the datasets Hub).
316
317 For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
318 'text' is found. You can easily tweak this behavior (see below).
319 """
320
321 def prepare_train_features(examples):
322 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
323 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
324 # left whitespace
325 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
326
327 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
328 # in one example possible giving several features when a context is long, each of those features having a
329 # context that overlaps a bit the context of the previous feature.
330 tokenized_examples = tokenizer(
331 examples[question_column_name if pad_on_right else context_column_name],
332 examples[context_column_name if pad_on_right else question_column_name],
333 truncation="only_second" if pad_on_right else "only_first",
334 max_length=max_seq_length,
335 stride=args.doc_stride,
336 return_overflowing_tokens=True,
337 return_offsets_mapping=True,
338 padding="max_length" if args.pad_to_max_length else False,
339 )
340
341 # Since one example might give us several features if it has a long context, we need a map from a feature to
342 # its corresponding example. This key gives us just that.
343 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
344 # The offset mappings will give us a map from token to character position in the original context. This will
345 # help us compute the start_positions and end_positions.
346 offset_mapping = tokenized_examples.pop("offset_mapping")
347
348 # Let's label those examples!
349 tokenized_examples["start_positions"] = []
350 tokenized_examples["end_positions"] = []
351
352 for i, offsets in enumerate(offset_mapping):
353 # We will label impossible answers with the index of the CLS token.
354 input_ids = tokenized_examples["input_ids"][i]
355 cls_index = input_ids.index(tokenizer.cls_token_id)
356
357 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
358 sequence_ids = tokenized_examples.sequence_ids(i)
359
360 # One example can give several spans, this is the index of the example containing this span of text.
361 sample_index = sample_mapping[i]
362 answers = examples[answer_column_name][sample_index]
363 # If no answers are given, set the cls_index as answer.
364 if len(answers["answer_start"]) == 0:
365 tokenized_examples["start_positions"].append(cls_index)
366 tokenized_examples["end_positions"].append(cls_index)
367 else:
368 # Start/end character index of the answer in the text.
369 start_char = answers["answer_start"][0]
370 end_char = start_char + len(answers["text"][0])
371
372 # Start token index of the current span in the text.
373 token_start_index = 0
374 while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
375 token_start_index += 1
376
377 # End token index of the current span in the text.
378 token_end_index = len(input_ids) - 1
379 while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
380 token_end_index -= 1
381
382 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
383 if not (
384 offsets[token_start_index][0] <= start_char
385 and offsets[token_end_index][1] >= end_char
386 ):
387 tokenized_examples["start_positions"].append(cls_index)
388 tokenized_examples["end_positions"].append(cls_index)
389 else:
390 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
391 # Note: we could go after the last offset if the answer is the last word (edge case).
392 while (
393 token_start_index < len(offsets)
394 and offsets[token_start_index][0] <= start_char
395 ):
396 token_start_index += 1
397 tokenized_examples["start_positions"].append(token_start_index - 1)
398 while offsets[token_end_index][1] >= end_char:
399 token_end_index -= 1
400 tokenized_examples["end_positions"].append(token_end_index + 1)
401
402 return tokenized_examples
403
404 def prepare_validation_features(examples):
405 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
406 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
407 # left whitespace
408 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
409
410 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
411 # in one example possible giving several features when a context is long, each of those features having a
412 # context that overlaps a bit the context of the previous feature.
413 tokenized_examples = tokenizer(
414 examples[question_column_name if pad_on_right else context_column_name],
415 examples[context_column_name if pad_on_right else question_column_name],
416 truncation="only_second" if pad_on_right else "only_first",
417 max_length=max_seq_length,
418 stride=args.doc_stride,
419 return_overflowing_tokens=True,
420 return_offsets_mapping=True,
421 padding="max_length" if args.pad_to_max_length else False,
422 )
423
424 # Since one example might give us several features if it has a long context, we need a map from a feature to
425 # its corresponding example. This key gives us just that.
426 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
427
428 # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
429 # corresponding example_id and we will store the offset mappings.
430 tokenized_examples["example_id"] = []
431
432 for i in range(len(tokenized_examples["input_ids"])):
433 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
434 sequence_ids = tokenized_examples.sequence_ids(i)
435 context_index = 1 if pad_on_right else 0
436
437 # One example can give several spans, this is the index of the example containing this span of text.
438 sample_index = sample_mapping[i]
439 tokenized_examples["example_id"].append(examples["id"][sample_index])
440
441 # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
442 # position is part of the context or not.
443 tokenized_examples["offset_mapping"][i] = [
444 (o if sequence_ids[k] == context_index else None)
445 for k, o in enumerate(tokenized_examples["offset_mapping"][i])
446 ]
447
448 return tokenized_examples
449
450 examples, dataset, dataloader = {}, {}, {}
451
452 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
453 # download the dataset.
454 # Downloading and loading a dataset from the hub.
455 raw_datasets = datasets.load_dataset("squad")
456 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
457 # https://huggingface.co/docs/datasets/loading_datasets.
458
459 # Preprocessing the datasets.
460 # Preprocessing is slighlty different for training and evaluation.
461
462 column_names = raw_datasets["train"].column_names
463
464 question_column_name = "question" if "question" in column_names else column_names[0]
465 context_column_name = "context" if "context" in column_names else column_names[1]
466 answer_column_name = "answers" if "answers" in column_names else column_names[2]
467
468 # Padding side determines if we do (question|context) or (context|question).
469 pad_on_right = tokenizer.padding_side == "right"
470
471 if args.max_seq_length > tokenizer.model_max_length:
472 logger.warning(
473 f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
474 f" for the model ({tokenizer.model_max_length}). Using"
475 f" max_seq_length={tokenizer.model_max_length}."
476 )
477
478 max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
479
480 examples["train"] = raw_datasets["train"]
481 if args.max_train_samples is not None:
482 # We will select sample from whole data if agument is specified
483 examples["train"] = examples["train"].select(range(args.max_train_samples))
484
485 # Create train feature from dataset
486 with accelerator.main_process_first():
487 dataset["train"] = examples["train"].map(
488 prepare_train_features,
489 batched=True,
490 num_proc=args.preprocessing_num_workers,
491 remove_columns=column_names,
492 load_from_cache_file=True,
493 desc="Running tokenizer on train dataset",
494 )
495 # if args.max_train_samples is not None:
496 # # Number of samples might increase during Feature Creation, We select only specified max samples
497 # dataset["train"] = dataset["train"].select(range(args.max_train_samples))
498
499 examples["eval"] = raw_datasets["validation"]
500 if args.max_eval_samples is not None:
501 # We will select sample from whole data
502 examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
503 # Validation Feature Creation
504 with accelerator.main_process_first():
505 dataset["eval"] = examples["eval"].map(
506 prepare_validation_features,
507 batched=True,
508 num_proc=args.preprocessing_num_workers,
509 remove_columns=column_names,
510 load_from_cache_file=True,
511 desc="Running tokenizer on validation dataset",
512 )
513 # if args.max_eval_samples is not None:
514 # # During Feature creation dataset samples might increase, we will select required samples again
515 # dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
516
517 # Log a random sample from the training set:
518 for index in random.sample(range(len(dataset["train"])), 1):
519 logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
520
521 # DataLoaders creation:
522 if args.pad_to_max_length:
523 # If padding was already done ot max length, we use the default data collator that will just convert everything
524 # to tensors.
525 data_collator = default_data_collator
526 else:
527 # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
528 # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
529 # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
530 data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
531
532 dataloader["train"] = DataLoader(
533 dataset["train"],
534 shuffle=True,
535 collate_fn=data_collator,
536 batch_size=args.per_device_train_batch_size,
537 )
538
539 dataloader["eval"] = DataLoader(
540 dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
541 collate_fn=data_collator,
542 batch_size=args.per_device_eval_batch_size,
543 )
544
545 return examples, dataset, dataloader, answer_column_name
546
547
548def evaluate_model(
549 args,
550 model: nn.Module,
551 accelerator: Accelerator,
552 eval_examples: Any,
553 eval_dataset: Any,
554 eval_dataloader: DataLoader,
555 answer_column_name: str,
556 prefix: str = "Eval",
557):
558 def create_and_fill_np_array(start_or_end_logits, max_len):
559 """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
560
561 Args:
562 start_or_end_logits: This is the output predictions of the model.
563 We can only enter either start or end logits.
564 max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
565 """
566 step = 0
567 # create a numpy array and fill it with -100.
568 logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
569 # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
570 for i, output_logit in enumerate(start_or_end_logits): # populate columns
571 # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
572 # And after every iteration we have to change the step
573 batch_size = output_logit.shape[0]
574 cols = output_logit.shape[1]
575
576 if step + batch_size < len(eval_dataset):
577 logits_concat[step : step + batch_size, :cols] = output_logit
578 else:
579 logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
580
581 step += batch_size
582
583 return logits_concat
584
585 def postprocess_qa_predictions(
586 examples,
587 features,
588 predictions: Tuple[np.ndarray, np.ndarray],
589 version_2_with_negative: bool = False,
590 n_best_size: int = 20,
591 max_answer_length: int = 30,
592 null_score_diff_threshold: float = 0.0,
593 output_dir: Optional[str] = None,
594 prefix: Optional[str] = None,
595 ) -> EvalPrediction:
596 """Post-processes the predictions of a question-answering model to convert them to answers
597 that are substrings of the original contexts. This is the base postprocessing functions for
598 models that only return start and end logits.
599
600 Args:
601 examples: The non-preprocessed dataset.
602 features: The processed dataset.
603 predictions: The predictions of the model: two arrays containing the start logits and the end logits
604 respectively. Its first dimension must match the number of elements of `features`.
605 version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
606 n_best_size: The total number of n-best predictions to generate when looking for an answer.
607 max_answer_length: The maximum length of an answer that can be generated. This is needed
608 because the start and end predictions are not conditioned on one another.
609 null_score_diff_threshold: The threshold used to select the null answer: if the best answer
610 has a score that is less than the score of the null answer minus this threshold, the
611 null answer is selected for this example (note that the score of the null answer for
612 an example giving several features is the minimum of the scores for the null answer on
613 each feature: all features must be aligned on the fact they `want` to predict a null answer).
614 Only useful when `version_2_with_negative` is `True`.
615 output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
616 and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
617 answers, are saved in `output_dir`.
618 prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
619 """
620 if len(predictions) != 2:
621 raise ValueError(
622 "`predictions` should be a tuple with two elements (start_logits, end_logits)."
623 )
624 all_start_logits, all_end_logits = predictions
625
626 if len(predictions[0]) != len(features):
627 raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
628
629 # Build a map example to its corresponding features.
630 example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
631 features_per_example = collections.defaultdict(list)
632 for i, feature in enumerate(features):
633 features_per_example[example_id_to_index[feature["example_id"]]].append(i)
634
635 # The dictionaries we have to fill.
636 all_predictions = collections.OrderedDict()
637 all_nbest_json = collections.OrderedDict()
638 if version_2_with_negative:
639 scores_diff_json = collections.OrderedDict()
640
641 logger.debug(
642 f"Post-processing {len(examples)} example predictions split into"
643 f" {len(features)} features."
644 )
645
646 # Let's loop over all the examples!
647 for example_index, example in enumerate(examples):
648 # Those are the indices of the features associated to the current example.
649 feature_indices = features_per_example[example_index]
650
651 min_null_prediction = None
652 prelim_predictions = []
653
654 # Looping through all the features associated to the current example.
655 for feature_index in feature_indices:
656 # We grab the predictions of the model for this feature.
657 start_logits = all_start_logits[feature_index]
658 end_logits = all_end_logits[feature_index]
659 # This is what will allow us to map some the positions in our logits to span of texts in the original
660 # context.
661 offset_mapping = features[feature_index]["offset_mapping"]
662 # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
663 # context available in the current feature.
664 token_is_max_context = features[feature_index].get("token_is_max_context", None)
665
666 # Update minimum null prediction.
667 feature_null_score = start_logits[0] + end_logits[0]
668 if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
669 min_null_prediction = {
670 "offsets": (0, 0),
671 "score": feature_null_score,
672 "start_logit": start_logits[0],
673 "end_logit": end_logits[0],
674 }
675
676 # Go through all possibilities for the `n_best_size` greater start and end logits.
677 start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
678 end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
679 for start_index in start_indexes:
680 for end_index in end_indexes:
681 # Don't consider out-of-scope answers, either because the indices are out of bounds or
682 # correspond to part of the input_ids that are not in the context.
683 if (
684 start_index >= len(offset_mapping)
685 or end_index >= len(offset_mapping)
686 or offset_mapping[start_index] is None
687 or len(offset_mapping[start_index]) < 2
688 or offset_mapping[end_index] is None
689 or len(offset_mapping[end_index]) < 2
690 ):
691 continue
692 # Don't consider answers with a length that is either < 0 or > max_answer_length.
693 if (
694 end_index < start_index
695 or end_index - start_index + 1 > max_answer_length
696 ):
697 continue
698 # Don't consider answer that don't have the maximum context available (if such information is
699 # provided).
700 if token_is_max_context is not None and not token_is_max_context.get(
701 str(start_index), False
702 ):
703 continue
704
705 prelim_predictions.append(
706 {
707 "offsets": (
708 offset_mapping[start_index][0],
709 offset_mapping[end_index][1],
710 ),
711 "score": start_logits[start_index] + end_logits[end_index],
712 "start_logit": start_logits[start_index],
713 "end_logit": end_logits[end_index],
714 }
715 )
716 if version_2_with_negative and min_null_prediction is not None:
717 # Add the minimum null prediction
718 prelim_predictions.append(min_null_prediction)
719 null_score = min_null_prediction["score"]
720
721 # Only keep the best `n_best_size` predictions.
722 n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
723 :n_best_size
724 ]
725
726 # Add back the minimum null prediction if it was removed because of its low score.
727 if (
728 version_2_with_negative
729 and min_null_prediction is not None
730 and not any(p["offsets"] == (0, 0) for p in n_best_preds)
731 ):
732 n_best_preds.append(min_null_prediction)
733
734 # Use the offsets to gather the answer text in the original context.
735 context = example["context"]
736 for pred in n_best_preds:
737 offsets = pred.pop("offsets")
738 pred["text"] = context[offsets[0] : offsets[1]]
739
740 # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
741 # failure.
742 if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
743 n_best_preds.insert(
744 0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
745 )
746
747 # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
748 # using the LogSumExp trick).
749 scores = np.array([pred.pop("score") for pred in n_best_preds])
750 exp_scores = np.exp(scores - np.max(scores))
751 probs = exp_scores / exp_scores.sum()
752
753 # Include the probabilities in our n_best_preds.
754 for prob, pred in zip(probs, n_best_preds):
755 pred["probability"] = prob
756
757 # Pick the best prediction. If the null answer is not possible, this is easy.
758 if not version_2_with_negative:
759 all_predictions[example["id"]] = n_best_preds[0]["text"]
760 else:
761 # Otherwise we first need to find the best non-empty prediction.
762 i = 0
763 while n_best_preds[i]["text"] == "":
764 i += 1
765 best_non_null_pred = n_best_preds[i]
766
767 # Then we compare to the null prediction using the threshold.
768 score_diff = (
769 null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
770 )
771 scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
772 if score_diff > null_score_diff_threshold:
773 all_predictions[example["id"]] = ""
774 else:
775 all_predictions[example["id"]] = best_non_null_pred["text"]
776
777 # Make `n_best_preds` JSON-serializable by casting np.float back to float.
778 all_nbest_json[example["id"]] = [
779 {
780 k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
781 for k, v in pred.items()
782 }
783 for pred in n_best_preds
784 ]
785
786 # If we have an output_dir, let's save all those dicts.
787 if output_dir is not None:
788 if not os.path.isdir(output_dir):
789 raise EnvironmentError(f"{output_dir} is not a directory.")
790
791 prediction_file = os.path.join(
792 output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
793 )
794 nbest_file = os.path.join(
795 output_dir,
796 "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
797 )
798 if version_2_with_negative:
799 null_odds_file = os.path.join(
800 output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
801 )
802
803 logger.info(f"Saving predictions to {prediction_file}.")
804 with open(prediction_file, "w") as writer:
805 writer.write(json.dumps(all_predictions, indent=4) + "\n")
806 logger.info(f"Saving nbest_preds to {nbest_file}.")
807 with open(nbest_file, "w") as writer:
808 writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
809 if version_2_with_negative:
810 logger.info(f"Saving null_odds to {null_odds_file}.")
811 with open(null_odds_file, "w") as writer:
812 writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
813
814 # Format the result to the format the metric expects.
815 formatted_predictions = [
816 {"id": k, "prediction_text": v} for k, v in all_predictions.items()
817 ]
818
819 references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
820 return EvalPrediction(predictions=formatted_predictions, label_ids=references)
821
822 logger.info(f"***** Running {prefix} *****")
823 logger.info(f" Num examples = {len(eval_dataset)}")
824 logger.info(f" Batch size = {args.per_device_eval_batch_size}")
825
826 model.eval()
827 all_start_logits = []
828 all_end_logits = []
829 for _, batch in enumerate(eval_dataloader):
830 with torch.no_grad():
831 outputs = model(**batch)
832 start_logits = outputs.start_logits
833 end_logits = outputs.end_logits
834
835 if (
836 not args.pad_to_max_length
837 ): # necessary to pad predictions and labels for being gathered
838 start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
839 end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
840
841 all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
842 all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
843
844 max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
845
846 # concatenate the numpy array
847 start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
848 end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
849
850 outputs_numpy = (start_logits_concat, end_logits_concat)
851 prediction = postprocess_qa_predictions(
852 examples=eval_examples,
853 features=eval_dataset,
854 predictions=outputs_numpy,
855 n_best_size=args.n_best_size,
856 max_answer_length=args.max_answer_length,
857 # output_dir=args.finetuned_model_path,
858 prefix=prefix,
859 )
860
861 metric = evaluate.load("squad")
862 eval_metric = metric.compute(
863 predictions=prediction.predictions, references=prediction.label_ids
864 )
865 logger.info(f"{prefix} metrics: {eval_metric}\n")
866 return eval_metric
867
868
869# Model Optimizer: Define a teacher factory for initializing the distillation model
870def teacher_factory(model_name_or_path):
871 return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
872
873
874# Model Optimizer: Define a custom distillation loss function that uses start and end logits
875class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
876 def forward(self, outputs_s, outputs_t):
877 loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
878 loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
879 loss = (loss_start + loss_end) / 2.0
880
881 return loss
882
883
884def train_and_evaluate_model(
885 args,
886 model: nn.Module,
887 accelerator: Accelerator,
888 examples: Dict,
889 dataset: Dict,
890 dataloader: Dict[str, DataLoader],
891 answer_column_name,
892):
893 # Optimizer
894 # Split weights in two groups, one with weight decay and the other not.
895 no_decay = ["bias", "LayerNorm.weight"]
896 optimizer_grouped_parameters = [
897 {
898 "params": [
899 p
900 for n, p in model.named_parameters()
901 if p.requires_grad and not any(nd in n for nd in no_decay)
902 ],
903 "weight_decay": args.weight_decay,
904 },
905 {
906 "params": [
907 p
908 for n, p in model.named_parameters()
909 if p.requires_grad and any(nd in n for nd in no_decay)
910 ],
911 "weight_decay": 0.0,
912 },
913 ]
914 optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
915
916 # Scheduler and math around the number of training steps.
917 overrode_max_train_steps = False
918 num_update_steps_per_epoch = math.ceil(
919 len(dataloader["train"]) / args.gradient_accumulation_steps
920 )
921 if args.max_train_steps is None:
922 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
923 overrode_max_train_steps = True
924
925 lr_scheduler = get_scheduler(
926 name=args.lr_scheduler_type,
927 optimizer=optimizer,
928 num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
929 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
930 )
931
932 # Prepare everything with our `accelerator`.
933 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
934 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
935 )
936
937 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
938 num_update_steps_per_epoch = math.ceil(
939 len(dataloader["train"]) / args.gradient_accumulation_steps
940 )
941 if overrode_max_train_steps:
942 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
943 # Afterwards we recalculate our number of training epochs
944 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
945
946 # Figure out how many steps we should save the Accelerator states
947 checkpointing_steps = args.checkpointing_steps
948 if checkpointing_steps is not None and checkpointing_steps.isdigit():
949 checkpointing_steps = int(checkpointing_steps)
950
951 # We need to initialize the trackers we use, and also store our configuration.
952 # The trackers initializes automatically on the main process.
953 if args.with_tracking:
954 experiment_config = vars(args)
955 # TensorBoard cannot log Enums, need the raw value
956 experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
957 accelerator.init_trackers("tensorboard", experiment_config)
958
959 # Train!
960 total_batch_size = (
961 args.per_device_train_batch_size
962 * accelerator.num_processes
963 * args.gradient_accumulation_steps
964 )
965
966 logger.info("***** Running training *****")
967 logger.info(f" Num examples = {len(dataset['train'])}")
968 logger.info(f" Num Epochs = {args.num_train_epochs}")
969 logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
970 logger.info(
971 f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
972 )
973 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
974 logger.info(f" Total optimization steps = {args.max_train_steps}")
975
976 # Only show the progress bar once on each machine.
977 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
978 completed_steps = 0
979 starting_epoch = 0
980 resume_step = None
981
982 # Potentially load in the weights and states from a previous save
983 if args.resume_from_last_ckpt:
984 # Get the most recent checkpoint
985 dirs = [
986 f.path
987 for f in os.scandir(args.finetuned_model_path)
988 if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
989 ]
990 if len(dirs) == 0:
991 logger.warning(
992 f"No checkpoint found in {args.finetuned_model_path}. Training from scratch!"
993 )
994 else:
995 latest_dir = max(dirs, key=os.path.getctime)
996 accelerator.load_state(latest_dir)
997
998 # Extract `epoch_{i}` or `step_{i}`
999 latest_dir = os.path.basename(latest_dir)
1000 if "epoch" in latest_dir:
1001 starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1002 completed_steps = starting_epoch * num_update_steps_per_epoch
1003 else:
1004 # need to multiply `gradient_accumulation_steps` to reflect real steps
1005 resume_step = (
1006 int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1007 )
1008 starting_epoch = resume_step // len(dataloader["train"])
1009 completed_steps = resume_step // args.gradient_accumulation_steps
1010 resume_step -= starting_epoch * len(dataloader["train"])
1011
1012 # update the progress_bar if load from checkpoint
1013 progress_bar.update(completed_steps)
1014
1015 # Evaluate before training (e.g. PTQ accuracy before QAT)
1016 eval_metric = evaluate_model(
1017 args,
1018 model,
1019 accelerator,
1020 examples["eval"],
1021 dataset["eval"],
1022 dataloader["eval"],
1023 answer_column_name,
1024 )
1025 for epoch in range(starting_epoch, args.num_train_epochs):
1026 model.train()
1027 if args.with_tracking:
1028 total_loss = 0
1029 if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1030 # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1031 active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1032 else:
1033 active_dataloader = dataloader["train"]
1034 for batch in active_dataloader:
1035 optimizer.zero_grad()
1036 outputs = model(**batch)
1037
1038 # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1039 if args.do_modelopt_distill:
1040 loss = model.module.compute_kd_loss()
1041 else:
1042 loss, _, _ = outputs.to_tuple()
1043
1044 # We keep track of the loss at each epoch
1045 if args.with_tracking:
1046 total_loss += loss.detach().float()
1047
1048 accelerator.backward(loss)
1049 optimizer.step()
1050 lr_scheduler.step()
1051
1052 # Checks if the accelerator has performed an optimization step behind the scenes
1053 if accelerator.sync_gradients:
1054 progress_bar.update(1)
1055 completed_steps += 1
1056
1057 if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1058 accelerator.save_state(
1059 os.path.join(args.finetuned_model_path, f"step_{completed_steps}")
1060 )
1061
1062 if completed_steps >= args.max_train_steps:
1063 break
1064
1065 if args.checkpointing_steps == "epoch":
1066 accelerator.save_state(os.path.join(args.finetuned_model_path, f"epoch_{epoch}"))
1067
1068 eval_metric = evaluate_model(
1069 args,
1070 model,
1071 accelerator,
1072 examples["eval"],
1073 dataset["eval"],
1074 dataloader["eval"],
1075 answer_column_name,
1076 )
1077
1078 if args.with_tracking:
1079 log = {
1080 "squad": eval_metric,
1081 "train_loss": total_loss.item() / len(dataloader["train"]), # type: ignore[attr-defined]
1082 "epoch": epoch,
1083 "step": completed_steps,
1084 }
1085 accelerator.log(log, step=completed_steps)
1086
1087 accelerator.wait_for_everyone()
1088 if accelerator.is_main_process and eval_metric:
1089 logger.info(json.dumps(eval_metric, indent=4))
1090 # Prefix all keys with prefix + '_'
1091 for key in list(eval_metric.keys()):
1092 eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1093
1094 with open(os.path.join(args.finetuned_model_path, "results.json"), "w") as f:
1095 json.dump(eval_metric, f, indent=4)
1096
1097
1098def main(input_args: Optional[List[str]] = None) -> None:
1099 args = parse_args(input_args)
1100
1101 # Initialize the accelerator
1102 accelerator_log_kwargs = {}
1103 if args.with_tracking:
1104 accelerator_log_kwargs["log_with"] = "tensorboard"
1105 accelerator_log_kwargs["project_dir"] = args.finetuned_model_path
1106 accelerator = Accelerator(
1107 gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1108 )
1109
1110 # Setup logging
1111 logging.basicConfig(
1112 format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1113 datefmt="%m/%d/%Y %H:%M:%S",
1114 level=logging.INFO,
1115 )
1116 logger.info(accelerator.state, main_process_only=False)
1117 if accelerator.is_local_main_process:
1118 datasets.utils.logging.set_verbosity_warning()
1119 transformers.utils.logging.set_verbosity_info()
1120 else:
1121 datasets.utils.logging.set_verbosity_error()
1122 transformers.utils.logging.set_verbosity_error()
1123
1124 # Set the training seed
1125 set_seed(SEED)
1126
1127 accelerator.wait_for_everyone()
1128
1129 # Load pretrained model and tokenizer
1130 tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1131 model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1132 dummy_input = model.dummy_inputs["input_ids"]
1133
1134 # Get datasets
1135 examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1136 args, tokenizer, accelerator
1137 )
1138
1139 def save(model, output_path):
1140 if accelerator.is_main_process:
1141 os.makedirs(os.path.dirname(output_path), exist_ok=True)
1142 model = accelerator.unwrap_model(model)
1143 model.save_pretrained(output_path)
1144 tokenizer.save_pretrained(output_path)
1145 logger.info(f"Saved model and tokenizer to {output_path}")
1146
1147 # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1148 if args.do_modelopt_prune:
1149 logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1150
1151 # NOTE: gradnas does not perform synchronization across data parallel groups
1152 # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1153 # in the data parallel group have the same gradients for pruning
1154 model = model.to(accelerator.device)
1155 dummy_input = dummy_input.to(accelerator.device)
1156
1157 # Search for the best pruned model
1158 # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1159 model, _ = mtp.prune(
1160 model=model,
1161 mode="gradnas",
1162 constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1163 dummy_input=dummy_input,
1164 config={
1165 "data_loader": dataloader["train"],
1166 "collect_func": lambda batch: (batch,),
1167 "loss_func": lambda output, batch: (
1168 output["loss"] if isinstance(output, dict) else output[0]
1169 ),
1170 },
1171 )
1172 save(model, args.pruned_model_path)
1173
1174 # Model Optimizer: Quantize the model to INT8 precision
1175 if args.modelopt_quantize_cfg:
1176 logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1177
1178 # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1179 # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1180 # in the data parallel group have same calibration statistics
1181 model = model.to(accelerator.device)
1182
1183 def forward_loop(model):
1184 num_samples = 256 # Use only 256 samples for PTQ calibration
1185 num_batches = num_samples // args.per_device_train_batch_size
1186 for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1187 batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1188 model(**batch)
1189 if idx >= num_batches:
1190 break
1191
1192 model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1193 torch.cuda.empty_cache()
1194 save(model, args.ptq_model_path)
1195
1196 if args.do_train:
1197 # Handle the finetuned_model_path creation
1198 if accelerator.is_main_process:
1199 os.makedirs(args.finetuned_model_path, exist_ok=True)
1200
1201 # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1202 if args.do_modelopt_distill:
1203 logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1204
1205 kd_config = {
1206 "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1207 "criterion": StartEndLogitsDistillationLoss(args.temperature),
1208 }
1209 model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1210
1211 train_and_evaluate_model(
1212 args, model, accelerator, examples, dataset, dataloader, answer_column_name
1213 )
1214
1215 # Model Optimizer: Export the distilled model
1216 if args.do_modelopt_distill:
1217 model = mtd.export(model)
1218
1219 save(model, args.finetuned_model_path)
1220
1221 if accelerator.is_main_process and args.onnx_export_path is not None:
1222 logger.info(f"Exporting ONNX model to {args.onnx_export_path}")
1223
1224 # Move the model and dummy_input to the device
1225 model = model.to(accelerator.device)
1226 dummy_input = dummy_input.to(accelerator.device)
1227
1228 with open(args.onnx_export_path, "wb") as f:
1229 f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1230
1231 logger.info("Done!")
1232
1233
1234if __name__ == "__main__":
1235 main()
Commands
First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.
#!/bin/bash set -ex MODEL_NAME_OR_PATH=bert-large-uncased-whole-word-masking-finetuned-squad FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent PRUNED_MODEL_PATH=${BASE_DIR}/pruned/ FINETUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/ modelopt_args="" if [ ! -d ${PRUNED_MODEL_PATH} ]; then modelopt_args="--do_modelopt_prune \ --modelopt_prune_flops_percent ${FLOPS_PERCENT} \ --pruned_model_path ${PRUNED_MODEL_PATH}" else MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH} fi # Run pruning followed by distributed fine-tuning on the pruned model accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path ${MODEL_NAME_OR_PATH} \ --finetuned_model_path ${FINETUNED_MODEL_PATH} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 1e-4 \ --per_device_train_batch_size 16 \ --num_train_epochs 15 \ --with_tracking \ --resume_from_last_ckpt
Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.
#!/bin/bash set -ex FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent PRUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/ PTQ_MODEL_PATH=${BASE_DIR}/int8_ptq/ QAT_MODEL_PATH=${BASE_DIR}/int8_qat/ modelopt_args="" if [ ! -d ${PTQ_MODEL_PATH} ]; then modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --ptq_model_path ${PTQ_MODEL_PATH}" MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH} else MODEL_NAME_OR_PATH=${PTQ_MODEL_PATH} fi # Run distributed QAT on the pruned model with 0.1x LR and less epochs accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path ${MODEL_NAME_OR_PATH} \ --finetuned_model_path ${QAT_MODEL_PATH} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 5e-6 \ --per_device_train_batch_size 16 \ --num_train_epochs 2 \ --with_tracking \ --resume_from_last_ckpt
Export the quantized model to ONNX format for deployment with TensorRT.
#!/bin/bash set -ex FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent # Export to ONNX on a single GPU python3 bert_prune_distill_quantize.py \ --model_name_or_path ${BASE_DIR}/int8_qat/ \ --onnx_export_path ${BASE_DIR}/pruned_model_int8.onnx \