HuggingFace: Active Learning for NLP Classification¶
In this tutorial, we guide you through using our new HuggingFace trainer wrapper to do active learning with transformers models.
Any model which could be trained by HuggingFace trainer and has Dropout
layers could be used in the same manner.
We will use the SST2
dataset and BertForSequenceClassification
as the model for the purpose of this tutorial. As usual, we need to first download the dataset.
Note: This tutorial is intended for advanced users. If you are not familiar with Baal, please refer to other tutorials.
from datasets import load_dataset
datasets = load_dataset("glue", "sst2", cache_dir="/tmp")
raw_train_set = datasets['train']
Reusing dataset glue (/tmp/glue/sst2/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)
ActiveLearning Dataset¶
In order to create an active learning dataset, we need to wrap the dataset with baal.ActiveLearningDataset
.
This requires a torch.utils.Dataset
so we propose a baal.active.HuggingFaceDataset
that can take a HuggingFace dataset
and perform the preprocessing steps.
from baal.active.dataset.nlp_datasets import active_huggingface_dataset
from transformers import BertTokenizer
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
active_set = active_huggingface_dataset(raw_train_set, tokenizer)
active_set.can_label = False # Need to manually do this for research
# lets randomly label 100 samples, therefore len(active_set) should be 100
active_set.label_randomly(100)
assert len(active_set) == 100
print(len(active_set.pool))
67249
Active Learning Model¶
The process of making a model bayesian is exactly the same as before. In this case, we will get the Bert
model and use baal.bayesian.dropout.patch_module
to make the dropout layer stochastic at inference time.
from copy import deepcopy
import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module
use_cuda = torch.cuda.is_available()
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
model = patch_module(model)
if use_cuda:
model.cuda()
init_weights = deepcopy(model.state_dict())
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias'] - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Heuristic¶
As already implemented and useful in all classification cases, we continue using BALD
as our active learning heuristic.
Note: ActiveLearning for NLP tasks is an open and challenging field and hence, desiging a proper heuristic is out of the scope of this tutorial. We encourage any pull request that would propose better heuristics.
from baal.active import get_heuristic
heuristic = get_heuristic('bald')
HugginFace Trainer Wrapper¶
If you are not familiar with the HuggingFace trainer module please start here.
HuggingFace Trainer is one of the most popular library to train Transformer models.
In order to do active learning, we need the prediction to be run over every sample in pool for number of iterations and hence our wrapper baal.BaalTransformersTrainer
will provide this functionality on top of the provided functionalities in the Trainer
module.
In the rest of this tutorial, we show how to initialize the baal.active.active_loop.ActiveLearningLoop
and how to do Active Training.
from transformers import TrainingArguments
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from baal.active.active_loop import ActiveLearningLoop
#Initialization for the huggingface trainer
training_args = TrainingArguments(
output_dir='.', # output directory
num_train_epochs=5, # total # of training epochs per AL step
per_device_train_batch_size=16, # batch size per device during training
per_device_eval_batch_size=64, # batch size for evaluation
weight_decay=0.01, # strength of weight decay
logging_dir='.', # directory for storing logs
)
# create the trainer through Baal Wrapper
baal_trainer = BaalTransformersTrainer(model=model,
args=training_args,
train_dataset=active_set,
tokenizer=None)
active_loop = ActiveLearningLoop(active_set,
baal_trainer.predict_on_dataset,
heuristic, 10, iterations=3)
for epoch in range(2):
baal_trainer.train()
should_continue = active_loop.step()
# We reset the model weights to relearn from the new train set.
baal_trainer.load_state_dict(init_weights)
baal_trainer.lr_scheduler = None
if not should_continue:
break
# at each Active step we add 10 samples to labelled data. At this point we should have 30 samples added
# to the labelled part of training set.
print(len(active_set))
Step | Training Loss |
---|
[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [info ] Start Predict dataset=67249
100%|██████████| 1051/1051 [12:30<00:00, 1.40it/s]
Step | Training Loss |
---|
Step | Training Loss |
---|
Step | Training Loss |
---|
Step | Training Loss |
---|
Step | Training Loss |
---|
[93-MainThread ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [info ] Start Predict dataset=67239
100%|██████████| 1051/1051 [12:29<00:00, 1.40it/s]
Step | Training Loss |
---|
Step | Training Loss |
---|
Step | Training Loss |
---|
Step | Training Loss |
---|
120