Use BaaL in production (Classification)

In this tutorial, we will show you how to use BaaL during your labeling task.

NOTE In this tutorial, we assume that we do not know the labels!

Install baal

pip install baal

We will first need a dataset! For the purpose of this demo, we will use a classification dataset, but BaaL works on more than computer vision! As long as we can estimate the uncertainty of a prediction, BaaL can be used.

We will use the Natural Images Dataset. Please extract the data in /tmp/natural_images.

[1]:
from glob import glob
import os
from sklearn.model_selection import train_test_split
files = glob('/tmp/natural_images/*/*.jpg')
classes = os.listdir('/tmp/natural_images')
train, test = train_test_split(files, random_state=1337)  # Split 75% train, 25% validation
print(f"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}")

Train: 5174, Valid: 1725, Num. classes : 8

Introducing baal.active.FileDataset and baal.active.ActiveLearningDataset

FileDataset is simply an object that loads data and implements def label(self, idx: int, lbl: Any). This methods is necessary to label items in the dataset. You can set any value you want for unlabelled items, in our example we use -1.

ActiveLearningDataset is a wrapper around a Dataset that performs data management. When you iterate over it, it will return labelled items only.

To learn more on dataset management, visit this notebook.

[2]:
from baal.active import FileDataset, ActiveLearningDataset
from torchvision import transforms

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(224),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# We use -1 to specify that the data is unlabeled.
train_dataset = FileDataset(train, [-1] * len(train), train_transform)

test_transform = transforms.Compose([transforms.Resize(224),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# We use -1 to specify that the data is unlabeled.
test_dataset = FileDataset(test, [-1] * len(test), test_transform)
active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})

We now have two unlabeled datasets : train and validation. We encapsulate the training dataset in a ActiveLearningDataset object which will take care of the split between labeled and unlabeled samples. We are now ready to use Active Learning. We will use a technique called MC-Dropout, BaaL supports other techniques (see README) and proposes a similar API for each of them. When using MC-Dropout with BaaL, you can use any model as long as there are some Dropout Layers. These layers are essential to compute the uncertainty of the model.

BaaL propose several models, but it also supports custom models using baal.bayesian.dropout.MCDropoutModule.

In this example, we will use VGG-16, a popular model from torchvision.

[3]:
import torch
from torch import nn, optim
from baal.modelwrapper import ModelWrapper
from torchvision.models import vgg16
from baal.bayesian.dropout import MCDropoutModule
USE_CUDA = torch.cuda.is_available()
model = vgg16(pretrained=False, num_classes=len(classes))
# This will modify all Dropout layers to be usable at test time which is
# required to perform Active Learning.
model = MCDropoutModule(model)
if USE_CUDA:
  model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# ModelWrapper is an object similar to keras.Model.
baal_model = ModelWrapper(model, criterion)


Heuristics

To rank uncertainty, we will use a heuristic. For classification and segmentation, BALD is the recommended heuristic. We will also add noise to the heuristic to lower the selection bias added by the AL process. This is done by specifying shuffle_prop in the heuristic constructor.

[4]:
from baal.active.heuristics import BALD
heuristic = BALD(shuffle_prop=0.1)

Oracle

When the AL process requires a new item to labeled, we need to provide an Oracle. In your case, the Oracle will be a human labeler most likely. For this example, we’re lucky the class label is in the image path!

[5]:
# This function would do the work that a human would do.
def get_label(img_path):
  return classes.index(img_path.split('/')[-2])


Labeling process

The labeling will go like this: 1. Label all the test set and some samples from the training set. 2. Train the model for a few epoch on the training set. 3. Select the K-top uncertain samples according to the heuristic. 4. Label those samples. 5. If not done, go back to 2.

[6]:
import numpy as np
# 1. Label all the test set and some samples from the training set.
for idx in range(len(test_dataset)):
  img_path = test_dataset.files[idx]
  test_dataset.label(idx, get_label(img_path))

# Let's label 100 training examples randomly first.
# Note: the indices here are relative to the pool of unlabelled items!
train_idxs = np.random.permutation(np.arange(len(train_dataset)))[:100].tolist()
labels = [get_label(train_dataset.files[idx]) for idx in train_idxs]
active_learning_ds.label(train_idxs, labels)

print(f"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}")

Num. labeled: 100/5174
[7]:
# 2. Train the model for a few epoch on the training set.
baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)

print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})

[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [info     ] Starting training              dataset=100 epoch=5
/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [info     ] Training complete              train_loss=2.058176279067993
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [info     ] Evaluation complete            test_loss=2.0671451091766357
Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}
[8]:
# 3. Select the K-top uncertain samples according to the heuristic.
pool = active_learning_ds.pool
if len(pool) == 0:
  raise ValueError("We're done!")

# We make 15 MCDropout iterations to approximate the uncertainty.
predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
# We will label the 10 most uncertain samples.
top_uncertainty = heuristic(predictions)[:10]

[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [info     ] Start Predict                  dataset=5074
[9]:
# 4. Label those samples.
labels = [get_label(train_dataset.files[idx]) for idx in top_uncertainty]
print(list(zip(labels, top_uncertainty)))
active_learning_ds.label(top_uncertainty, labels)


[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]
[10]:
# 5. If not done, go back to 2.
for step in range(5): # 5 Active Learning step!
  # 2. Train the model for a few epoch on the training set.
  print(f"Training on {len(active_learning_ds)} items!")
  baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
  baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)

  print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})

  # 3. Select the K-top uncertain samples according to the heuristic.
  pool = active_learning_ds.pool
  if len(pool) == 0:
    print("We're done!")
    break
  predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
  top_uncertainty = heuristic(predictions)[:10]
  # 4. Label those samples.
  labels = [get_label(train_dataset.files[idx]) for idx in top_uncertainty]
  active_learning_ds.label(top_uncertainty, labels)
Training on 110 items!
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:50:02.089160Z [info     ] Starting training              dataset=110 epoch=5
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:50:19.678241Z [info     ] Training complete              train_loss=1.9793428182601929
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:50:19.681509Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:50:33.777658Z [info     ] Evaluation complete            test_loss=2.013453960418701
Metrics: {'test_loss': 2.013453960418701, 'train_loss': 1.9793428182601929}
[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:50:33.784990Z [info     ] Start Predict                  dataset=5064
Training on 120 items!
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:52:14.295969Z [info     ] Starting training              dataset=120 epoch=5
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:52:32.482238Z [info     ] Training complete              train_loss=1.8900309801101685
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:52:32.484473Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:52:46.287436Z [info     ] Evaluation complete            test_loss=1.8315811157226562
Metrics: {'test_loss': 1.8315811157226562, 'train_loss': 1.8900309801101685}
[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:52:46.367016Z [info     ] Start Predict                  dataset=5054
Training on 130 items!
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:54:26.794349Z [info     ] Starting training              dataset=130 epoch=5
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:54:44.481490Z [info     ] Training complete              train_loss=1.961772084236145
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:54:44.483477Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:54:58.268424Z [info     ] Evaluation complete            test_loss=1.859472393989563
Metrics: {'test_loss': 1.859472393989563, 'train_loss': 1.961772084236145}
[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:54:58.276565Z [info     ] Start Predict                  dataset=5044
Training on 140 items!
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:56:38.406344Z [info     ] Starting training              dataset=140 epoch=5
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:56:57.088064Z [info     ] Training complete              train_loss=1.8688158988952637
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:56:57.091358Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:57:10.968456Z [info     ] Evaluation complete            test_loss=1.7242822647094727
Metrics: {'test_loss': 1.7242822647094727, 'train_loss': 1.8688158988952637}
[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:57:10.977104Z [info     ] Start Predict                  dataset=5034
Training on 150 items!
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:58:51.197386Z [info     ] Starting training              dataset=150 epoch=5
[103-MainThread  ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:59:09.779341Z [info     ] Training complete              train_loss=1.8381125926971436
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:59:09.782580Z [info     ] Starting evaluating            dataset=1725
[103-MainThread  ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:59:23.176680Z [info     ] Evaluation complete            test_loss=1.7318601608276367
Metrics: {'test_loss': 1.7318601608276367, 'train_loss': 1.8381125926971436}
[103-MainThread  ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:59:23.184444Z [info     ] Start Predict                  dataset=5024

And we’re done! Be sure to save the dataset and the model.

[11]:
torch.save({
  'active_dataset': active_learning_ds.state_dict(),
  'model': baal_model.state_dict(),
  'metrics': {k:v.avg for k,v in baal_model.metrics.items()}
}, '/tmp/baal_output.pth')

Support

Submit an issue or reach us to our Gitter!