Active learning infrastructure objects¶
Active learning, or interactively choosing datapoints to request labels for, presents a challenge that requires some data handling infrastructure that's slightly different to the normal pytorch dataset classes. In particular, a dataset is no longer a static thing, but instead grows as you progress through your experiment or your application.
To handle these needs, baal
contains:
ActiveLearningDataset
, which is a pytorch dataset that lets you interactively label data.ActiveLearningLoop
, which wraps anActiveLearningDataset
and simplifies experiments by allowing you to callstep
whenever you want to label data.
The ActiveLearningDataset
wraps another pytorch dataset. For an example on
how to use it, you can take a look at how we turn the MNIST dataset into an
active learning dataset:
path = "/tmp"
from torchvision import transforms, datasets
from baal.active.dataset import ActiveLearningDataset
transform = transforms.Compose([transforms.Grayscale(3), transforms.ToTensor()])
test_transform = transform
active_mnist = ActiveLearningDataset(
datasets.MNIST(path, train=True, transform=transform),
pool_specifics={'transform': test_transform},
)
As you can see, this is a fairly thin wrapper around MNIST. But, we can now check several new properties of this dataset:
active_mnist.n_labelled
tensor(0)
active_mnist.n_unlabelled
tensor(60000)
We can also start labelling data. Either randomly, or based on specific indices:
active_mnist.label_randomly(10)
active_mnist.label([55, 56, 50100])
We've just labelled 10 points randomly, and 3 points based on specific indices. Now, if we check how many have been labelled, we see that 13 have been labelled:
active_mnist.n_labelled
tensor(13)
active_mnist.n_unlabelled
tensor(59987)
We will also see that when we check the length of this dataset - something that
is done by e.g. pytorch DataLoader
classes - it only gives the length of the
labelled dataset:
len(active_mnist)
13
And, if we try to access an item, it will only allow us to index the labelled datapoints:
active_mnist[0]
(tensor([[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], [[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], [[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]), 8)
active_mnist[15]
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) <ipython-input-14-2413857e337e> in <module> ----> 1 active_mnist[15] ~/projects/baal/src/baal/active/dataset.py in __getitem__(self, index) 30 def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]: 31 """Return stuff from the original dataset.""" ---> 32 return self._dataset[self._labelled_to_oracle_index(index)] 33 34 def __len__(self) -> int: ~/projects/baal/src/baal/active/dataset.py in _labelled_to_oracle_index(self, index) 54 55 def _labelled_to_oracle_index(self, index: int) -> int: ---> 56 return self._labelled.nonzero()[index].squeeze().item() 57 58 def _pool_to_oracle_index(self, index: Union[int, List[int]]) -> List[int]: IndexError: index 15 is out of bounds for dimension 0 with size 13
Instead, if we want to actually use the unlabelled data, we need to use the
pool
attribute of the active learning dataset, which is itself a dataset:
len(active_mnist.pool)
59987