Skip to content

Pytorch lightning

Pytorch Lightning Compatibility

baal.utils.pytorch_lightning.ResetCallback

Bases: Callback

Callback to reset the weights between active learning steps.

Parameters:

Name Type Description Default
weights dict

State dict of the model.

required
Notes

The weight should be deep copied beforehand.

Source code in baal/utils/pytorch_lightning.py
class ResetCallback(Callback):
    """Callback to reset the weights between active learning steps.

    Args:
        weights (dict): State dict of the model.

    Notes:
        The weight should be deep copied beforehand.

    """

    def __init__(self, weights):
        self.weights = weights

    def on_train_start(self, trainer, module):
        """Will reset the module to its initial weights."""
        module.load_state_dict(self.weights)
        trainer.fit_loop.current_epoch = 0

on_train_start(trainer, module)

Will reset the module to its initial weights.

Source code in baal/utils/pytorch_lightning.py
def on_train_start(self, trainer, module):
    """Will reset the module to its initial weights."""
    module.load_state_dict(self.weights)
    trainer.fit_loop.current_epoch = 0

baal.utils.pytorch_lightning.BaalTrainer

Bases: Trainer

Object that perform the training and active learning iteration.

Parameters:

Name Type Description Default
dataset ActiveLearningDataset

Dataset with some sample already labelled.

required
heuristic Heuristic

Heuristic from baal.active.heuristics.

Random()
query_size int

Number of sample to label per step.

1
max_sample int

Limit the number of sample used (-1 is no limit).

required
**kwargs

Parameters forwarded to get_probabilities and to pytorch_ligthning Trainer.init

{}
Source code in baal/utils/pytorch_lightning.py
class BaalTrainer(Trainer):
    """Object that perform the training and active learning iteration.

    Args:
        dataset (ActiveLearningDataset): Dataset with some sample already labelled.
        heuristic (Heuristic): Heuristic from baal.active.heuristics.
        query_size (int): Number of sample to label per step.
        max_sample (int): Limit the number of sample used (-1 is no limit).
        **kwargs: Parameters forwarded to `get_probabilities`
            and to pytorch_ligthning Trainer.__init__
    """

    def __init__(
        self,
        dataset: ActiveLearningDataset,
        heuristic: heuristics.AbstractHeuristic = heuristics.Random(),
        query_size: int = 1,
        **kwargs
    ) -> None:

        super().__init__(**kwargs)
        self.query_size = query_size
        self.heuristic = heuristic
        self.dataset = dataset
        self.kwargs = kwargs

    def predict_on_dataset(self, model=None, dataloader=None, *args, **kwargs):
        "For documentation, see `predict_on_dataset_generator`"
        preds = list(self.predict_on_dataset_generator(model, dataloader))

        if len(preds) > 0 and not isinstance(preds[0], Sequence):
            # Is an Array or a Tensor
            return np.vstack(preds)
        return [np.vstack(pr) for pr in zip(*preds)]

    def predict_on_dataset_generator(
        self, model=None, dataloader: Optional[DataLoader] = None, *args, **kwargs
    ):
        """Predict on the pool loader.

        Args:
            model: Model to be used in prediction. If None, will get the Trainer's model.
            dataloader (Optional[DataLoader]): If provided, will predict on this dataloader.
                                                Otherwise, uses model.pool_dataloader().

        Returns:
            Numpy arrays with all the predictions.
        """
        model = model or self.lightning_module
        model.eval()
        if isinstance(self.accelerator, CUDAAccelerator):
            model.cuda(self.strategy.root_device.index)
        dataloader = dataloader or model.pool_dataloader()
        if len(dataloader) == 0:
            return None

        log.info("Start Predict", dataset=len(dataloader))
        for idx, batch in enumerate(tqdm(dataloader, total=len(dataloader), file=sys.stdout)):
            if isinstance(self.accelerator, CUDAAccelerator):
                batch = to_cuda(batch)
            pred = model.predict_step(batch, idx)
            yield map_on_tensor(lambda x: x.detach().cpu().numpy(), pred)
        # teardown, TODO customize this later?
        model.cpu()

    def step(self, model=None, datamodule: Optional[BaaLDataModule] = None) -> bool:
        """
        Perform an active learning step.

        model: Model to be used in prediction. If None, will get the Trainer's model.
        dataloader (Optional[DataLoader]): If provided, will predict on this dataloader.
                                                Otherwise, uses model.pool_dataloader().

        Notes:
            This will get the pool from the model pool_dataloader and if max_sample is set, it will
            **require** the data_loader sampler to select `max_pool` samples.

        Returns:
            boolean, Flag indicating if we continue training.

        """
        # High to low
        if datamodule is None:
            pool_dataloader = self.lightning_module.pool_dataloader()  # type: ignore
        else:
            pool_dataloader = datamodule.pool_dataloader()
        model = model if model is not None else self.lightning_module

        if isinstance(pool_dataloader.sampler, torch.utils.data.sampler.RandomSampler):
            log.warning(
                "Your pool_dataloader has `shuffle=True`," " it is best practice to turn this off."
            )

        if len(pool_dataloader) > 0:
            # TODO Add support for max_samples in pool_dataloader
            probs = self.predict_on_dataset_generator(
                model=model, dataloader=pool_dataloader, **self.kwargs
            )
            if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
                to_label = self.heuristic(probs)
                if len(to_label) > 0:
                    self.dataset.label(to_label[: self.query_size])
                    return True
        return False

predict_on_dataset(model=None, dataloader=None, *args, **kwargs)

For documentation, see predict_on_dataset_generator

Source code in baal/utils/pytorch_lightning.py
def predict_on_dataset(self, model=None, dataloader=None, *args, **kwargs):
    "For documentation, see `predict_on_dataset_generator`"
    preds = list(self.predict_on_dataset_generator(model, dataloader))

    if len(preds) > 0 and not isinstance(preds[0], Sequence):
        # Is an Array or a Tensor
        return np.vstack(preds)
    return [np.vstack(pr) for pr in zip(*preds)]

predict_on_dataset_generator(model=None, dataloader=None, *args, **kwargs)

Predict on the pool loader.

Parameters:

Name Type Description Default
model

Model to be used in prediction. If None, will get the Trainer's model.

None
dataloader Optional[DataLoader]

If provided, will predict on this dataloader. Otherwise, uses model.pool_dataloader().

None

Returns:

Type Description

Numpy arrays with all the predictions.

Source code in baal/utils/pytorch_lightning.py
def predict_on_dataset_generator(
    self, model=None, dataloader: Optional[DataLoader] = None, *args, **kwargs
):
    """Predict on the pool loader.

    Args:
        model: Model to be used in prediction. If None, will get the Trainer's model.
        dataloader (Optional[DataLoader]): If provided, will predict on this dataloader.
                                            Otherwise, uses model.pool_dataloader().

    Returns:
        Numpy arrays with all the predictions.
    """
    model = model or self.lightning_module
    model.eval()
    if isinstance(self.accelerator, CUDAAccelerator):
        model.cuda(self.strategy.root_device.index)
    dataloader = dataloader or model.pool_dataloader()
    if len(dataloader) == 0:
        return None

    log.info("Start Predict", dataset=len(dataloader))
    for idx, batch in enumerate(tqdm(dataloader, total=len(dataloader), file=sys.stdout)):
        if isinstance(self.accelerator, CUDAAccelerator):
            batch = to_cuda(batch)
        pred = model.predict_step(batch, idx)
        yield map_on_tensor(lambda x: x.detach().cpu().numpy(), pred)
    # teardown, TODO customize this later?
    model.cpu()

step(model=None, datamodule=None)

Perform an active learning step.

model: Model to be used in prediction. If None, will get the Trainer's model. dataloader (Optional[DataLoader]): If provided, will predict on this dataloader. Otherwise, uses model.pool_dataloader().

Notes

This will get the pool from the model pool_dataloader and if max_sample is set, it will require the data_loader sampler to select max_pool samples.

Returns:

Type Description
bool

boolean, Flag indicating if we continue training.

Source code in baal/utils/pytorch_lightning.py
def step(self, model=None, datamodule: Optional[BaaLDataModule] = None) -> bool:
    """
    Perform an active learning step.

    model: Model to be used in prediction. If None, will get the Trainer's model.
    dataloader (Optional[DataLoader]): If provided, will predict on this dataloader.
                                            Otherwise, uses model.pool_dataloader().

    Notes:
        This will get the pool from the model pool_dataloader and if max_sample is set, it will
        **require** the data_loader sampler to select `max_pool` samples.

    Returns:
        boolean, Flag indicating if we continue training.

    """
    # High to low
    if datamodule is None:
        pool_dataloader = self.lightning_module.pool_dataloader()  # type: ignore
    else:
        pool_dataloader = datamodule.pool_dataloader()
    model = model if model is not None else self.lightning_module

    if isinstance(pool_dataloader.sampler, torch.utils.data.sampler.RandomSampler):
        log.warning(
            "Your pool_dataloader has `shuffle=True`," " it is best practice to turn this off."
        )

    if len(pool_dataloader) > 0:
        # TODO Add support for max_samples in pool_dataloader
        probs = self.predict_on_dataset_generator(
            model=model, dataloader=pool_dataloader, **self.kwargs
        )
        if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
            to_label = self.heuristic(probs)
            if len(to_label) > 0:
                self.dataset.label(to_label[: self.query_size])
                return True
    return False

baal.utils.pytorch_lightning.BaaLDataModule

Bases: LightningDataModule

Source code in baal/utils/pytorch_lightning.py
class BaaLDataModule(LightningDataModule):
    def __init__(self, active_dataset: ActiveLearningDataset, batch_size=1, **kwargs):
        super().__init__(**kwargs)
        self.active_dataset = active_dataset
        self.batch_size = batch_size

    def pool_dataloader(self) -> DataLoader:
        """Create Dataloader for the pool of unlabelled examples."""
        return DataLoader(
            self.active_dataset.pool, batch_size=self.batch_size, num_workers=4, shuffle=False
        )

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        if "active_dataset" in checkpoint:
            self.active_dataset.load_state_dict(checkpoint["active_dataset"])
        else:
            log.warning("'active_dataset' not in checkpoint!")

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
        checkpoint["active_dataset"] = self.active_dataset.state_dict()

pool_dataloader()

Create Dataloader for the pool of unlabelled examples.

Source code in baal/utils/pytorch_lightning.py
def pool_dataloader(self) -> DataLoader:
    """Create Dataloader for the pool of unlabelled examples."""
    return DataLoader(
        self.active_dataset.pool, batch_size=self.batch_size, num_workers=4, shuffle=False
    )