HF BERT: Prune, Distill & Quantize
This example shows how to compress a Hugging Face Bert large model for Question Answering
using the combination of modelopt.torch.prune
, modelopt.torch.distill
and modelopt.torch.quantize
. More specifically, we will:
Prune the Bert large model to 50% FLOPs with GradNAS algorithm and fine-tune with distillation
Quantize the fine-tuned model to INT8 precision with Post-Training Quantization (PTQ) and Quantize Aware Training (QAT) with distillation
Export the quantized model to ONNX format for deployment with TensorRT
Prerequisites
Install Model Optimizer with optional torch and huggingface dependencies:
pip install "nvidia-modelopt[torch,hf]" --extra-index-url https://pypi.nvidia.com
Note
This example has been tested on 8 x 24GB A5000 GPUs with PyTorch 2.4 and CUDA 12.4. It takes about 2 hours to complete all the stages of the optimization. Most of the time is spent on fine-tuning and QAT.
Full code
You can view the full code below with ModelOpt integration points highlighted. The source code and scripts are also available on ModelOpt GitHub
1# NOTE: This is adapted from run_qa_no_trainer.py and utils_qa.py from
2# https://github.com/huggingface/transformers/blob/c52b515e/examples/pytorch/question-answering
3#
4# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
19# SPDX-License-Identifier: MIT
20#
21# Permission is hereby granted, free of charge, to any person obtaining a
22# copy of this software and associated documentation files (the "Software"),
23# to deal in the Software without restriction, including without limitation
24# the rights to use, copy, modify, merge, publish, distribute, sublicense,
25# and/or sell copies of the Software, and to permit persons to whom the
26# Software is furnished to do so, subject to the following conditions:
27#
28# The above copyright notice and this permission notice shall be included in
29# all copies or substantial portions of the Software.
30#
31# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
34# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
36# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
37# DEALINGS IN THE SOFTWARE.
38"""
39Example showcasing how to do end-to-end optimization of a BERT model on SQuAD using Model Optimizer.
40This includes GradNAS pruning, INT8 quantization, fine-tuning / QAT with distillation, and ONNX export.
41"""
42
43import argparse
44import collections
45import json
46import logging
47import math
48import os
49import random
50from typing import Any, Dict, List, Optional, Tuple
51
52import datasets
53import evaluate
54import numpy as np
55import torch
56import torch.nn as nn
57import transformers
58from accelerate import Accelerator
59from accelerate.logging import get_logger
60from accelerate.utils import set_seed
61from torch.utils.data import DataLoader
62from tqdm.auto import tqdm
63from transformers import (
64 AutoModelForQuestionAnswering,
65 AutoTokenizer,
66 DataCollatorWithPadding,
67 EvalPrediction,
68 PreTrainedTokenizer,
69 SchedulerType,
70 default_data_collator,
71 get_scheduler,
72)
73
74# Model Optimizer: imports
75import modelopt.torch.distill as mtd
76import modelopt.torch.opt as mto
77import modelopt.torch.prune as mtp
78import modelopt.torch.quantization as mtq
79from modelopt.torch._deploy.utils import get_onnx_bytes
80
81# Enable automatic save/load of modelopt_state with huggingface checkpointing
82mto.enable_huggingface_checkpointing()
83
84logger = get_logger(__name__)
85
86SEED = 123
87
88
89def parse_args(input_args: Optional[List[str]] = None):
90 parser = argparse.ArgumentParser(
91 description="Finetune a transformers model on a Question Answering task"
92 )
93
94 # Training arguments
95 parser.add_argument(
96 "--model_name_or_path",
97 type=str,
98 default="bert-large-uncased-whole-word-masking-finetuned-squad",
99 help="Path to pretrained model or model identifier from huggingface.co/models.",
100 )
101 parser.add_argument(
102 "--do_train", action="store_true", help="Whether to run training / fine-tuning."
103 )
104 parser.add_argument(
105 "--per_device_train_batch_size",
106 type=int,
107 default=16,
108 help="Batch size (per device) for the training dataloader.",
109 )
110 parser.add_argument(
111 "--per_device_eval_batch_size",
112 type=int,
113 default=64,
114 help="Batch size (per device) for the evaluation dataloader.",
115 )
116 parser.add_argument(
117 "--learning_rate",
118 type=float,
119 default=5e-5,
120 help="Initial learning rate (after the potential warmup period) to use.",
121 )
122 parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
123 parser.add_argument(
124 "--lr_scheduler_type",
125 type=SchedulerType,
126 default="linear",
127 help="The scheduler type to use.",
128 choices=[
129 "linear",
130 "cosine",
131 "cosine_with_restarts",
132 "polynomial",
133 "constant",
134 "constant_with_warmup",
135 ],
136 )
137 parser.add_argument(
138 "--num_warmup_steps",
139 type=int,
140 default=0,
141 help="Number of steps for the warmup in the lr scheduler.",
142 )
143 parser.add_argument(
144 "--num_train_epochs",
145 type=float,
146 default=2.0,
147 help="Total number of training epochs to perform.",
148 )
149 parser.add_argument(
150 "--max_train_steps",
151 type=int,
152 default=None,
153 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
154 )
155 parser.add_argument(
156 "--gradient_accumulation_steps",
157 type=int,
158 default=1,
159 help="Number of updates steps to accumulate before performing a backward/update pass.",
160 )
161 parser.add_argument(
162 "--preprocessing_num_workers",
163 type=int,
164 default=4,
165 help="The number of processes to use for preprocessing the dataset.",
166 )
167
168 # Logging and checkpointing arguments
169 parser.add_argument(
170 "--finetuned_model_path",
171 type=str,
172 default=None,
173 help="Path to save the finetuned (pruned or quantized) model for restoring later with `.from_pretrained()`.",
174 )
175 parser.add_argument(
176 "--with_tracking",
177 action="store_true",
178 help="Whether to enable experiment trackers for logging.",
179 )
180 parser.add_argument(
181 "--checkpointing_steps",
182 type=str,
183 default="epoch",
184 help=(
185 "Whether the various states should be saved at the end of every n steps, or 'epoch' for"
186 " each epoch."
187 ),
188 )
189 parser.add_argument(
190 "--resume_from_last_ckpt",
191 action="store_true",
192 help="If the training should continue from the latest checkpoint in model_name_or_path.",
193 )
194 parser.add_argument(
195 "--onnx_export_path", type=str, default=None, help="Path to export the ONNX model to."
196 )
197
198 # Misc arguments for Bert (should not be modified in most cases)
199 parser.add_argument(
200 "--max_seq_length",
201 type=int,
202 default=384,
203 help=(
204 "The maximum total input sequence length after tokenization. Sequences longer than this"
205 " will be truncated, and shorter will be padded if `--pad_to_max_lengh` is passed."
206 ),
207 )
208 parser.add_argument(
209 "--pad_to_max_length",
210 action="store_true",
211 help="If passed, pad all samples to `max_seq_length`. Otherwise, dynamic padding is used.",
212 )
213 parser.add_argument(
214 "--doc_stride",
215 type=int,
216 default=128,
217 help=(
218 "When splitting up a long document into chunks how much stride to take between chunks."
219 ),
220 )
221 parser.add_argument(
222 "--n_best_size",
223 type=int,
224 default=20,
225 help="The total number of n-best predictions to generate when looking for an answer.",
226 )
227 parser.add_argument(
228 "--max_answer_length",
229 type=int,
230 default=30,
231 help=(
232 "The maximum length of an answer that can be generated. This is needed because the"
233 " start and end predictions are not conditioned on one another."
234 ),
235 )
236
237 # Debugging arguments
238 parser.add_argument(
239 "--max_train_samples",
240 type=int,
241 default=None,
242 help="For debugging purposes or quicker training.",
243 )
244 parser.add_argument(
245 "--max_eval_samples",
246 type=int,
247 default=None,
248 help="For debugging purposes or quicker training.",
249 )
250
251 # Model Optimizer: pruning arguments
252 parser.add_argument(
253 "--do_modelopt_prune",
254 action="store_true",
255 help="Whether or not to use Model Optimizer pruning.",
256 )
257 parser.add_argument(
258 "--modelopt_prune_flops_percent",
259 type=float,
260 default=None,
261 help="The percentage (between 0 and 100) of FLOPs to retain in the pruned model.",
262 )
263 parser.add_argument(
264 "--pruned_model_path",
265 type=str,
266 default=None,
267 help="Path to save the pruned model for further finetuning.",
268 )
269
270 # Model Optimizer: quantization arguments
271 parser.add_argument(
272 "--modelopt_quantize_cfg",
273 help="Model Optimizer quantization config.",
274 choices=mtq.config.choices,
275 )
276
277 # Model Optimizer: Distillation arguments
278 parser.add_argument(
279 "--do_modelopt_distill",
280 action="store_true",
281 help="Whether or not to use distillation. A teacher model must be specified.",
282 )
283 parser.add_argument(
284 "--temperature", type=float, default=2.0, help="The temperature to use when distilling."
285 )
286 parser.add_argument(
287 "--ptq_model_path",
288 type=str,
289 default=None,
290 help="Path to save the PTQ quantized model for further QAT.",
291 )
292
293 args = parser.parse_args(input_args)
294
295 # Sanity checks
296 if args.do_train and not args.finetuned_model_path:
297 raise ValueError("`finetuned_model_path` required when `do_train` is passed.")
298 if args.do_modelopt_prune and not (
299 args.modelopt_prune_flops_percent and args.pruned_model_path
300 ):
301 raise ValueError(
302 "`modelopt_prune_flops_percent` and `pruned_model_path` required when `do_modelopt_prune` is passed."
303 )
304 if args.modelopt_quantize_cfg and not args.ptq_model_path:
305 raise ValueError("`ptq_model_path` required when `modelopt_quantize_cfg` is passed.")
306
307 return args
308
309
310def get_datasets_and_dataloaders(args, tokenizer: PreTrainedTokenizer, accelerator: Accelerator):
311 """Get the examples, dataset, dataloader, answer_column_name
312
313 You can either provide your own CSV/JSON/TXT training and evaluation files (see below)
314 or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
315 (the dataset will be downloaded automatically from the datasets Hub).
316
317 For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
318 'text' is found. You can easily tweak this behavior (see below).
319 """
320
321 def prepare_train_features(examples):
322 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
323 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
324 # left whitespace
325 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
326
327 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
328 # in one example possible giving several features when a context is long, each of those features having a
329 # context that overlaps a bit the context of the previous feature.
330 tokenized_examples = tokenizer(
331 examples[question_column_name if pad_on_right else context_column_name],
332 examples[context_column_name if pad_on_right else question_column_name],
333 truncation="only_second" if pad_on_right else "only_first",
334 max_length=max_seq_length,
335 stride=args.doc_stride,
336 return_overflowing_tokens=True,
337 return_offsets_mapping=True,
338 padding="max_length" if args.pad_to_max_length else False,
339 )
340
341 # Since one example might give us several features if it has a long context, we need a map from a feature to
342 # its corresponding example. This key gives us just that.
343 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
344 # The offset mappings will give us a map from token to character position in the original context. This will
345 # help us compute the start_positions and end_positions.
346 offset_mapping = tokenized_examples.pop("offset_mapping")
347
348 # Let's label those examples!
349 tokenized_examples["start_positions"] = []
350 tokenized_examples["end_positions"] = []
351
352 for i, offsets in enumerate(offset_mapping):
353 # We will label impossible answers with the index of the CLS token.
354 input_ids = tokenized_examples["input_ids"][i]
355 cls_index = input_ids.index(tokenizer.cls_token_id)
356
357 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
358 sequence_ids = tokenized_examples.sequence_ids(i)
359
360 # One example can give several spans, this is the index of the example containing this span of text.
361 sample_index = sample_mapping[i]
362 answers = examples[answer_column_name][sample_index]
363 # If no answers are given, set the cls_index as answer.
364 if len(answers["answer_start"]) == 0:
365 tokenized_examples["start_positions"].append(cls_index)
366 tokenized_examples["end_positions"].append(cls_index)
367 else:
368 # Start/end character index of the answer in the text.
369 start_char = answers["answer_start"][0]
370 end_char = start_char + len(answers["text"][0])
371
372 # Start token index of the current span in the text.
373 token_start_index = 0
374 while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
375 token_start_index += 1
376
377 # End token index of the current span in the text.
378 token_end_index = len(input_ids) - 1
379 while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
380 token_end_index -= 1
381
382 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
383 if not (
384 offsets[token_start_index][0] <= start_char
385 and offsets[token_end_index][1] >= end_char
386 ):
387 tokenized_examples["start_positions"].append(cls_index)
388 tokenized_examples["end_positions"].append(cls_index)
389 else:
390 # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
391 # Note: we could go after the last offset if the answer is the last word (edge case).
392 while (
393 token_start_index < len(offsets)
394 and offsets[token_start_index][0] <= start_char
395 ):
396 token_start_index += 1
397 tokenized_examples["start_positions"].append(token_start_index - 1)
398 while offsets[token_end_index][1] >= end_char:
399 token_end_index -= 1
400 tokenized_examples["end_positions"].append(token_end_index + 1)
401
402 return tokenized_examples
403
404 def prepare_validation_features(examples):
405 # Some of the questions have lots of whitespace on the left, which is not useful and will make the
406 # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
407 # left whitespace
408 examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
409
410 # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
411 # in one example possible giving several features when a context is long, each of those features having a
412 # context that overlaps a bit the context of the previous feature.
413 tokenized_examples = tokenizer(
414 examples[question_column_name if pad_on_right else context_column_name],
415 examples[context_column_name if pad_on_right else question_column_name],
416 truncation="only_second" if pad_on_right else "only_first",
417 max_length=max_seq_length,
418 stride=args.doc_stride,
419 return_overflowing_tokens=True,
420 return_offsets_mapping=True,
421 padding="max_length" if args.pad_to_max_length else False,
422 )
423
424 # Since one example might give us several features if it has a long context, we need a map from a feature to
425 # its corresponding example. This key gives us just that.
426 sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
427
428 # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
429 # corresponding example_id and we will store the offset mappings.
430 tokenized_examples["example_id"] = []
431
432 for i in range(len(tokenized_examples["input_ids"])):
433 # Grab the sequence corresponding to that example (to know what is the context and what is the question).
434 sequence_ids = tokenized_examples.sequence_ids(i)
435 context_index = 1 if pad_on_right else 0
436
437 # One example can give several spans, this is the index of the example containing this span of text.
438 sample_index = sample_mapping[i]
439 tokenized_examples["example_id"].append(examples["id"][sample_index])
440
441 # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
442 # position is part of the context or not.
443 tokenized_examples["offset_mapping"][i] = [
444 (o if sequence_ids[k] == context_index else None)
445 for k, o in enumerate(tokenized_examples["offset_mapping"][i])
446 ]
447
448 return tokenized_examples
449
450 examples, dataset, dataloader = {}, {}, {}
451
452 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
453 # download the dataset.
454 # Downloading and loading a dataset from the hub.
455 raw_datasets = datasets.load_dataset("squad")
456 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
457 # https://huggingface.co/docs/datasets/loading_datasets.
458
459 # Preprocessing the datasets.
460 # Preprocessing is slighlty different for training and evaluation.
461
462 column_names = raw_datasets["train"].column_names
463
464 question_column_name = "question" if "question" in column_names else column_names[0]
465 context_column_name = "context" if "context" in column_names else column_names[1]
466 answer_column_name = "answers" if "answers" in column_names else column_names[2]
467
468 # Padding side determines if we do (question|context) or (context|question).
469 pad_on_right = tokenizer.padding_side == "right"
470
471 if args.max_seq_length > tokenizer.model_max_length:
472 logger.warning(
473 f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length"
474 f" for the model ({tokenizer.model_max_length}). Using"
475 f" max_seq_length={tokenizer.model_max_length}."
476 )
477
478 max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
479
480 examples["train"] = raw_datasets["train"]
481 if args.max_train_samples is not None:
482 # We will select sample from whole data if agument is specified
483 examples["train"] = examples["train"].select(range(args.max_train_samples))
484
485 # Create train feature from dataset
486 with accelerator.main_process_first():
487 dataset["train"] = examples["train"].map(
488 prepare_train_features,
489 batched=True,
490 num_proc=args.preprocessing_num_workers,
491 remove_columns=column_names,
492 load_from_cache_file=True,
493 desc="Running tokenizer on train dataset",
494 )
495 # if args.max_train_samples is not None:
496 # # Number of samples might increase during Feature Creation, We select only specified max samples
497 # dataset["train"] = dataset["train"].select(range(args.max_train_samples))
498
499 examples["eval"] = raw_datasets["validation"]
500 if args.max_eval_samples is not None:
501 # We will select sample from whole data
502 examples["eval"] = examples["eval"].select(range(args.max_eval_samples))
503 # Validation Feature Creation
504 with accelerator.main_process_first():
505 dataset["eval"] = examples["eval"].map(
506 prepare_validation_features,
507 batched=True,
508 num_proc=args.preprocessing_num_workers,
509 remove_columns=column_names,
510 load_from_cache_file=True,
511 desc="Running tokenizer on validation dataset",
512 )
513 # if args.max_eval_samples is not None:
514 # # During Feature creation dataset samples might increase, we will select required samples again
515 # dataset["eval"] = dataset["eval"].select(range(args.max_eval_samples))
516
517 # Log a random sample from the training set:
518 for index in random.sample(range(len(dataset["train"])), 1):
519 logger.info(f"Sample {index} of the training set: {dataset['train'][index]}.")
520
521 # DataLoaders creation:
522 if args.pad_to_max_length:
523 # If padding was already done ot max length, we use the default data collator that will just convert everything
524 # to tensors.
525 data_collator = default_data_collator
526 else:
527 # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
528 # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
529 # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
530 data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
531
532 dataloader["train"] = DataLoader(
533 dataset["train"],
534 shuffle=True,
535 collate_fn=data_collator,
536 batch_size=args.per_device_train_batch_size,
537 )
538
539 dataloader["eval"] = DataLoader(
540 dataset["eval"].remove_columns(["example_id", "offset_mapping"]),
541 collate_fn=data_collator,
542 batch_size=args.per_device_eval_batch_size,
543 )
544
545 return examples, dataset, dataloader, answer_column_name
546
547
548def evaluate_model(
549 args,
550 model: nn.Module,
551 accelerator: Accelerator,
552 eval_examples: Any,
553 eval_dataset: Any,
554 eval_dataloader: DataLoader,
555 answer_column_name: str,
556 prefix: str = "Eval",
557):
558 def create_and_fill_np_array(start_or_end_logits, max_len):
559 """Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
560
561 Args:
562 start_or_end_logits: This is the output predictions of the model.
563 We can only enter either start or end logits.
564 max_len: The maximum length of the output tensor. (See the model.eval() part for more details)
565 """
566 step = 0
567 # create a numpy array and fill it with -100.
568 logits_concat = np.full((len(eval_dataset), max_len), -100, dtype=np.float64)
569 # Now since we have create an array we will populate it with the outputs using accelerator.gather_for_metrics
570 for i, output_logit in enumerate(start_or_end_logits): # populate columns
571 # We have to fill it such that we have to take the whole tensor and replace it on the newly created array
572 # And after every iteration we have to change the step
573 batch_size = output_logit.shape[0]
574 cols = output_logit.shape[1]
575
576 if step + batch_size < len(eval_dataset):
577 logits_concat[step : step + batch_size, :cols] = output_logit
578 else:
579 logits_concat[step:, :cols] = output_logit[: len(eval_dataset) - step]
580
581 step += batch_size
582
583 return logits_concat
584
585 def postprocess_qa_predictions(
586 examples,
587 features,
588 predictions: Tuple[np.ndarray, np.ndarray],
589 version_2_with_negative: bool = False,
590 n_best_size: int = 20,
591 max_answer_length: int = 30,
592 null_score_diff_threshold: float = 0.0,
593 output_dir: Optional[str] = None,
594 prefix: Optional[str] = None,
595 ) -> EvalPrediction:
596 """Post-processes the predictions of a question-answering model to convert them to answers
597 that are substrings of the original contexts. This is the base postprocessing functions for
598 models that only return start and end logits.
599
600 Args:
601 examples: The non-preprocessed dataset.
602 features: The processed dataset.
603 predictions: The predictions of the model: two arrays containing the start logits and the end logits
604 respectively. Its first dimension must match the number of elements of `features`.
605 version_2_with_negative: Whether or not the underlying dataset contains examples with no answers.
606 n_best_size: The total number of n-best predictions to generate when looking for an answer.
607 max_answer_length: The maximum length of an answer that can be generated. This is needed
608 because the start and end predictions are not conditioned on one another.
609 null_score_diff_threshold: The threshold used to select the null answer: if the best answer
610 has a score that is less than the score of the null answer minus this threshold, the
611 null answer is selected for this example (note that the score of the null answer for
612 an example giving several features is the minimum of the scores for the null answer on
613 each feature: all features must be aligned on the fact they `want` to predict a null answer).
614 Only useful when `version_2_with_negative` is `True`.
615 output_dir: If provided, the dictionaries of predictions, n_best predictions (with their scores and logits)
616 and, if `version_2_with_negative=True`, the dictionary of the scores differences between best and null
617 answers, are saved in `output_dir`.
618 prefix: If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
619 """
620 if len(predictions) != 2:
621 raise ValueError(
622 "`predictions` should be a tuple with two elements (start_logits, end_logits)."
623 )
624 all_start_logits, all_end_logits = predictions
625
626 if len(predictions[0]) != len(features):
627 raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
628
629 # Build a map example to its corresponding features.
630 example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
631 features_per_example = collections.defaultdict(list)
632 for i, feature in enumerate(features):
633 features_per_example[example_id_to_index[feature["example_id"]]].append(i)
634
635 # The dictionaries we have to fill.
636 all_predictions = collections.OrderedDict()
637 all_nbest_json = collections.OrderedDict()
638 if version_2_with_negative:
639 scores_diff_json = collections.OrderedDict()
640
641 logger.debug(
642 f"Post-processing {len(examples)} example predictions split into"
643 f" {len(features)} features."
644 )
645
646 # Let's loop over all the examples!
647 for example_index, example in enumerate(examples):
648 # Those are the indices of the features associated to the current example.
649 feature_indices = features_per_example[example_index]
650
651 min_null_prediction = None
652 prelim_predictions = []
653
654 # Looping through all the features associated to the current example.
655 for feature_index in feature_indices:
656 # We grab the predictions of the model for this feature.
657 start_logits = all_start_logits[feature_index]
658 end_logits = all_end_logits[feature_index]
659 # This is what will allow us to map some the positions in our logits to span of texts in the original
660 # context.
661 offset_mapping = features[feature_index]["offset_mapping"]
662 # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum
663 # context available in the current feature.
664 token_is_max_context = features[feature_index].get("token_is_max_context", None)
665
666 # Update minimum null prediction.
667 feature_null_score = start_logits[0] + end_logits[0]
668 if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
669 min_null_prediction = {
670 "offsets": (0, 0),
671 "score": feature_null_score,
672 "start_logit": start_logits[0],
673 "end_logit": end_logits[0],
674 }
675
676 # Go through all possibilities for the `n_best_size` greater start and end logits.
677 start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
678 end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
679 for start_index in start_indexes:
680 for end_index in end_indexes:
681 # Don't consider out-of-scope answers, either because the indices are out of bounds or
682 # correspond to part of the input_ids that are not in the context.
683 if (
684 start_index >= len(offset_mapping)
685 or end_index >= len(offset_mapping)
686 or offset_mapping[start_index] is None
687 or len(offset_mapping[start_index]) < 2
688 or offset_mapping[end_index] is None
689 or len(offset_mapping[end_index]) < 2
690 ):
691 continue
692 # Don't consider answers with a length that is either < 0 or > max_answer_length.
693 if (
694 end_index < start_index
695 or end_index - start_index + 1 > max_answer_length
696 ):
697 continue
698 # Don't consider answer that don't have the maximum context available (if such information is
699 # provided).
700 if token_is_max_context is not None and not token_is_max_context.get(
701 str(start_index), False
702 ):
703 continue
704
705 prelim_predictions.append(
706 {
707 "offsets": (
708 offset_mapping[start_index][0],
709 offset_mapping[end_index][1],
710 ),
711 "score": start_logits[start_index] + end_logits[end_index],
712 "start_logit": start_logits[start_index],
713 "end_logit": end_logits[end_index],
714 }
715 )
716 if version_2_with_negative and min_null_prediction is not None:
717 # Add the minimum null prediction
718 prelim_predictions.append(min_null_prediction)
719 null_score = min_null_prediction["score"]
720
721 # Only keep the best `n_best_size` predictions.
722 n_best_preds = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[
723 :n_best_size
724 ]
725
726 # Add back the minimum null prediction if it was removed because of its low score.
727 if (
728 version_2_with_negative
729 and min_null_prediction is not None
730 and not any(p["offsets"] == (0, 0) for p in n_best_preds)
731 ):
732 n_best_preds.append(min_null_prediction)
733
734 # Use the offsets to gather the answer text in the original context.
735 context = example["context"]
736 for pred in n_best_preds:
737 offsets = pred.pop("offsets")
738 pred["text"] = context[offsets[0] : offsets[1]]
739
740 # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
741 # failure.
742 if len(n_best_preds) == 0 or (len(n_best_preds) == 1 and n_best_preds[0]["text"] == ""):
743 n_best_preds.insert(
744 0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}
745 )
746
747 # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file,
748 # using the LogSumExp trick).
749 scores = np.array([pred.pop("score") for pred in n_best_preds])
750 exp_scores = np.exp(scores - np.max(scores))
751 probs = exp_scores / exp_scores.sum()
752
753 # Include the probabilities in our n_best_preds.
754 for prob, pred in zip(probs, n_best_preds):
755 pred["probability"] = prob
756
757 # Pick the best prediction. If the null answer is not possible, this is easy.
758 if not version_2_with_negative:
759 all_predictions[example["id"]] = n_best_preds[0]["text"]
760 else:
761 # Otherwise we first need to find the best non-empty prediction.
762 i = 0
763 while n_best_preds[i]["text"] == "":
764 i += 1
765 best_non_null_pred = n_best_preds[i]
766
767 # Then we compare to the null prediction using the threshold.
768 score_diff = (
769 null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
770 )
771 scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
772 if score_diff > null_score_diff_threshold:
773 all_predictions[example["id"]] = ""
774 else:
775 all_predictions[example["id"]] = best_non_null_pred["text"]
776
777 # Make `n_best_preds` JSON-serializable by casting np.float back to float.
778 all_nbest_json[example["id"]] = [
779 {
780 k: float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v
781 for k, v in pred.items()
782 }
783 for pred in n_best_preds
784 ]
785
786 # If we have an output_dir, let's save all those dicts.
787 if output_dir is not None:
788 if not os.path.isdir(output_dir):
789 raise EnvironmentError(f"{output_dir} is not a directory.")
790
791 prediction_file = os.path.join(
792 output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
793 )
794 nbest_file = os.path.join(
795 output_dir,
796 "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json",
797 )
798 if version_2_with_negative:
799 null_odds_file = os.path.join(
800 output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
801 )
802
803 logger.info(f"Saving predictions to {prediction_file}.")
804 with open(prediction_file, "w") as writer:
805 writer.write(json.dumps(all_predictions, indent=4) + "\n")
806 logger.info(f"Saving nbest_preds to {nbest_file}.")
807 with open(nbest_file, "w") as writer:
808 writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
809 if version_2_with_negative:
810 logger.info(f"Saving null_odds to {null_odds_file}.")
811 with open(null_odds_file, "w") as writer:
812 writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
813
814 # Format the result to the format the metric expects.
815 formatted_predictions = [
816 {"id": k, "prediction_text": v} for k, v in all_predictions.items()
817 ]
818
819 references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
820 return EvalPrediction(predictions=formatted_predictions, label_ids=references)
821
822 logger.info(f"***** Running {prefix} *****")
823 logger.info(f" Num examples = {len(eval_dataset)}")
824 logger.info(f" Batch size = {args.per_device_eval_batch_size}")
825
826 model.eval()
827 all_start_logits = []
828 all_end_logits = []
829 for _, batch in enumerate(eval_dataloader):
830 with torch.no_grad():
831 outputs = model(**batch)
832 start_logits = outputs.start_logits
833 end_logits = outputs.end_logits
834
835 if (
836 not args.pad_to_max_length
837 ): # necessary to pad predictions and labels for being gathered
838 start_logits = accelerator.pad_across_processes(start_logits, dim=1, pad_index=-100)
839 end_logits = accelerator.pad_across_processes(end_logits, dim=1, pad_index=-100)
840
841 all_start_logits.append(accelerator.gather_for_metrics(start_logits).cpu().numpy())
842 all_end_logits.append(accelerator.gather_for_metrics(end_logits).cpu().numpy())
843
844 # Model Optimizer: clear the intermediate states of the distillation model from the forward passes
845 if args.do_modelopt_distill:
846 model.module.compute_kd_loss()
847
848 max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
849
850 # concatenate the numpy array
851 start_logits_concat = create_and_fill_np_array(all_start_logits, max_len)
852 end_logits_concat = create_and_fill_np_array(all_end_logits, max_len)
853
854 outputs_numpy = (start_logits_concat, end_logits_concat)
855 prediction = postprocess_qa_predictions(
856 examples=eval_examples,
857 features=eval_dataset,
858 predictions=outputs_numpy,
859 n_best_size=args.n_best_size,
860 max_answer_length=args.max_answer_length,
861 # output_dir=args.finetuned_model_path,
862 prefix=prefix,
863 )
864
865 metric = evaluate.load("squad")
866 eval_metric = metric.compute(
867 predictions=prediction.predictions, references=prediction.label_ids
868 )
869 logger.info(f"{prefix} metrics: {eval_metric}\n")
870 return eval_metric
871
872
873# Model Optimizer: Define a teacher factory for initializing the distillation model
874def teacher_factory(model_name_or_path):
875 return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)
876
877
878# Model Optimizer: Define a custom distillation loss function that uses start and end logits
879class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
880 def forward(self, outputs_s, outputs_t):
881 loss_start = super().forward(outputs_s.start_logits, outputs_t.start_logits)
882 loss_end = super().forward(outputs_s.end_logits, outputs_t.end_logits)
883 loss = (loss_start + loss_end) / 2.0
884
885 return loss
886
887
888def train_and_evaluate_model(
889 args,
890 model: nn.Module,
891 accelerator: Accelerator,
892 examples: Dict,
893 dataset: Dict,
894 dataloader: Dict[str, DataLoader],
895 answer_column_name,
896):
897 # Optimizer
898 # Split weights in two groups, one with weight decay and the other not.
899 no_decay = ["bias", "LayerNorm.weight"]
900 optimizer_grouped_parameters = [
901 {
902 "params": [
903 p
904 for n, p in model.named_parameters()
905 if p.requires_grad and not any(nd in n for nd in no_decay)
906 ],
907 "weight_decay": args.weight_decay,
908 },
909 {
910 "params": [
911 p
912 for n, p in model.named_parameters()
913 if p.requires_grad and any(nd in n for nd in no_decay)
914 ],
915 "weight_decay": 0.0,
916 },
917 ]
918 optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
919
920 # Scheduler and math around the number of training steps.
921 overrode_max_train_steps = False
922 num_update_steps_per_epoch = math.ceil(
923 len(dataloader["train"]) / args.gradient_accumulation_steps
924 )
925 if args.max_train_steps is None:
926 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
927 overrode_max_train_steps = True
928
929 lr_scheduler = get_scheduler(
930 name=args.lr_scheduler_type,
931 optimizer=optimizer,
932 num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
933 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
934 )
935
936 # Prepare everything with our `accelerator`.
937 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler = accelerator.prepare(
938 model, optimizer, dataloader["train"], dataloader["eval"], lr_scheduler
939 )
940
941 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
942 num_update_steps_per_epoch = math.ceil(
943 len(dataloader["train"]) / args.gradient_accumulation_steps
944 )
945 if overrode_max_train_steps:
946 args.max_train_steps = int(args.num_train_epochs * num_update_steps_per_epoch)
947 # Afterwards we recalculate our number of training epochs
948 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
949
950 # Figure out how many steps we should save the Accelerator states
951 checkpointing_steps = args.checkpointing_steps
952 if checkpointing_steps is not None and checkpointing_steps.isdigit():
953 checkpointing_steps = int(checkpointing_steps)
954
955 # We need to initialize the trackers we use, and also store our configuration.
956 # The trackers initializes automatically on the main process.
957 if args.with_tracking:
958 experiment_config = vars(args)
959 # TensorBoard cannot log Enums, need the raw value
960 experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
961 accelerator.init_trackers("tensorboard", experiment_config)
962
963 # Train!
964 total_batch_size = (
965 args.per_device_train_batch_size
966 * accelerator.num_processes
967 * args.gradient_accumulation_steps
968 )
969
970 logger.info("***** Running training *****")
971 logger.info(f" Num examples = {len(dataset['train'])}")
972 logger.info(f" Num Epochs = {args.num_train_epochs}")
973 logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
974 logger.info(
975 f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
976 )
977 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
978 logger.info(f" Total optimization steps = {args.max_train_steps}")
979
980 # Only show the progress bar once on each machine.
981 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
982 completed_steps = 0
983 starting_epoch = 0
984 resume_step = None
985
986 # Potentially load in the weights and states from a previous save
987 if args.resume_from_last_ckpt:
988 # Get the most recent checkpoint
989 dirs = [
990 f.path
991 for f in os.scandir(args.finetuned_model_path)
992 if f.is_dir() and (f.name.startswith("epoch_") or f.name.startswith("step_"))
993 ]
994 if len(dirs) == 0:
995 logger.warning(
996 f"No checkpoint found in {args.finetuned_model_path}. Training from scratch!"
997 )
998 else:
999 latest_dir = max(dirs, key=os.path.getctime)
1000 accelerator.load_state(latest_dir)
1001
1002 # Extract `epoch_{i}` or `step_{i}`
1003 latest_dir = os.path.basename(latest_dir)
1004 if "epoch" in latest_dir:
1005 starting_epoch = int(latest_dir.replace("epoch_", "")) + 1
1006 completed_steps = starting_epoch * num_update_steps_per_epoch
1007 else:
1008 # need to multiply `gradient_accumulation_steps` to reflect real steps
1009 resume_step = (
1010 int(latest_dir.replace("step_", "")) * args.gradient_accumulation_steps
1011 )
1012 starting_epoch = resume_step // len(dataloader["train"])
1013 completed_steps = resume_step // args.gradient_accumulation_steps
1014 resume_step -= starting_epoch * len(dataloader["train"])
1015
1016 # update the progress_bar if load from checkpoint
1017 progress_bar.update(completed_steps)
1018
1019 # Evaluate before training (e.g. PTQ accuracy before QAT)
1020 eval_metric = evaluate_model(
1021 args,
1022 model,
1023 accelerator,
1024 examples["eval"],
1025 dataset["eval"],
1026 dataloader["eval"],
1027 answer_column_name,
1028 )
1029 for epoch in range(starting_epoch, args.num_train_epochs):
1030 model.train()
1031 if args.with_tracking:
1032 total_loss = 0
1033 if args.resume_from_last_ckpt and epoch == starting_epoch and resume_step is not None:
1034 # We skip the first `n` batches in the dataloader when resuming from a checkpoint
1035 active_dataloader = accelerator.skip_first_batches(dataloader["train"], resume_step)
1036 else:
1037 active_dataloader = dataloader["train"]
1038 for batch in active_dataloader:
1039 optimizer.zero_grad()
1040 outputs = model(**batch)
1041
1042 # Model Optimizer: If using distillation, we unwrap the model and extract the custom loss function
1043 if args.do_modelopt_distill:
1044 loss = model.module.compute_kd_loss()
1045 else:
1046 loss, _, _ = outputs.to_tuple()
1047
1048 # We keep track of the loss at each epoch
1049 if args.with_tracking:
1050 total_loss += loss.detach().float()
1051
1052 accelerator.backward(loss)
1053 optimizer.step()
1054 lr_scheduler.step()
1055
1056 # Checks if the accelerator has performed an optimization step behind the scenes
1057 if accelerator.sync_gradients:
1058 progress_bar.update(1)
1059 completed_steps += 1
1060
1061 if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
1062 accelerator.save_state(
1063 os.path.join(args.finetuned_model_path, f"step_{completed_steps}")
1064 )
1065
1066 if completed_steps >= args.max_train_steps:
1067 break
1068
1069 if args.checkpointing_steps == "epoch":
1070 accelerator.save_state(os.path.join(args.finetuned_model_path, f"epoch_{epoch}"))
1071
1072 eval_metric = evaluate_model(
1073 args,
1074 model,
1075 accelerator,
1076 examples["eval"],
1077 dataset["eval"],
1078 dataloader["eval"],
1079 answer_column_name,
1080 )
1081
1082 if args.with_tracking:
1083 log = {
1084 "squad": eval_metric,
1085 "train_loss": total_loss.item() / len(dataloader["train"]), # type: ignore[attr-defined]
1086 "epoch": epoch,
1087 "step": completed_steps,
1088 }
1089 accelerator.log(log, step=completed_steps)
1090
1091 accelerator.wait_for_everyone()
1092 if accelerator.is_main_process and eval_metric:
1093 logger.info(json.dumps(eval_metric, indent=4))
1094 # Prefix all keys with prefix + '_'
1095 for key in list(eval_metric.keys()):
1096 eval_metric[f"Eval_{key}"] = eval_metric.pop(key)
1097
1098 with open(os.path.join(args.finetuned_model_path, "results.json"), "w") as f:
1099 json.dump(eval_metric, f, indent=4)
1100
1101
1102def main(input_args: Optional[List[str]] = None) -> None:
1103 args = parse_args(input_args)
1104
1105 # Initialize the accelerator
1106 accelerator_log_kwargs = {}
1107 if args.with_tracking:
1108 accelerator_log_kwargs["log_with"] = "tensorboard"
1109 accelerator_log_kwargs["project_dir"] = args.finetuned_model_path
1110 accelerator = Accelerator(
1111 gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs
1112 )
1113
1114 # Setup logging
1115 logging.basicConfig(
1116 format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
1117 datefmt="%m/%d/%Y %H:%M:%S",
1118 level=logging.INFO,
1119 )
1120 logger.info(accelerator.state, main_process_only=False)
1121 if accelerator.is_local_main_process:
1122 datasets.utils.logging.set_verbosity_warning()
1123 transformers.utils.logging.set_verbosity_info()
1124 else:
1125 datasets.utils.logging.set_verbosity_error()
1126 transformers.utils.logging.set_verbosity_error()
1127
1128 # Set the training seed
1129 set_seed(SEED)
1130
1131 accelerator.wait_for_everyone()
1132
1133 # Load pretrained model and tokenizer
1134 tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
1135 model = AutoModelForQuestionAnswering.from_pretrained(args.model_name_or_path)
1136 dummy_input = model.dummy_inputs["input_ids"]
1137
1138 # Get datasets
1139 examples, dataset, dataloader, answer_column_name = get_datasets_and_dataloaders(
1140 args, tokenizer, accelerator
1141 )
1142
1143 def save(model, output_path):
1144 if accelerator.is_main_process:
1145 os.makedirs(os.path.dirname(output_path), exist_ok=True)
1146 model = accelerator.unwrap_model(model)
1147 model.save_pretrained(output_path)
1148 tokenizer.save_pretrained(output_path)
1149 logger.info(f"Saved model and tokenizer to {output_path}")
1150
1151 # Model Optimizer: Prune the model to given FLOPS target using GradNAS algorithm
1152 if args.do_modelopt_prune:
1153 logger.info(f"Pruning model to {args.modelopt_prune_flops_percent}% FLOPS")
1154
1155 # NOTE: gradnas does not perform synchronization across data parallel groups
1156 # Use unwrapped model & non-distributed dataloader for gradnas so that all the processes
1157 # in the data parallel group have the same gradients for pruning
1158 model = model.to(accelerator.device)
1159 dummy_input = dummy_input.to(accelerator.device)
1160
1161 # Search for the best pruned model
1162 # To use other NAS algorithms, you can use `mtn.convert` + `mtn.search` here.
1163 model, _ = mtp.prune(
1164 model=model,
1165 mode="gradnas",
1166 constraints={"flops": f"{args.modelopt_prune_flops_percent}%"},
1167 dummy_input=dummy_input,
1168 config={
1169 "data_loader": dataloader["train"],
1170 "collect_func": lambda batch: (batch,),
1171 "loss_func": lambda output, batch: (
1172 output["loss"] if isinstance(output, dict) else output[0]
1173 ),
1174 },
1175 )
1176 save(model, args.pruned_model_path)
1177
1178 # Model Optimizer: Quantize the model to INT8 precision
1179 if args.modelopt_quantize_cfg:
1180 logger.info(f"Quantizing model with {args.modelopt_quantize_cfg} config")
1181
1182 # NOTE: `mtq.quantize` does not perform synchronization across data parallel groups
1183 # Use unwrapped model & non-distributed dataloader for PTQ calibration so that all the processes
1184 # in the data parallel group have same calibration statistics
1185 model = model.to(accelerator.device)
1186
1187 def forward_loop(model):
1188 num_samples = 256 # Use only 256 samples for PTQ calibration
1189 num_batches = num_samples // args.per_device_train_batch_size
1190 for idx, batch in tqdm(enumerate(dataloader["train"]), total=num_batches):
1191 batch = {k: v.to(accelerator.device) for k, v in batch.items()}
1192 model(**batch)
1193 if idx >= num_batches:
1194 break
1195
1196 model = mtq.quantize(model, getattr(mtq, args.modelopt_quantize_cfg), forward_loop)
1197 torch.cuda.empty_cache()
1198 save(model, args.ptq_model_path)
1199
1200 if args.do_train:
1201 # Handle the finetuned_model_path creation
1202 if accelerator.is_main_process:
1203 os.makedirs(args.finetuned_model_path, exist_ok=True)
1204
1205 # Model Optimizer: Convert to a DistillationModel containing teacher to train with distillation
1206 if args.do_modelopt_distill:
1207 logger.info(f"Using distillation with teacher {args.model_name_or_path}")
1208
1209 kd_config = {
1210 "teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
1211 "criterion": StartEndLogitsDistillationLoss(args.temperature),
1212 }
1213 model = mtd.convert(model, mode=[("kd_loss", kd_config)])
1214
1215 train_and_evaluate_model(
1216 args, model, accelerator, examples, dataset, dataloader, answer_column_name
1217 )
1218
1219 # Model Optimizer: Export the distilled model
1220 if args.do_modelopt_distill:
1221 model = mtd.export(model)
1222
1223 save(model, args.finetuned_model_path)
1224
1225 if accelerator.is_main_process and args.onnx_export_path is not None:
1226 logger.info(f"Exporting ONNX model to {args.onnx_export_path}")
1227
1228 # Move the model and dummy_input to the device
1229 model = model.to(accelerator.device)
1230 dummy_input = dummy_input.to(accelerator.device)
1231
1232 with open(args.onnx_export_path, "wb") as f:
1233 f.write(get_onnx_bytes(model, dummy_input, onnx_opset=14))
1234
1235 logger.info("Done!")
1236
1237
1238if __name__ == "__main__":
1239 main()
Commands
First we prune the Bert large model to 50% FLOPs with GradNAS algorithm. Then, we fine-tune the pruned model with distillation from unpruned teacher model to recover 99+% of the initial F1 score (93.15). We recommend using multiple GPUs for fine-tuning. Note that we use more epochs for fine-tuning, which is different from the 2 epochs used originally in fine-tuning Bert without distillation since distillation requires more epochs to converge but achieves much better results.
#!/bin/bash set -ex MODEL_NAME_OR_PATH=bert-large-uncased-whole-word-masking-finetuned-squad FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent PRUNED_MODEL_PATH=${BASE_DIR}/pruned/ FINETUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/ modelopt_args="" if [ ! -d ${PRUNED_MODEL_PATH} ]; then modelopt_args="--do_modelopt_prune \ --modelopt_prune_flops_percent ${FLOPS_PERCENT} \ --pruned_model_path ${PRUNED_MODEL_PATH}" else MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH} fi # Run pruning followed by distributed fine-tuning on the pruned model accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path ${MODEL_NAME_OR_PATH} \ --finetuned_model_path ${FINETUNED_MODEL_PATH} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 1e-4 \ --per_device_train_batch_size 16 \ --num_train_epochs 15 \ --with_tracking \ --resume_from_last_ckpt
Quantize the fine-tuned model to INT8 precision and run calibration (PTQ). Note that PTQ will result in a slight drop in F1 score but we will be able to recover the F1 score with QAT. We run QAT with distillation as well from unpruned teacher model.
#!/bin/bash set -ex FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent PRUNED_MODEL_PATH=${BASE_DIR}/pruned_finetuned/ PTQ_MODEL_PATH=${BASE_DIR}/int8_ptq/ QAT_MODEL_PATH=${BASE_DIR}/int8_qat/ modelopt_args="" if [ ! -d ${PTQ_MODEL_PATH} ]; then modelopt_args="--modelopt_quantize_cfg INT8_DEFAULT_CFG --ptq_model_path ${PTQ_MODEL_PATH}" MODEL_NAME_OR_PATH=${PRUNED_MODEL_PATH} else MODEL_NAME_OR_PATH=${PTQ_MODEL_PATH} fi # Run distributed QAT on the pruned model with 0.1x LR and less epochs accelerate launch --multi_gpu --mixed_precision bf16 bert_prune_distill_quantize.py \ --model_name_or_path ${MODEL_NAME_OR_PATH} \ --finetuned_model_path ${QAT_MODEL_PATH} \ ${modelopt_args} \ --do_train \ --do_modelopt_distill \ --lr_scheduler_type cosine \ --learning_rate 5e-6 \ --per_device_train_batch_size 16 \ --num_train_epochs 2 \ --with_tracking \ --resume_from_last_ckpt
Export the quantized model to ONNX format for deployment with TensorRT.
#!/bin/bash set -ex FLOPS_PERCENT=50 BASE_DIR=results/bert_large_pruned_${FLOPS_PERCENT}_percent # Export to ONNX on a single GPU python3 bert_prune_distill_quantize.py \ --model_name_or_path ${BASE_DIR}/int8_qat/ \ --onnx_export_path ${BASE_DIR}/pruned_model_int8.onnx \