Skip to content

Bayesian deep learning

In Bayesian active learning, we draw from the posterior distribution to estimate uncertainty.

Example

from baal.bayesian.dropout import MCDropoutModule, patch_module
from baal.bayesian.weight_drop import MCDropoutConnectModule

model: nn.Module
# To make Dropout layers always on
model = MCDropoutModule(model)
# or
model = patch_module(model)


# To use MC-Dropconnect on all linear layers
model = MCDropoutConnectModule(model, layers=["Linear"], weight_dropout=0.5)

API

baal.bayesian.dropout.MCDropoutModule

Bases: BayesianModule

Create a module that with all dropout layers patched.

Parameters:

Name Type Description Default
module Module

A fully specified neural network.

required
Source code in baal/bayesian/dropout.py
class MCDropoutModule(BayesianModule):
    """Create a module that with all dropout layers patched.

    Args:
        module (torch.nn.Module):
            A fully specified neural network.
    """

    patching_function = patch_module
    unpatch_function = unpatch_module

baal.bayesian.weight_drop.MCDropoutConnectModule

Bases: BayesianModule

Create a module that with all dropout layers patched. With MCDropoutConnectModule, it could be decided which type of modules to be replaced.

Parameters:

Name Type Description Default
module Module

A fully specified neural network.

required
layers list[str]

Name of layers to be replaced from ['Conv', 'Linear', 'LSTM', 'GRU'].

required
weight_dropout float

The probability a weight will be dropped.

required
Source code in baal/bayesian/weight_drop.py
class MCDropoutConnectModule(BayesianModule):
    """Create a module that with all dropout layers patched.
    With MCDropoutConnectModule, it could be decided which type of modules to be
    replaced.

    Args:
        module (torch.nn.Module):
            A fully specified neural network.
        layers (list[str]):
            Name of layers to be replaced from ['Conv', 'Linear', 'LSTM', 'GRU'].
        weight_dropout (float): The probability a weight will be dropped.
    """

    patching_function = patch_module
    unpatch_function = unpatch_module