DeBERTa-v3-base-mnli-fever-anli
Model description
This model was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs. This base model outperforms almost all large models on the ANLI benchmark. The base model is DeBERTa-v3-base from Microsoft. The v3 variant of DeBERTa substantially outperforms previous versions of the model by including a different pre-training objective, see annex 11 of the original DeBERTa paper.
For highest performance (but less speed), I recommend using https://huggingface.co/MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli.
How to use the model
Simple zero-shot classification pipeline
#!pip install transformers[sentencepiece]
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)
NLI use-case
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device)) # device = "cuda:0" or "cpu"
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
Training data
DeBERTa-v3-base-mnli-fever-anli was trained on the MultiNLI, Fever-NLI and Adversarial-NLI (ANLI) datasets, which comprise 763 913 NLI hypothesis-premise pairs.
Training procedure
DeBERTa-v3-base-mnli-fever-anli was trained using the Hugging Face trainer with the following hyperparameters.
training_args = TrainingArguments(
num_train_epochs=3, # total number of training epochs
learning_rate=2e-05,
per_device_train_batch_size=32, # batch size per device during training
per_device_eval_batch_size=32, # batch size for evaluation
warmup_ratio=0.1, # number of warmup steps for learning rate scheduler
weight_decay=0.06, # strength of weight decay
fp16=True # mixed precision training
)
Eval results
The model was evaluated using the test sets for MultiNLI and ANLI and the dev set for Fever-NLI. The metric used is accuracy.
mnli-m | mnli-mm | fever-nli | anli-all | anli-r3 |
---|---|---|---|---|
0.903 | 0.903 | 0.777 | 0.579 | 0.495 |
Limitations and bias
Please consult the original DeBERTa paper and literature on different NLI datasets for potential biases.
Citation
If you use this model, please cite: Laurer, Moritz, Wouter van Atteveldt, Andreu Salleras Casas, and Kasper Welbers. 2022. ‘Less Annotating, More Classifying – Addressing the Data Scarcity Issue of Supervised Machine Learning with Deep Transfer Learning and BERT - NLI’. Preprint, June. Open Science Framework. https://osf.io/74b8k.
Ideas for cooperation or questions?
If you have questions or ideas for cooperation, contact me at m{dot}laurer{at}vu{dot}nl or LinkedIn
Debugging and issues
Note that DeBERTa-v3 was released on 06.12.21 and older versions of HF Transformers seem to have issues running the model (e.g. resulting in an issue with the tokenizer). Using Transformers>=4.13 might solve some issues.
Also make sure to install sentencepiece to avoid tokenizer errors. Run: pip install transformers[sentencepiece]
or pip install sentencepiece
Model Recycling
Evaluation on 36 datasets using MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli as a base model yields average score of 79.69 in comparison to 79.04 by microsoft/deberta-v3-base.
The model is ranked 2nd among all tested models for the microsoft/deberta-v3-base architecture as of 09/01/2023.
Results:
20_newsgroup | ag_news | amazon_reviews_multi | anli | boolq | cb | cola | copa | dbpedia | esnli | financial_phrasebank | imdb | isear | mnli | mrpc | multirc | poem_sentiment | qnli | qqp | rotten_tomatoes | rte | sst2 | sst_5bins | stsb | trec_coarse | trec_fine | tweet_ev_emoji | tweet_ev_emotion | tweet_ev_hate | tweet_ev_irony | tweet_ev_offensive | tweet_ev_sentiment | wic | wnli | wsc | yahoo_answers |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
85.8072 | 90.4333 | 67.32 | 59.625 | 85.107 | 91.0714 | 85.8102 | 67 | 79.0333 | 91.6327 | 82.5 | 94.02 | 71.6428 | 89.5749 | 89.7059 | 64.1708 | 88.4615 | 93.575 | 91.4148 | 89.6811 | 86.2816 | 94.6101 | 57.0588 | 91.5508 | 97.6 | 91.2 | 45.264 | 82.6179 | 54.5455 | 74.3622 | 84.8837 | 71.6949 | 71.0031 | 69.0141 | 68.2692 | 71.3333 |
For more information, see: Model Recycling
- Downloads last month
- 2,119,742
Datasets used to train MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
Spaces using MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli 16
Evaluation results
- Accuracy on anliverified0.495
- Precision Macro on anliverified0.498
- Precision Micro on anliverified0.495
- Precision Weighted on anliverified0.498
- Recall Macro on anliverified0.495
- Recall Micro on anliverified0.495
- Recall Weighted on anliverified0.495
- F1 Macro on anliverified0.494
- F1 Micro on anliverified0.495
- F1 Weighted on anliverified0.494
- loss on anliverified1.879
- Accuracy on anliverified0.712
- Precision Macro on anliverified0.713
- Precision Micro on anliverified0.712
- Precision Weighted on anliverified0.713
- Recall Macro on anliverified0.712
- Recall Micro on anliverified0.712
- Recall Weighted on anliverified0.712
- F1 Macro on anliverified0.712
- F1 Micro on anliverified0.712
- F1 Weighted on anliverified0.712
- loss on anliverified1.011
- Accuracy on multi_nliverified0.903
- Precision Macro on multi_nliverified0.902
- Precision Micro on multi_nliverified0.903
- Precision Weighted on multi_nliverified0.903
- Recall Macro on multi_nliverified0.902
- Recall Micro on multi_nliverified0.903
- Recall Weighted on multi_nliverified0.903
- F1 Macro on multi_nliverified0.902
- F1 Micro on multi_nliverified0.903
- F1 Weighted on multi_nliverified0.903
- loss on multi_nliverified0.328
- Accuracy on anliverified0.737
- Precision Macro on anliverified0.738
- Precision Micro on anliverified0.737
- Precision Weighted on anliverified0.738
- Recall Macro on anliverified0.737
- Recall Micro on anliverified0.737
- Recall Weighted on anliverified0.737
- F1 Macro on anliverified0.737
- F1 Micro on anliverified0.737
- F1 Weighted on anliverified0.737
- loss on anliverified0.935