Skip to content

Active learning functionality

In this module, we find all the utilities to do active learning.

  1. Dataset management
  2. Active loop implementation

Baal takes care of the dataset split between labelled and unlabelled examples. It also takes care of the active learning loop:

  1. Predict on the unlabelled examples.
  2. Label the most uncertain examples.

Example

from baal.active.dataset import ActiveLearningDataset
al_dataset = ActiveLearningDataset(your_dataset)

# To start, we can select 1000 random examples to be labelled
al_dataset.label_randomly(1000)

# Our training set is now 1000
len(al_dataset)

# We can label examples by their indices.
al_dataset.label([32, 10, 4])

# Our dataset length is now 1003.
len(al_dataset)

# At initialization, we can also swap attributes for the pool.
al_dataset = ActiveLearningDataset(your_dataset, pool_specifics={"transform": None})
assert al_dataset.pool.transform is None

API

baal.active.ActiveLearningDataset

Bases: SplittedDataset

A dataset that allows for active learning.

Parameters:

Name Type Description Default
dataset Dataset

The baseline dataset.

required
labelled Optional[ndarray]

An array that acts as a mask which is greater than 1 for every data point that is labelled, and 0 for every data point that is not labelled.

None
make_unlabelled Callable

The function that returns an unlabelled version of a datum so that it can still be used in the DataLoader.

_identity
random_state

Set the random seed for label_randomly().

None
pool_specifics Optional[dict]

Attributes to set when creating the pool. Useful to remove data augmentation.

None
last_active_steps int

If specified, will iterate over the last active steps instead of the full dataset. Useful when doing partial finetuning.

-1
Source code in baal/active/dataset/pytorch_dataset.py
class ActiveLearningDataset(SplittedDataset):
    """A dataset that allows for active learning.

    Args:
        dataset: The baseline dataset.
        labelled: An array that acts as a mask which is greater than 1 for every
            data point that is labelled, and 0 for every data point that is not
            labelled.
        make_unlabelled: The function that returns an
            unlabelled version of a datum so that it can still be used in the DataLoader.
        random_state: Set the random seed for label_randomly().
        pool_specifics: Attributes to set when creating the pool.
                                         Useful to remove data augmentation.
        last_active_steps: If specified, will iterate over the last active steps
                            instead of the full dataset. Useful when doing partial finetuning.
    """

    def __init__(
        self,
        dataset: Dataset,
        labelled: Optional[np.ndarray] = None,
        make_unlabelled: Callable = _identity,
        random_state=None,
        pool_specifics: Optional[dict] = None,
        last_active_steps: int = -1,
    ) -> None:
        self._dataset = dataset

        # The labelled_map keeps track of the step at which an item as been labelled.
        if labelled is not None:
            labelled_map: np.ndarray = labelled.astype(int)
        else:
            labelled_map = np.zeros(len(self._dataset), dtype=int)

        if pool_specifics is None:
            pool_specifics = {}
        self.pool_specifics: Dict[str, Any] = pool_specifics

        self.make_unlabelled = make_unlabelled
        # For example, FileDataset has a method 'label'. This is useful when we're in prod.
        self.can_label = self.check_dataset_can_label()
        super().__init__(
            labelled=labelled_map, random_state=random_state, last_active_steps=last_active_steps
        )
        self._warn_if_pool_stochastic()

    def check_dataset_can_label(self):
        """Check if a dataset can be labelled.

        Returns:
            Whether the dataset's label can be modified or not.

        Notes:
            To be labelled, a dataset needs a method `label`
            with definition: `label(self, idx, value)` where `value`
            is the label for indice `idx`.
        """
        has_label_attr = getattr(self._dataset, "label", None)
        if has_label_attr:
            if callable(has_label_attr):
                return True
            else:
                warnings.warn(
                    "Dataset has an attribute `label`, but it is not callable."
                    "The Dataset will not be labelled with new labels.",
                    UserWarning,
                )
        return False

    def __getitem__(self, index: int) -> Any:
        """Return items from the original dataset based on the labelled index."""
        index = self.get_indices_for_active_step()[index]
        return self._dataset[index]

    class ActiveIter:
        """Iterator over an ActiveLearningDataset."""

        def __init__(self, aldataset):
            self.i = 0
            self.aldataset = aldataset

        def __len__(self):
            return len(self.aldataset)

        def __next__(self):
            if self.i >= len(self):
                raise StopIteration

            n = self.aldataset[self.i]
            self.i = self.i + 1
            return n

    def __iter__(self):
        return self.ActiveIter(self)

    @property
    def pool(self) -> "ActiveLearningPool":
        """Returns a new Dataset made from unlabelled samples.

        Raises:
            ValueError if a pool specific attribute cannot be set.
        """
        current_dataset = deepcopy(self._dataset)

        for attr, new_val in self.pool_specifics.items():
            if hasattr(current_dataset, attr):
                setattr(current_dataset, attr, new_val)
            else:
                raise ValueError(f"{current_dataset} doesn't have {attr}")

        pool_dataset: torchdata.Subset = torchdata.Subset(
            current_dataset, (~self.labelled).nonzero()[0].reshape([-1]).tolist()
        )
        ald = ActiveLearningPool(pool_dataset, make_unlabelled=self.make_unlabelled)
        return ald

    def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
        """
        Label data points.
        The index should be relative to the pool, not the overall data.

        Args:
            index: one or many indices to label.
            value: The label value. If not provided, no modification
                                    to the underlying dataset is done.

        Raises:
            ValueError if the indices do not match the values or
             if no `value` is provided and `can_label` is True.
        """
        if isinstance(index, int):
            # We were provided only the index, we make a list.
            index_lst = [index]
            value_lst: List[Any] = [value]
        else:
            index_lst = index
            if value is None:
                value_lst = [value]
            else:
                value_lst = value

        if value_lst[0] is not None and len(index_lst) != len(value_lst):
            raise ValueError(
                "Expected `index` and `value` to be of same length when `value` is provided."
                f"Got index={len(index_lst)} and value={len(value_lst)}"
            )
        indexes = self._pool_to_oracle_index(index_lst)
        active_step = self.current_al_step + 1
        for idx, val in zip_longest(indexes, value_lst, fillvalue=None):
            if self.can_label and val is not None:
                self._dataset.label(idx, val)  # type: ignore
                self.labelled_map[idx] = active_step
            elif self.can_label and val is None:
                raise ValueError(
                    """The dataset is able to label data, but no label was provided.
                                 If this is a research setting, please set the
                                  `ActiveLearningDataset.can_label` to `False`.
                                  """
                )
            else:
                # Regular research usecase.
                self.labelled_map[idx] = active_step
                if val is not None:
                    warnings.warn(
                        "We will consider the original label of this datasample : {}, {}.".format(
                            self._dataset[idx][0], self._dataset[idx][1]
                        ),
                        UserWarning,
                    )

    def reset_labelled(self):
        """Reset the label map."""
        self.labelled_map = np.zeros(len(self._dataset), dtype=np.bool)

    def get_raw(self, idx: int) -> Any:
        """Get a datapoint from the underlying dataset."""
        return self._dataset[idx]

    def state_dict(self) -> Dict:
        """Return the state_dict, ie. the labelled map and random_state."""
        return {"labelled": self.labelled_map, "random_state": self.random_state}

    def load_state_dict(self, state_dict):
        """Load the labelled map and random_state with give state_dict."""
        self.labelled_map = state_dict["labelled"]
        self.random_state = state_dict["random_state"]

    def _warn_if_pool_stochastic(self):
        pool = self.pool
        if len(pool) > 0 and not deep_check(pool[0], pool[0]):
            warnings.warn(
                STOCHASTIC_POOL_WARNING,
                UserWarning,
            )

pool: ActiveLearningPool property

Returns a new Dataset made from unlabelled samples.

ActiveIter

Iterator over an ActiveLearningDataset.

Source code in baal/active/dataset/pytorch_dataset.py
class ActiveIter:
    """Iterator over an ActiveLearningDataset."""

    def __init__(self, aldataset):
        self.i = 0
        self.aldataset = aldataset

    def __len__(self):
        return len(self.aldataset)

    def __next__(self):
        if self.i >= len(self):
            raise StopIteration

        n = self.aldataset[self.i]
        self.i = self.i + 1
        return n

__getitem__(index)

Return items from the original dataset based on the labelled index.

Source code in baal/active/dataset/pytorch_dataset.py
def __getitem__(self, index: int) -> Any:
    """Return items from the original dataset based on the labelled index."""
    index = self.get_indices_for_active_step()[index]
    return self._dataset[index]

check_dataset_can_label()

Check if a dataset can be labelled.

Returns:

Type Description

Whether the dataset's label can be modified or not.

Notes

To be labelled, a dataset needs a method label with definition: label(self, idx, value) where value is the label for indice idx.

Source code in baal/active/dataset/pytorch_dataset.py
def check_dataset_can_label(self):
    """Check if a dataset can be labelled.

    Returns:
        Whether the dataset's label can be modified or not.

    Notes:
        To be labelled, a dataset needs a method `label`
        with definition: `label(self, idx, value)` where `value`
        is the label for indice `idx`.
    """
    has_label_attr = getattr(self._dataset, "label", None)
    if has_label_attr:
        if callable(has_label_attr):
            return True
        else:
            warnings.warn(
                "Dataset has an attribute `label`, but it is not callable."
                "The Dataset will not be labelled with new labels.",
                UserWarning,
            )
    return False

get_raw(idx)

Get a datapoint from the underlying dataset.

Source code in baal/active/dataset/pytorch_dataset.py
def get_raw(self, idx: int) -> Any:
    """Get a datapoint from the underlying dataset."""
    return self._dataset[idx]

label(index, value=None)

Label data points. The index should be relative to the pool, not the overall data.

Parameters:

Name Type Description Default
index Union[list, int]

one or many indices to label.

required
value Optional[Any]

The label value. If not provided, no modification to the underlying dataset is done.

None
Source code in baal/active/dataset/pytorch_dataset.py
def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
    """
    Label data points.
    The index should be relative to the pool, not the overall data.

    Args:
        index: one or many indices to label.
        value: The label value. If not provided, no modification
                                to the underlying dataset is done.

    Raises:
        ValueError if the indices do not match the values or
         if no `value` is provided and `can_label` is True.
    """
    if isinstance(index, int):
        # We were provided only the index, we make a list.
        index_lst = [index]
        value_lst: List[Any] = [value]
    else:
        index_lst = index
        if value is None:
            value_lst = [value]
        else:
            value_lst = value

    if value_lst[0] is not None and len(index_lst) != len(value_lst):
        raise ValueError(
            "Expected `index` and `value` to be of same length when `value` is provided."
            f"Got index={len(index_lst)} and value={len(value_lst)}"
        )
    indexes = self._pool_to_oracle_index(index_lst)
    active_step = self.current_al_step + 1
    for idx, val in zip_longest(indexes, value_lst, fillvalue=None):
        if self.can_label and val is not None:
            self._dataset.label(idx, val)  # type: ignore
            self.labelled_map[idx] = active_step
        elif self.can_label and val is None:
            raise ValueError(
                """The dataset is able to label data, but no label was provided.
                             If this is a research setting, please set the
                              `ActiveLearningDataset.can_label` to `False`.
                              """
            )
        else:
            # Regular research usecase.
            self.labelled_map[idx] = active_step
            if val is not None:
                warnings.warn(
                    "We will consider the original label of this datasample : {}, {}.".format(
                        self._dataset[idx][0], self._dataset[idx][1]
                    ),
                    UserWarning,
                )

load_state_dict(state_dict)

Load the labelled map and random_state with give state_dict.

Source code in baal/active/dataset/pytorch_dataset.py
def load_state_dict(self, state_dict):
    """Load the labelled map and random_state with give state_dict."""
    self.labelled_map = state_dict["labelled"]
    self.random_state = state_dict["random_state"]

reset_labelled()

Reset the label map.

Source code in baal/active/dataset/pytorch_dataset.py
def reset_labelled(self):
    """Reset the label map."""
    self.labelled_map = np.zeros(len(self._dataset), dtype=np.bool)

state_dict()

Return the state_dict, ie. the labelled map and random_state.

Source code in baal/active/dataset/pytorch_dataset.py
def state_dict(self) -> Dict:
    """Return the state_dict, ie. the labelled map and random_state."""
    return {"labelled": self.labelled_map, "random_state": self.random_state}

baal.active.ActiveLearningLoop

Object that perform the active learning iteration.

Parameters:

Name Type Description Default
dataset ActiveLearningDataset

Dataset with some sample already labelled.

required
get_probabilities Function

Dataset -> **kwargs -> ndarray [n_samples, n_outputs, n_iterations].

required
heuristic Heuristic

Heuristic from baal.active.heuristics.

Random()
query_size int

Number of sample to label per step.

1
max_sample int

Limit the number of sample used (-1 is no limit).

-1
uncertainty_folder Optional[str]

If provided, will store uncertainties on disk.

None
ndata_to_label int

DEPRECATED, please use query_size.

None
**kwargs

Parameters forwarded to get_probabilities.

{}
Source code in baal/active/active_loop.py
class ActiveLearningLoop:
    """Object that perform the active learning iteration.

    Args:
        dataset (ActiveLearningDataset): Dataset with some sample already labelled.
        get_probabilities (Function): Dataset -> **kwargs ->
                                        ndarray [n_samples, n_outputs, n_iterations].
        heuristic (Heuristic): Heuristic from baal.active.heuristics.
        query_size (int): Number of sample to label per step.
        max_sample (int): Limit the number of sample used (-1 is no limit).
        uncertainty_folder (Optional[str]): If provided, will store uncertainties on disk.
        ndata_to_label (int): DEPRECATED, please use `query_size`.
        **kwargs: Parameters forwarded to `get_probabilities`.
    """

    def __init__(
        self,
        dataset: ActiveLearningDataset,
        get_probabilities: Callable,
        heuristic: heuristics.AbstractHeuristic = heuristics.Random(),
        query_size: int = 1,
        max_sample=-1,
        uncertainty_folder=None,
        ndata_to_label=None,
        **kwargs,
    ) -> None:
        if ndata_to_label is not None:
            warnings.warn(
                "`ndata_to_label` is deprecated, please use `query_size`.", DeprecationWarning
            )
            query_size = ndata_to_label
        self.query_size = query_size
        self.get_probabilities = get_probabilities
        self.heuristic = heuristic
        self.dataset = dataset
        self.max_sample = max_sample
        self.uncertainty_folder = uncertainty_folder
        self.kwargs = kwargs

    def step(self, pool=None) -> bool:
        """
        Perform an active learning step.

        Args:
            pool (iterable): Optional dataset pool indices.
                             If not set, will use pool from the active set.

        Returns:
            boolean, Flag indicating if we continue training.

        """
        if pool is None:
            pool = self.dataset.pool
            if len(pool) > 0:
                # Limit number of samples
                if self.max_sample != -1 and self.max_sample < len(pool):
                    indices = np.random.choice(len(pool), self.max_sample, replace=False)
                    pool = torchdata.Subset(pool, indices)
                else:
                    indices = np.arange(len(pool))
        else:
            indices = None

        if len(pool) > 0:
            if isinstance(self.heuristic, heuristics.Random):
                probs = np.random.uniform(low=0, high=1, size=(len(pool), 1))
            else:
                probs = self.get_probabilities(pool, **self.kwargs)
            if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
                to_label, uncertainty = self.heuristic.get_ranks(probs)
                if indices is not None:
                    to_label = indices[np.array(to_label)]
                if self.uncertainty_folder is not None:
                    # We save uncertainty in a file.
                    uncertainty_name = (
                        f"uncertainty_pool={len(pool)}" f"_labelled={len(self.dataset)}.pkl"
                    )
                    pickle.dump(
                        {
                            "indices": indices,
                            "uncertainty": uncertainty,
                            "dataset": self.dataset.state_dict(),
                        },
                        open(pjoin(self.uncertainty_folder, uncertainty_name), "wb"),
                    )
                if len(to_label) > 0:
                    self.dataset.label(to_label[: self.query_size])
                    return True
        return False

step(pool=None)

Perform an active learning step.

Parameters:

Name Type Description Default
pool iterable

Optional dataset pool indices. If not set, will use pool from the active set.

None

Returns:

Type Description
bool

boolean, Flag indicating if we continue training.

Source code in baal/active/active_loop.py
def step(self, pool=None) -> bool:
    """
    Perform an active learning step.

    Args:
        pool (iterable): Optional dataset pool indices.
                         If not set, will use pool from the active set.

    Returns:
        boolean, Flag indicating if we continue training.

    """
    if pool is None:
        pool = self.dataset.pool
        if len(pool) > 0:
            # Limit number of samples
            if self.max_sample != -1 and self.max_sample < len(pool):
                indices = np.random.choice(len(pool), self.max_sample, replace=False)
                pool = torchdata.Subset(pool, indices)
            else:
                indices = np.arange(len(pool))
    else:
        indices = None

    if len(pool) > 0:
        if isinstance(self.heuristic, heuristics.Random):
            probs = np.random.uniform(low=0, high=1, size=(len(pool), 1))
        else:
            probs = self.get_probabilities(pool, **self.kwargs)
        if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
            to_label, uncertainty = self.heuristic.get_ranks(probs)
            if indices is not None:
                to_label = indices[np.array(to_label)]
            if self.uncertainty_folder is not None:
                # We save uncertainty in a file.
                uncertainty_name = (
                    f"uncertainty_pool={len(pool)}" f"_labelled={len(self.dataset)}.pkl"
                )
                pickle.dump(
                    {
                        "indices": indices,
                        "uncertainty": uncertainty,
                        "dataset": self.dataset.state_dict(),
                    },
                    open(pjoin(self.uncertainty_folder, uncertainty_name), "wb"),
                )
            if len(to_label) > 0:
                self.dataset.label(to_label[: self.query_size])
                return True
    return False

baal.active.FileDataset

Bases: Dataset

Dataset object that load the files and apply a transformation.

Parameters:

Name Type Description Default
files List[str]

The files.

required
lbls List[Any]

The labels, -1 indicates that the label is unknown.

None
transform Optional[Callable]

torchvision.transform pipeline.

None
target_transform Optional[Callable]

Function that modifies the target.

None
image_load_fn Optional[Callable]

Function that loads the image, by default uses PIL.

None
seed Optional[int]

Will set a seed before and between DA.

None
Source code in baal/active/file_dataset.py
class FileDataset(Dataset):
    """
    Dataset object that load the files and apply a transformation.

    Args:
        files (List[str]): The files.
        lbls (List[Any]): The labels, -1 indicates that the label is unknown.
        transform (Optional[Callable]): torchvision.transform pipeline.
        target_transform (Optional[Callable]): Function that modifies the target.
        image_load_fn (Optional[Callable]): Function that loads the image, by default uses PIL.
        seed (Optional[int]): Will set a seed before and between DA.
    """

    def __init__(
        self,
        files: List[str],
        lbls: Optional[List[Any]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        image_load_fn: Optional[Callable] = None,
        seed=None,
    ):
        self.files = files

        if lbls is None:
            self.lbls = [-1] * len(self.files)
        else:
            self.lbls = lbls

        self.transform = transform
        self.target_transform = target_transform
        self.image_load_fn = image_load_fn or default_image_load_fn
        self.seed = seed

    def label(self, idx: int, lbl: Any):
        """
        Label the sample `idx` with `lbl`.

        Args:
            idx (int): The sample index.
            lbl (Any): The label to assign.
        """
        if self.lbls[idx] >= 0:
            warnings.warn(
                "We're modifying the class of the sample {} that we already know : {}.".format(
                    self.files[idx], self.lbls[idx]
                ),
                UserWarning,
            )

        self.lbls[idx] = lbl

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x, y = self.files[idx], self.lbls[idx]

        np.random.seed(self.seed)
        batch_seed = np.random.randint(0, 100, 1).item()
        seed_all(batch_seed + idx)

        img = self.image_load_fn(x)
        kwargs = self.get_kwargs(self.transform, image_shape=img.size, idx=idx)

        if self.transform:
            img_t = self.transform(img, **kwargs)
        else:
            img_t = img

        if self.target_transform:
            seed_all(batch_seed + idx)
            kwargs = self.get_kwargs(self.target_transform, image_shape=img.size, idx=idx)
            y = self.target_transform(y, **kwargs)
        return img_t, y

    @staticmethod
    def get_kwargs(transform, **kwargs):
        if isinstance(transform, BaaLTransform):
            t_kwargs = {k: kwargs[k] for k in transform.get_requires()}
        else:
            t_kwargs = {}
        return t_kwargs

label(idx, lbl)

Label the sample idx with lbl.

Parameters:

Name Type Description Default
idx int

The sample index.

required
lbl Any

The label to assign.

required
Source code in baal/active/file_dataset.py
def label(self, idx: int, lbl: Any):
    """
    Label the sample `idx` with `lbl`.

    Args:
        idx (int): The sample index.
        lbl (Any): The label to assign.
    """
    if self.lbls[idx] >= 0:
        warnings.warn(
            "We're modifying the class of the sample {} that we already know : {}.".format(
                self.files[idx], self.lbls[idx]
            ),
            UserWarning,
        )

    self.lbls[idx] = lbl