Baal cheat sheet

In the table below, we have a mapping between common equations and the Baal API.

Setup

Here are the types for all variables needed.

model: torch.nn.Module
wrapper: baal.ModelWrapper
dataset: torch.utils.data_utils.Dataset
bald = baal.active.heuristics.BALD()
entropy = baal.active.heuristics.Entropy()


We assume that baal.bayesian.dropout.patch_module has been applied to the model.

model = baal.bayesian.dropout.patch_module(model)

Description Equation Baal
Bayesian Model Averaging $$\hat{T} = p(y \mid x, {\cal D})= \int p(y \mid x, \theta)p(\theta \mid D) d\theta$$ wrapper.predict_on_dataset(dataset, batch_size=B, iterations=I, use_cuda=True).mean(-1)
MC-Dropout $$T = \{p(y\mid x_j, \theta_i)\} \mid x_j \in {\cal D}' ,i \in \{1, \ldots, I\}$$ wrapper.predict_on_dataset(dataset, batch_size=B, iterations=I, use_cuda=True)
BALD $${\cal I}[y, \theta \mid x, {\cal D}] = {\cal H}[y \mid x, {\cal D}] - {\cal E}_{p(\theta \mid {\cal D})}[{\cal H}[y \mid x, \theta]]$$ bald.get_uncertainties(T)
Entropy $$\sum_c \hat{T}_c \log(\hat{T}_c)$$ entropy.get_uncertainties(T)

Contributing

If some equations are missing, please open a PR so that we can make this cheat sheet as useful as possible.