Active Learning for NLP Classification

In this tutorial, we guide you through using our new HuggingFace trainer wrapper to do active learning with transformers models. Any model which could be trained by HuggingFace trainer and has Dropout layers could be used in the same manner.

We will use the SST2 dataset and BertForSequenceClassification as the model for the purpose of this tutorial. As usual, we need to first download the dataset.

Note: This tutorial is intended for advanced users. If you are not familiar with BaaL, please refer to other tutorials.

[1]:
from datasets import load_dataset
datasets = load_dataset("glue", "sst2", cache_dir="/tmp")
raw_train_set = datasets['train']
Reusing dataset glue (/tmp/glue/sst2/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)

ActiveLearning Dataset

In order to create an active learning dataset, we need to wrap the dataset with baal.ActiveLearningDataset. This requires a torch.utils.Dataset so we propose a baal.active.HuggingFaceDataset that can take a HuggingFace dataset and perform the preprocessing steps.

[2]:
from baal.active import active_huggingface_dataset
from transformers import BertTokenizer
pretrained_weights = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
active_set = active_huggingface_dataset(raw_train_set, tokenizer)

# lets randomly label 100 samples, therefore len(active_set) should be 100
active_set.label_randomly(100)
assert len(active_set) == 100
print(len(active_set.pool))
67249

Active Learning Model

The process of making a model bayesian is exactly the same as before. In this case, we will get the Bert model and use baal.bayesian.dropout.patch_module to make the dropout layer stochastic at inference time.

[3]:
from copy import deepcopy
import torch
from transformers import BertForSequenceClassification
from baal.bayesian.dropout import patch_module

use_cuda = torch.cuda.is_available()

model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=pretrained_weights)
model = patch_module(model)
if use_cuda:
    model.cuda()
init_weights = deepcopy(model.state_dict())
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Heuristic

As already implemented and useful in all classification cases, we continue using BALD as our active learning heuristic.

Note: ActiveLearning for NLP tasks is an open and challenging field and hence, desiging a proper heuristic is out of the scope of this tutorial. We encourage any pull request that would propose better heuristics.

[4]:
from baal.active import get_heuristic

heuristic = get_heuristic('bald')

HugginFace Trainer Wrapper

If you are not familiar with the HuggingFace trainer module please start here. HuggingFace Trainer is one of the most popular library to train Transformer models. In order to do active learning, we need the prediction to be run over every sample in pool for number of iterations and hence our wrapper baal.BaalTransformersTrainer will provide this functionality on top of the provided functionalities in the Trainer module. In the rest of this tutorial, we show how to initialize the baal.active.active_loop.ActiveLearningLoop and how to do Active Training.

[5]:
from transformers import TrainingArguments
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from baal.active.active_loop import ActiveLearningLoop

#Initialization for the huggingface trainer
training_args = TrainingArguments(
    output_dir='.',  # output directory
    num_train_epochs=5,  # total # of training epochs per AL step
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,  # batch size for evaluation
    weight_decay=0.01,  # strength of weight decay
    logging_dir='.',  # directory for storing logs
    )

# create the trainer through Baal Wrapper
baal_trainer = BaalTransformersTrainer(model=model,
                                       args=training_args,
                                       train_dataset=active_set,
                                       tokenizer=None)


active_loop = ActiveLearningLoop(active_set,
                                 baal_trainer.predict_on_dataset,
                                 heuristic, 10, iterations=3)

for epoch in range(2):
    baal_trainer.train()

    should_continue = active_loop.step()

    # We reset the model weights to relearn from the new train set.
    baal_trainer.load_state_dict(init_weights)
    baal_trainer.lr_scheduler = None
    if not should_continue:
        break

# at each Active step we add 10 samples to labelled data. At this point we should have 30 samples added
# to the labelled part of training set.
print(len(active_set))
[7/7 00:01, Epoch 1/1]
Step Training Loss

[93-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:15:36.980534Z [info     ] Start Predict                  dataset=67249
100%|██████████| 1051/1051 [12:30<00:00,  1.40it/s]
[7/7 00:01, Epoch 1/1]
Step Training Loss

[7/7 00:01, Epoch 1/1]
Step Training Loss

[7/7 00:01, Epoch 1/1]
Step Training Loss

[7/7 00:01, Epoch 1/1]
Step Training Loss

[7/7 00:01, Epoch 1/1]
Step Training Loss

[93-MainThread   ] [baal.transformers_trainer_wrapper:predict_on_dataset_generator:61] 2021-03-08T20:28:15.903378Z [info     ] Start Predict                  dataset=67239
100%|██████████| 1051/1051 [12:29<00:00,  1.40it/s]
[8/8 00:01, Epoch 1/1]
Step Training Loss

[8/8 00:01, Epoch 1/1]
Step Training Loss

[8/8 00:01, Epoch 1/1]
Step Training Loss

[8/8 00:01, Epoch 1/1]
Step Training Loss

120