How to use BaaL with Pytorch Lightning

In this notebook we’ll go through an example of how to build a project with Baal and Pytorch Lightning

Useful resources:

NOTE The API of ActiveLearningMixin and BaalTrainer are subject to change as we are looking for feedback from the community. If you want to help us making this API better, please come to our Gitter or submit an issue.

[ ]:
import copy
from dataclasses import dataclass, asdict

from baal.active import ActiveLearningDataset
from baal.active.heuristics import BALD
from baal.bayesian.dropout import patch_module
from baal.utils.pytorch_lightning import ActiveLearningMixin, BaalTrainer, BaaLDataModule, ResetCallback
from pytorch_lightning import LightningModule
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16
from torchvision.transforms import transforms

Model definition

Bellow you can see an example using VGG16

Note the ActiveLearningMixin which we will use to perform active learning. This Mixin expects an active dataset and the following keys in the hparams:

iterations: int # How many MC sampling to perform at prediction time.
replicate_in_memory: bool # Whether to perform MC sampling by replicating the batch `iterations` times.

If you want to modify how the MC sampling is made, you can overwrite predict_step.

[ ]:
class VGG16(LightningModule, ActiveLearningMixin):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.name = "VGG16"
        self.version = "0.0.1"
        self.criterion = CrossEntropyLoss()
        self._build_model()

    def _build_model(self):
        # We use `patch_module` to swap Dropout modules in the model
        # for our implementation which enables MC-Dropou
        self.vgg16 = patch_module(vgg16(num_classes=self.hparams.num_classes))

    def forward(self, x):
        return self.vgg16(x)

    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop
        :param batch:
        :return:
        """
        # forward pass
        x, y = batch
        y_hat = self(x)

        # calculate loss
        loss_val = self.criterion(y_hat, y)

        self.log("train_loss", loss_val, prog_bar=True, on_epoch=True)
        return loss_val

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        # calculate loss
        loss_val = self.criterion(y_hat, y)

        self.log("test_loss", loss_val, prog_bar=True, on_epoch=True)
        return loss_val

    def configure_optimizers(self):
        """
        return whatever optimizers we want here
        :return: list of optimizers
        """
        optimizer = optim.SGD(self.parameters(), lr=self.hparams.learning_rate, momentum=0.9, weight_decay=5e-4)
        return [optimizer], []

Hyperparameters

[ ]:
@dataclass
class HParams:
    batch_size: int = 10
    data_root: str = '/tmp'
    num_classes: int = 10
    learning_rate: float = 0.001
    query_size: int = 100
    iterations: int = 20
    replicate_in_memory: bool = True
    gpus: int = 1

hparams = HParams()

DataModule

We support pl.DataModule, here is how you can define it. By using BaaLDataModule, you do not have to implement pool_dataloader which is the DataLoader that runs on the pool of unlabelled examples.

[ ]:
class Cifar10DataModule(BaaLDataModule):
    def __init__(self, data_root, batch_size):
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                              transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        active_set = ActiveLearningDataset(
            CIFAR10(data_root, train=True, transform=train_transform, download=True),
            pool_specifics={
                'transform': test_transform
            })
        self.test_set = CIFAR10(data_root, train=False, transform=test_transform, download=True)
        super().__init__(active_dataset=active_set, batch_size=batch_size,
                         train_transforms=train_transform,
                         test_transforms=test_transform)

    def train_dataloader(self, *args, **kwargs) -> DataLoader:
        return DataLoader(self.active_dataset, self.batch_size, shuffle=True, num_workers=4)

    def test_dataloader(self, *args, **kwargs) -> DataLoader:
        return DataLoader(self.test_set, self.batch_size, shuffle=False, num_workers=4)

Experiment

We now have all the pieces to start our experiment.

Initial labelling

To kickstart active learning, we will randomly select items to be labelled.

[ ]:
data_module = Cifar10DataModule(hparams.data_root, hparams.batch_size)
data_module.active_dataset.label_randomly(10)

Instantiating BALD

This is used to rank the uncertainty. More info here.

[ ]:
heuristic = BALD()
model = VGG16(**asdict(hparams))

Create a trainer to generate predictions

Note that we use the BaalTrainer which inherits the usual Pytorch Lightning Trainer. The BaaLTrainer will take care of the active learning part by performing predict_on_dataset on the pool.

[ ]:
trainer = BaalTrainer(dataset=data_module.active_dataset,
                      heuristic=heuristic,
                      ndata_to_label=hparams.query_size,
                      max_epochs=10, default_root_dir=hparams.data_root,
                      gpus=hparams.gpus,
                      callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])

Training the model and perform Active learning

Our experiment steps are as follow:

  1. Train on the labelled dataset.

  2. Evaluate ourselves on a held-out set.

  3. Label the top-k most uncertain examples.

  4. Go back to 1.

[ ]:
AL_STEPS = 100

for al_step in range(AL_STEPS):
    print(f'Step {al_step} Dataset size {len(data_module.active_dataset)}')
    trainer.fit(model, datamodule=data_module)  # Train the model on the labelled set.
    trainer.test(model, datamodule=data_module)  # Get test performance.
    should_continue = trainer.step(model, datamodule=data_module)  # Label the top-k most uncertain examples.
    if not should_continue:
        break