Active learning infrastructure objects

Active learning, or interactively choosing datapoints to request labels for, presents a challenge that requires some data handling infrastructure that’s slightly different to the normal pytorch dataset classes. In particular, a dataset is no longer a static thing, but instead grows as you progress through your experiment or your application.

To handle these needs, baal contains:

  • ActiveLearningDataset, which is a pytorch dataset that lets you interactively label data.

  • ActiveLearningLoop, which wraps an ActiveLearningDataset and simplifies experiments by allowing you to call step whenever you want to label data.

The ActiveLearningDataset wraps another pytorch dataset. For an example on how to use it, you can take a look at how we turn the MNIST dataset into an active learning dataset:

[2]:
path = "/Users/jan/datasets/mnist/"
[4]:
from torchvision import transforms, datasets
from baal.active.dataset import ActiveLearningDataset

transform = transforms.Compose([transforms.Grayscale(3), transforms.ToTensor()])
test_transform = transform

active_mnist = ActiveLearningDataset(
    datasets.MNIST(path, train=True, transform=transform),
    pool_specifics={'transform': test_transform},
)

As you can see, this is a fairly thin wrapper around MNIST. But, we can now check several new properties of this dataset:

[5]:
active_mnist.n_labelled
[5]:
tensor(0)
[6]:
active_mnist.n_unlabelled
[6]:
tensor(60000)

We can also start labelling data. Either randomly, or based on specific indices:

[7]:
active_mnist.label_randomly(10)
active_mnist.label([55, 56, 50100])

We’ve just labelled 10 points randomly, and 3 points based on specific indices. Now, if we check how many have been labelled, we see that 13 have been labelled:

[8]:
active_mnist.n_labelled
[8]:
tensor(13)
[9]:
active_mnist.n_unlabelled
[9]:
tensor(59987)

We will also see that when we check the length of this dataset - something that is done by e.g. pytorch DataLoader classes - it only gives the length of the labelled dataset:

[12]:
len(active_mnist)
[12]:
13

And, if we try to access an item, it will only allow us to index the labelled datapoints:

[13]:
active_mnist[0]
[13]:
(tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]), 8)
[14]:
active_mnist[15]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-14-2413857e337e> in <module>
----> 1 active_mnist[15]

~/projects/baal/src/baal/active/dataset.py in __getitem__(self, index)
     30     def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]:
     31         """Return stuff from the original dataset."""
---> 32         return self._dataset[self._labelled_to_oracle_index(index)]
     33
     34     def __len__(self) -> int:

~/projects/baal/src/baal/active/dataset.py in _labelled_to_oracle_index(self, index)
     54
     55     def _labelled_to_oracle_index(self, index: int) -> int:
---> 56         return self._labelled.nonzero()[index].squeeze().item()
     57
     58     def _pool_to_oracle_index(self, index: Union[int, List[int]]) -> List[int]:

IndexError: index 15 is out of bounds for dimension 0 with size 13

Instead, if we want to actually use the unlabelled data, we need to use the pool attribute of the active learning dataset, which is itself a dataset:

[15]:
len(active_mnist.pool)
[15]:
59987