diff --git a/examples/post_training_quantization/SQuAD/README.md b/examples/post_training_quantization/SQuAD/README.md new file mode 100644 index 0000000..3e0185f --- /dev/null +++ b/examples/post_training_quantization/SQuAD/README.md @@ -0,0 +1,20 @@ +## Introduction +- We introduced a CoLA demo to demonstrate how to apply PTQ to BERT. + +## Run + +### Install Requirements +- `pip install -r requirements.txt` + +### Training +- run `python main.py finetuning` to get a checkpoint + +### Post Training Quantization +- `python main.py postquant qconfig.yaml ./checkpoint.pth.tar` + +## Results +- we use f1-score as the metric here + +model | float | 8w8f | +--- | --- | --- | +bert-base-uncased | 88.22 | 87.47 | diff --git a/examples/post_training_quantization/SQuAD/main.py b/examples/post_training_quantization/SQuAD/main.py new file mode 100644 index 0000000..2c95b08 --- /dev/null +++ b/examples/post_training_quantization/SQuAD/main.py @@ -0,0 +1,351 @@ +import collections +import numpy as np +from tqdm import tqdm +from functools import partial + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.optim import AdamW +from torch.utils.data import DataLoader +import evaluate +from datasets import load_dataset +from transformers import BertTokenizerFast, default_data_collator, get_scheduler + +from model import BertModel, BertForQuestionAnswering +from sparsebit.quantization import QuantModel, parse_qconfig + +MAX_LENGTH = 384 +STRIDE = 128 + + +def build_train_dataloader(args, raw_datasets): + + def preprocess_training_examples(examples): + questions = [q.strip() for q in examples["question"]] + inputs = tokenizer( + questions, + examples["context"], + max_length=MAX_LENGTH, + truncation="only_second", + stride=STRIDE, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + offset_mapping = inputs.pop("offset_mapping") + sample_map = inputs.pop("overflow_to_sample_mapping") + answers = examples["answers"] + start_positions = [] + end_positions = [] + + for i, offset in enumerate(offset_mapping): + sample_idx = sample_map[i] + answer = answers[sample_idx] + start_char = answer["answer_start"][0] + end_char = answer["answer_start"][0] + len(answer["text"][0]) + sequence_ids = inputs.sequence_ids(i) + + # Find the start and end of the context + idx = 0 + while sequence_ids[idx] != 1: + idx += 1 + context_start = idx + while sequence_ids[idx] == 1: + idx += 1 + context_end = idx - 1 + + # If the answer is not fully inside the context, label is (0, 0) + if offset[context_start][0] > start_char or offset[context_end][1] < end_char: + start_positions.append(0) + end_positions.append(0) + else: + # Otherwise it's the start and end token positions + idx = context_start + while idx <= context_end and offset[idx][0] <= start_char: + idx += 1 + start_positions.append(idx - 1) + + idx = context_end + while idx >= context_start and offset[idx][1] >= end_char: + idx -= 1 + end_positions.append(idx + 1) + + inputs["start_positions"] = start_positions + inputs["end_positions"] = end_positions + return inputs + + tokenizer = BertTokenizerFast.from_pretrained(args.architecture, do_lower_case=True) + train_dataset = raw_datasets["train"].map( + preprocess_training_examples, + batched=True, + remove_columns=raw_datasets["train"].column_names, + ) + train_dataset.set_format("torch") + train_dataloader = DataLoader( + train_dataset, + shuffle=True, + collate_fn=default_data_collator, + batch_size=args.batch_size, + ) + return train_dataloader, train_dataset + + +def build_validation_dataloader(args, raw_datasets): + + def preprocess_validation_examples(examples): + questions = [q.strip() for q in examples["question"]] + inputs = tokenizer( + questions, + examples["context"], + max_length=MAX_LENGTH, + truncation="only_second", + stride=STRIDE, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + sample_map = inputs.pop("overflow_to_sample_mapping") + example_ids = [] + + for i in range(len(inputs["input_ids"])): + sample_idx = sample_map[i] + example_ids.append(examples["id"][sample_idx]) + + sequence_ids = inputs.sequence_ids(i) + offset = inputs["offset_mapping"][i] + inputs["offset_mapping"][i] = [ + o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) + ] + + inputs["example_id"] = example_ids + return inputs + + tokenizer = BertTokenizerFast.from_pretrained(args.architecture, do_lower_case=True) + validation_dataset = raw_datasets["validation"].map( + preprocess_validation_examples, + batched=True, + remove_columns=raw_datasets["validation"].column_names, + ) + validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"]) + validation_set.set_format("torch") + validation_dataloader = DataLoader( + validation_set, collate_fn=default_data_collator, batch_size=2 * args.batch_size, + ) + return validation_dataloader, validation_dataset + + +def compute_metrics(start_logits, end_logits, features, examples): + example_to_features = collections.defaultdict(list) + for idx, feature in enumerate(features): + example_to_features[feature["example_id"]].append(idx) + + n_best = 20 + max_answer_length = 30 + predicted_answers = [] + for example in tqdm(examples): + example_id = example["id"] + context = example["context"] + answers = [] + + # Loop through all features associated with that example + for feature_index in example_to_features[example_id]: + start_logit = start_logits[feature_index] + end_logit = end_logits[feature_index] + offsets = features[feature_index]["offset_mapping"] + + start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() + end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Skip answers that are not fully in the context + if offsets[start_index] is None or offsets[end_index] is None: + continue + # Skip answers with a length that is either < 0 or > max_answer_length + if ( + end_index < start_index + or end_index - start_index + 1 > max_answer_length + ): + continue + + answer = { + "text": context[offsets[start_index][0] : offsets[end_index][1]], + "logit_score": start_logit[start_index] + end_logit[end_index], + } + answers.append(answer) + + # Select the answer with the best score + if len(answers) > 0: + best_answer = max(answers, key=lambda x: x["logit_score"]) + predicted_answers.append( + {"id": example_id, "prediction_text": best_answer["text"]} + ) + else: + predicted_answers.append({"id": example_id, "prediction_text": ""}) + + theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples] + metric = evaluate.load("squad") + return metric.compute(predictions=predicted_answers, references=theoretical_answers) + + +def finetuning(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + + raw_datasets = load_dataset('squad') + train_dataloader, _ = build_train_dataloader(args, raw_datasets) + + bert_model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False) + bert_model.embeddings.seq_length = MAX_LENGTH + bert_model.config.num_labels = 2 + model = BertForQuestionAnswering(bert_model, bert_model.config) + model.cuda() + + total_training_steps = args.epochs * len(train_dataloader) + optimizer = AdamW(model.parameters(), lr=args.lr) + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=total_training_steps, + ) + + progress_bar = tqdm(range(total_training_steps)) + + def calc_loss_from_logits(logits, start_positions, end_positions): + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + return loss + + for epoch in range(args.epochs): + # training + model.train() + for step, batch in enumerate(train_dataloader): + batch_data = {k: v.to(device) for k, v in batch.items() if k not in ["start_positions", "end_positions"]} + logits = model(**batch_data) + # calculate loss + loss = calc_loss_from_logits(logits, + batch["start_positions"].to(device), + batch["end_positions"].to(device)) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + }, + "checkpoint.pth.tar", + ) + print("Running evaluation of epoch_{}".format(epoch)) + evaluation(args, model, raw_datasets, device) + + +def postquant(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + raw_datasets = load_dataset('squad') + calib_dataloader, _ = build_train_dataloader(args, raw_datasets) + bert_model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False) + bert_model.embeddings.seq_length = MAX_LENGTH + bert_model.config.num_labels = 2 + model = BertForQuestionAnswering(bert_model, bert_model.config) + model.load_state_dict(torch.load(args.checkpoint)["model"]) + model.to(device) + + qconfig = parse_qconfig(args.qconfig) + qmodel = QuantModel(model, config=qconfig).to(device) + + cudnn.benchmark = True + qmodel.prepare_calibration() + calibration_size, cur_size = 128, 0 + for batch in calib_dataloader: + batch = {k: v.to(device) for k, v in batch.items() if k not in ["start_positions", "end_positions"]} + qmodel(**batch) + cur_size += batch["input_ids"].shape[0] + if cur_size >= calibration_size: + break + qmodel.calc_qparams() + + qmodel.set_quant(w_quant=True, a_quant=True) + evaluation(args, qmodel, raw_datasets, device) + + # export onnx + dummy_data = [] + for batch in calib_dataloader: + dummy_data.append(batch["input_ids"][:1, :]) + dummy_data.append(batch["attention_mask"][:1, :]) + dummy_data.append(batch["token_type_ids"][:1, :]) + break + qmodel.export_onnx(tuple(dummy_data), name="qBERT.onnx") + + +def evaluation(args, model, raw_datasets, device): + dataloader, dataset = build_validation_dataloader(args, raw_datasets) + model.eval() + start_logits_list, end_logits_list = [], [] + for batch in tqdm(dataloader): + batch = {k: v.to(device) for k, v in batch.items()} + with torch.no_grad(): + logits = model(**batch) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + start_logits_list.append(start_logits.cpu().numpy()) + end_logits_list.append(end_logits.cpu().numpy()) + # cat all results to evaluate f1-score + start_logits = np.concatenate(start_logits_list)[: len(dataset)] + end_logits = np.concatenate(end_logits_list)[: len(dataset)] + + metrics = compute_metrics( + start_logits, end_logits, dataset, raw_datasets["validation"] + ) + print(metrics) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(help="sub-command help") + + # a fine-tuning worker + parser_finetuning = subparsers.add_parser( + "finetuning", help="the entrance of BERT fine-tuning" + ) + parser_finetuning.add_argument("--architecture", type=str, help="the architecture of BERT", default="bert-base-uncased") + parser_finetuning.add_argument("--batch-size", type=int, default=8) + parser_finetuning.add_argument("--epochs", type=int, default=3) + parser_finetuning.add_argument("--lr", type=int, default=2e-5) + parser_finetuning.set_defaults(func=finetuning) + + parser_postquant = subparsers.add_parser( + "postquant", help="the entrance of BERT post training quantization" + ) + parser_postquant.set_defaults(func=postquant) + parser_postquant.add_argument("qconfig") + parser_postquant.add_argument("checkpoint") + parser_postquant.add_argument("--architecture", type=str, help="the architecture of BERT", default="bert-base-uncased") + parser_postquant.add_argument("--batch-size", type=int, default=8) + + args = parser.parse_args() + args.func(args) + diff --git a/examples/post_training_quantization/SQuAD/model.py b/examples/post_training_quantization/SQuAD/model.py new file mode 100644 index 0000000..5b32439 --- /dev/null +++ b/examples/post_training_quantization/SQuAD/model.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, + QuestionAnsweringModelOutput, +) +from transformers import ( + BertTokenizer, + AdamW, + get_linear_schedule_with_warmup, + BertPreTrainedModel, +) +from transformers.models.bert.modeling_bert import BertPooler, BertLayer + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.seq_length = 0 # a workaround for traced + + def forward(self, input_ids, token_type_ids): + position_ids = self.position_ids[:, : self.seq_length] + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # set None for torch.fx traced + layer_head_mask = None # head_mask[i] if head_mask is not None else None + past_key_value = ( + None # past_key_values[i] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def get_extended_attention_mask( + self, attention_mask, input_shape: Tuple[int], device + ): + extended_attention_mask = attention_mask[:, None, None, :] + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + # extended_attention_mask = extended_attention_mask.to( + # dtype=self.dtype + # ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + token_type_ids: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = self.config.output_attentions + output_hidden_states = self.config.output_hidden_states + use_cache = False + + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device + ) + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertForQuestionAnswering(nn.Module): + def __init__(self, traced_model, config=None): + super().__init__() + self.bert = traced_model + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + self.config = config + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + #return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + #return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + sequence_output = outputs["last_hidden_state"] + logits = self.qa_outputs(sequence_output) + return logits + + + #start_logits, end_logits = logits.split(1, dim=-1) + #start_logits = start_logits.squeeze(-1).contiguous() + #end_logits = end_logits.squeeze(-1).contiguous() + # + #total_loss = None + #if start_positions is not None and end_positions is not None: + # # If we are on multi-GPU, split add a dimension + # if len(start_positions.size()) > 1: + # start_positions = start_positions.squeeze(-1) + # if len(end_positions.size()) > 1: + # end_positions = end_positions.squeeze(-1) + # # sometimes the start/end positions are outside our model inputs, we ignore these terms + # ignored_index = start_logits.size(1) + # start_positions = start_positions.clamp(0, ignored_index) + # end_positions = end_positions.clamp(0, ignored_index) + # + # loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + # start_loss = loss_fct(start_logits, start_positions) + # end_loss = loss_fct(end_logits, end_positions) + # total_loss = (start_loss + end_loss) / 2 + # + #return QuestionAnsweringModelOutput( + # loss=total_loss, + # start_logits=start_logits, + # end_logits=end_logits, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + #) diff --git a/examples/post_training_quantization/SQuAD/qconfig.yaml b/examples/post_training_quantization/SQuAD/qconfig.yaml new file mode 100644 index 0000000..6d52320 --- /dev/null +++ b/examples/post_training_quantization/SQuAD/qconfig.yaml @@ -0,0 +1,22 @@ +BACKEND: tensorrt +SCHEDULE: + FUSE_BN: True +W: + QSCHEME: per-channel-symmetric + QUANTIZER: + TYPE: uniform + BIT: 8 + OBSERVER: + TYPE: MINMAX +A: + QSCHEME: per-tensor-symmetric + QUANTIZER: + TYPE: uniform + BIT: 8 + OBSERVER: + TYPE: MSE + LAYOUT: NCHW + SPECIFIC: [{ # bit=0 is disable_quant + "*layer_norm": ["QUANTIZER.DISABLE", True], + "softmax*": ["QUANTIZER.DISABLE", True], + }] diff --git a/examples/post_training_quantization/SQuAD/requirements.txt b/examples/post_training_quantization/SQuAD/requirements.txt new file mode 100644 index 0000000..b7bbcd6 --- /dev/null +++ b/examples/post_training_quantization/SQuAD/requirements.txt @@ -0,0 +1,3 @@ +datasets==2.7.1 +evaluate==0.3.0 +transformers==4.18.0