Use Baal in production (Image classification)¶
In this tutorial, we will show you how to use Baal during your labeling task.
NOTE In this tutorial, we assume that we do not know the labels!
Install baal¶
pip install baal
We will first need a dataset! For the purpose of this demo, we will use a classification dataset, but Baal works on more than computer vision! As long as we can estimate the uncertainty of a prediction, Baal can be used.
We will use the Natural Images Dataset.
Please extract the data in /tmp/natural_images
.
from glob import glob
import os
from sklearn.model_selection import train_test_split
files = glob('/tmp/natural_images/*/*.jpg')
classes = os.listdir('/tmp/natural_images')
train, test = train_test_split(files, random_state=1337) # Split 75% train, 25% validation
print(f"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}")
Train: 5174, Valid: 1725, Num. classes : 8
Introducing baal.active.FileDataset
and baal.active.ActiveLearningDataset
FileDataset is simply an object that loads data and implements def label(self, idx: int, lbl: Any)
.
This methods is necessary to label items in the dataset. You can set any value you want for unlabelled items,
in our example we use -1.
ActiveLearningDataset
is a wrapper around a Dataset
that performs data management.
When you iterate over it, it will return labelled items only.
To learn more on dataset management, visit this notebook.
from baal.active import FileDataset, ActiveLearningDataset
from torchvision import transforms
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.Resize(224),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# We use -1 to specify that the data is unlabeled.
train_dataset = FileDataset(train, [-1] * len(train), train_transform)
test_transform = transforms.Compose([transforms.Resize(224),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# We use -1 to specify that the data is unlabeled.
test_dataset = FileDataset(test, [-1] * len(test), test_transform)
active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})
We now have two unlabeled datasets : train and validation. We encapsulate the training dataset in a
ActiveLearningDataset
object which will take care of the split between labeled and unlabeled samples.
We are now ready to use Active Learning.
We will use a technique called MC-Dropout, Baal supports other techniques (see README) and proposes a similar API
for each of them.
When using MC-Dropout with Baal, you can use any model as long as there are some Dropout Layers. These layers are essential to compute
the uncertainty of the model.
Baal propose several models, but it also supports custom models using baal.bayesian.dropout.MCDropoutModule.
In this example, we will use VGG-16, a popular model from torchvision
.
import torch
from torch import nn, optim
from baal.modelwrapper import ModelWrapper
from torchvision.models import vgg16
from baal.bayesian.dropout import MCDropoutModule
USE_CUDA = torch.cuda.is_available()
model = vgg16(pretrained=False, num_classes=len(classes))
# This will modify all Dropout layers to be usable at test time which is
# required to perform Active Learning.
model = MCDropoutModule(model)
if USE_CUDA:
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
# ModelWrapper is an object similar to keras.Model.
baal_model = ModelWrapper(model, criterion)
Heuristics¶
To rank uncertainty, we will use a heuristic. For classification and segmentation, BALD is the recommended
heuristic. We will also add noise to the heuristic to lower the selection bias added by the AL process.
This is done by specifying shuffle_prop
in the heuristic constructor.
from baal.active.heuristics import BALD
heuristic = BALD(shuffle_prop=0.1)
Oracle¶
When the AL process requires a new item to labeled, we need to provide an Oracle. In your case, the Oracle will be a human labeler most likely. For this example, we're lucky the class label is in the image path!
# This function would do the work that a human would do.
def get_label(img_path):
return classes.index(img_path.split('/')[-2])
Labeling process¶
The labeling will go like this:
- Label all the test set and some samples from the training set.
- Train the model for a few epoch on the training set.
- Select the K-top uncertain samples according to the heuristic.
- Label those samples.
- If not done, go back to 2.
import numpy as np
# 1. Label all the test set and some samples from the training set.
for idx in range(len(test_dataset)):
img_path = test_dataset.files[idx]
test_dataset.label(idx, get_label(img_path))
# Let's label 100 training examples randomly first.
# Note: the indices here are relative to the pool of unlabelled items!
train_idxs = np.random.permutation(np.arange(len(train_dataset)))[:100].tolist()
labels = [get_label(train_dataset.files[idx]) for idx in train_idxs]
active_learning_ds.label(train_idxs, labels)
print(f"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}")
Num. labeled: 100/5174
# 2. Train the model for a few epoch on the training set.
baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)
print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})
[103-MainThread ] [baal.modelwrapper:train_on_dataset:109] 2021-07-28T14:47:48.133213Z [info ] Starting training dataset=100 epoch=5
/opt/conda/lib/python3.9/site-packages/torch/utils/data/dataloader.py:478: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 1, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( /opt/conda/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[103-MainThread ] [baal.modelwrapper:train_on_dataset:119] 2021-07-28T14:48:07.477011Z [info ] Training complete train_loss=2.058176279067993 [103-MainThread ] [baal.modelwrapper:test_on_dataset:147] 2021-07-28T14:48:07.479793Z [info ] Starting evaluating dataset=1725 [103-MainThread ] [baal.modelwrapper:test_on_dataset:156] 2021-07-28T14:48:21.277716Z [info ] Evaluation complete test_loss=2.0671451091766357 Metrics: {'test_loss': 2.0671451091766357, 'train_loss': 2.058176279067993}
# 3. Select the K-top uncertain samples according to the heuristic.
pool = active_learning_ds.pool
if len(pool) == 0:
raise ValueError("We're done!")
# We make 15 MCDropout iterations to approximate the uncertainty.
predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
# We will label the 10 most uncertain samples.
top_uncertainty = heuristic(predictions)[:10]
[103-MainThread ] [baal.modelwrapper:predict_on_dataset_generator:241] 2021-07-28T14:48:21.291851Z [info ] Start Predict dataset=5074
# 4. Label those samples.
oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)
labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]
print(list(zip(labels, oracle_indices)))
active_learning_ds.label(top_uncertainty, labels)
[(3, 1429), (4, 2971), (2, 1309), (4, 5), (3, 3761), (4, 2708), (6, 4679), (7, 160), (7, 1638), (6, 73)]
# 5. If not done, go back to 2.
for step in range(5): # 5 Active Learning step!
# 2. Train the model for a few epoch on the training set.
print(f"Training on {len(active_learning_ds)} items!")
baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)
print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})
# 3. Select the K-top uncertain samples according to the heuristic.
pool = active_learning_ds.pool
if len(pool) == 0:
print("We're done!")
break
predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
top_uncertainty = heuristic(predictions)[:10]
# 4. Label those samples.
oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)
labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]
active_learning_ds.label(top_uncertainty, labels)
And we're done! Be sure to save the dataset and the model.
torch.save({
'active_dataset': active_learning_ds.state_dict(),
'model': baal_model.state_dict(),
'metrics': {k:v.avg for k,v in baal_model.metrics.items()}
}, '/tmp/baal_output.pth')
Support¶
Submit an issue or reach us to our Slack!