How to do research and visualize progress¶
In this tutorial, we will show how to use Baal for research ie. when we know the labels. We will introduce notions such as dataset management, MC-Dropout, BALD. If you need more documentation, be sure to check our Additional resources section below!
BaaL can be used on a variety of research domains:
- Active Learning
- Uncertainty estimation
- Fairness, Accountability, Transparency and Ethics (FATE)
- And more!
Today we will focus on a simple example with CIFAR10 and we will animate the progress of active learning!
Additional resources¶
- More info on the inner working of Active Learning Dataset here.
- To know more about Bayesian deep learning please see our Literature review.
Let's do this!¶
# Let's start with a bunch of imports.
import random
from copy import deepcopy
from dataclasses import dataclass
import numpy as np
import torch
import torch.backends
import torch.utils.data as torchdata
from torch import optim
from torch.hub import load_state_dict_from_url
from torch.nn import CrossEntropyLoss
from torchvision import datasets
from torchvision import models
from torchvision.transforms import transforms
from tqdm.autonotebook import tqdm
from baal.active import get_heuristic, ActiveLearningDataset
from baal.active.active_loop import ActiveLearningLoop
from baal.bayesian.dropout import patch_module
from baal.modelwrapper import ModelWrapper
def vgg16(num_classes):
model = models.vgg16(pretrained=False, num_classes=num_classes)
weights = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth')
weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
model.load_state_dict(weights, strict=False)
return model
Dataset management and the pool¶
At many places in our library, we mention the pool, but what is it? The pool is simply the set of unlabelled examples, the example that the model has not deemed important enough to be labelled.
In BaaL, the object that manages this is called the baal.active.ActiveLearningDataset
.
You supply it with your dataset and by default, everything is unlabelled.
al_dataset = ActiveLearningDataset(your_dataset)
You can then start with the initial number of examples that are labelled randomly. Let's label 100 examples.
al_dataset.label_randomly(100)
When iterating over al_dataset
, you will only get the labelled examples. If you need to work on the pool,
then you can call al_dataset.pool
which would return it.
How can I disable data augmentation when iterating on the pool?
Disabling data augmentation when computing uncertainty is preferable, we wouldn't want to disrupt the uncertainty estimation. Fortunately, BaaL can help you with that.
In ActiveLearningDataset
, we can supply the attribute pool_specifics
with a dictionary of what to modify in the pool.
For example, my Dataset
has an attribute transform
which applies the data augmentation.
I can modify with:
ActiveLearningDataset(your_dataset, pool_specifics:{'transform': test_transform}
where test_transform
is the test version of transform
without data augmentation.
Here we define our Experiment configuration, this can come from your favorite experiment manager like MLFlow. BaaL does not expect a particular format as all arguments are supplied.
@dataclass
class ExperimentConfig:
epoch: int = 20000 // 100
batch_size: int = 32
initial_pool: int = 512
query_size: int = 100
lr: float = 0.001
heuristic: str = 'bald'
iterations: int = 40
training_duration: int = 10
Problem definition¶
We will perform active learning on a toy dataset, CIFAR-3 where we only keep dogs, cats and airplanes. This will make visualization easier.
def get_datasets(initial_pool):
"""
Let's create a subset of CIFAR10 named CIFAR3, so that we can visualize thing better.
We will only select the classes airplane, cat and dog.
Args:
initial_pool: Amount of labels to start with.
Returns:
ActiveLearningDataset, Dataset, the training and test set.
"""
class TransformAdapter(torchdata.Subset):
# We need a custom Subset class as we need to override "transforms" as well.
# This shouldn't be needed for your experiments.
@property
def transform(self):
if hasattr(self.dataset, 'transform'):
return self.dataset.transform
else:
raise AttributeError()
@transform.setter
def transform(self, transform):
if hasattr(self.dataset, 'transform'):
self.dataset.transform = transform
# airplane, cat, dog
classes_to_keep = [0, 3, 5]
transform = transforms.Compose(
[transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize(3 * [0.5], 3 * [0.5]), ])
test_transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(3 * [0.5], 3 * [0.5]),
]
)
train_ds = datasets.CIFAR10('.', train=True,
transform=transform, target_transform=None, download=True)
train_mask = np.where([y in classes_to_keep for y in train_ds.targets])[0]
train_ds = TransformAdapter(train_ds, train_mask)
# In a real application, you will want a validation set here.
test_set = datasets.CIFAR10('.', train=False,
transform=test_transform, target_transform=None, download=True)
test_mask = np.where([y in classes_to_keep for y in test_set.targets])[0]
test_set = TransformAdapter(test_set, test_mask)
# Here we set `pool_specifics`, where we set the transform attribute for the pool.
active_set = ActiveLearningDataset(train_ds, pool_specifics={'transform': test_transform})
# We start labeling randomly.
active_set.label_randomly(initial_pool)
return active_set, test_set
Creating our experiment¶
We are now ready to instantiate all of our components:
- Our
ActiveLearningDataset
and a test dataset. - Our heuristic (BALD)
- Our model and its criterion.
BaaL simplifies your experiments by providing ModelWrapper and ActiveLearningLoop.
- ModelWrapper
- Performs MC sampling efficiently
- Training/testing loops
- ActiveLearningLoop
- Will make prediction on the pool and label the most uncertain examples.
hyperparams = ExperimentConfig()
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
random.seed(1337)
torch.manual_seed(1337)
if not use_cuda:
print("warning, the experiments would take ages to run on cpu")
# Get datasets
active_set, test_set = get_datasets(hyperparams.initial_pool)
# Get our model.
heuristic = get_heuristic(hyperparams.heuristic)
criterion = CrossEntropyLoss()
model = vgg16(num_classes=10)
# change dropout layer to MCDropout
model = patch_module(model)
if use_cuda:
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=hyperparams.lr, momentum=0.9)
# Wraps the model into a usable API.
model = ModelWrapper(model, criterion)
# for ActiveLearningLoop we use a smaller batchsize
# since we will stack predictions to perform MCDropout.
active_loop = ActiveLearningLoop(active_set,
model.predict_on_dataset,
heuristic,
hyperparams.query_size,
batch_size=1,
iterations=hyperparams.iterations,
use_cuda=use_cuda,
verbose=False)
# We will reset the weights at each active learning step so we make a copy.
init_weights = deepcopy(model.state_dict())
What is an active learning loop¶
An active learning loop is the process of:
- Training
- Estimate uncertainty on the pool
- Label the most uncertain examples.
labelling_progress = active_set._labelled.copy().astype(np.uint16)
for epoch in tqdm(range(hyperparams.epoch)):
# Load the initial weights.
model.load_state_dict(init_weights)
# Train the model on the currently labelled dataset.
_ = model.train_on_dataset(active_set, optimizer=optimizer, batch_size=hyperparams.batch_size,
use_cuda=use_cuda, epoch=hyperparams.training_duration)
# Get test NLL!
model.test_on_dataset(test_set, hyperparams.batch_size, use_cuda,
average_predictions=hyperparams.iterations)
metrics = model.metrics
# We can now label the most uncertain samples according to our heuristic.
should_continue = active_loop.step()
# Keep track of progress
labelling_progress += active_set._labelled.astype(np.uint16)
if not should_continue:
break
test_loss = metrics['test_loss'].value
logs = {
"test_nll": test_loss,
"epoch": epoch,
"Next Training set size": len(active_set)
}
We will now save our progress on disk.
model_weight = model.state_dict()
dataset = active_set.state_dict()
torch.save({'model': model_weight, 'dataset': dataset, 'labelling_progress': labelling_progress},
'checkpoint.pth')
print(model.state_dict().keys(), dataset.keys(), labelling_progress)
# modify our model to get features
from torch import nn
from torch.utils.data import DataLoader
# Make a feature extractor from our trained model.
class FeatureExtractor(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return torch.flatten(self.model.features(x), 1)
features = FeatureExtractor(model.model)
acc = []
for x, y in DataLoader(active_set._dataset, batch_size=10):
acc.append((features(x.cuda()).detach().cpu().numpy(), y.detach().cpu().numpy()))
xs, ys = zip(*acc)
from sklearn.manifold import TSNE
# Compute t-SNE on the extracted features.
tsne = TSNE(n_jobs=4)
transformed = tsne.fit_transform(np.vstack(xs))
labels = np.concatenate(ys)
labels.shape
(15000,)
To make the animation, BaaL has baal.utils.plot_utils.make_animation_from_data
which takes a set of features, their labels
and the array containing the progress we created earlier.
from baal.utils.plot_utils import make_animation_from_data
# Create frames to animate the process.
frames = make_animation_from_data(transformed, labels, labelling_progress, ["airplane", "cat", "dog"])
from IPython.display import HTML
import matplotlib.pyplot as plt
from matplotlib import animation
def plot_images(img_list):
def init():
img.set_data(img_list[0])
return (img,)
def animate(i):
img.set_data(img_list[i])
return (img,)
fig = plt.Figure(figsize=(10, 10))
ax = fig.gca()
img = ax.imshow(img_list[0])
anim = animation.FuncAnimation(fig, animate, init_func=init,
frames=len(img_list), interval=60, blit=True)
return anim
HTML(plot_images(frames).to_jshtml())
Animation size has reached 21159316 bytes, exceeding the limit of 20971520.0. If you're sure you want a larger animation embedded, set the animation.embed_limit rc parameter to a larger value (in MB). This and further frames will be dropped.