Skip to content

Calibration Wrapper

baal.calibration.DirichletCalibrator

Bases: object

Adding a linear layer to a classifier model after the model is trained and train this new layer until convergence. Together with the linear layer, the model is now calibrated. Source: https://arxiv.org/abs/1910.12656 Code inspired from: https://github.com/dirichletcal/experiments_neurips

References

@article{kullbeyond, title={Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with Dirichlet calibration Supplementary material}, author={Kull, Meelis and Perello-Nieto, Miquel and K{"a}ngsepp, Markus and Silva Filho, Telmo and Song, Hao and Flach, Peter} }

Parameters:

Name Type Description Default
wrapper ModelWrapper

Provides training and testing methods.

required
num_classes int

Number of classes in classification task.

required
lr float

Learning rate.

required
reg_factor float

Regularization factor for the linear layer weights.

required
mu float

Regularization factor for the linear layer biases. If not given, will be initialized by "l".

None
Source code in baal/calibration/calibration.py
class DirichletCalibrator(object):
    """
    Adding a linear layer to a classifier model after the model is
    trained and train this new layer until convergence.
    Together with the linear layer, the model is now calibrated.
    Source: https://arxiv.org/abs/1910.12656
    Code inspired from: https://github.com/dirichletcal/experiments_neurips

    References:
        @article{kullbeyond,
                title={Beyond temperature scaling: Obtaining well-calibrated multi-class
                 probabilities with Dirichlet calibration Supplementary material},
                author={Kull, Meelis and Perello-Nieto,
                 Miquel and K{\"a}ngsepp, Markus and Silva Filho,
                  Telmo and Song, Hao and Flach, Peter}
                }

    Args:

        wrapper (ModelWrapper): Provides training and testing methods.
        num_classes (int): Number of classes in classification task.
        lr (float): Learning rate.
        reg_factor (float): Regularization factor for the linear layer weights.
        mu (float): Regularization factor for the linear layer biases.
            If not given, will be initialized by "l".

    """

    def __init__(
        self,
        wrapper: ModelWrapper,
        num_classes: int,
        lr: float,
        reg_factor: float,
        mu: float = None,
    ):
        self.num_classes = num_classes
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.reg_factor = reg_factor
        self.mu = mu or reg_factor
        self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes)
        self.model = nn.Sequential(wrapper.model, self.dirichlet_linear)
        self.wrapper = ModelWrapper(self.model, self.criterion)

        self.wrapper.add_metric("ece", lambda: ECE())
        self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))

    def l2_reg(self):
        """Using trainable layer's parameters for l2 regularization.

        Returns:
            The regularization term for the linear layer.
        """
        weight_p, bias_p = self.dirichlet_linear.parameters()
        w_l2_factor = weight_p.norm(2)
        b_l2_factor = bias_p.norm(2)
        return self.reg_factor * w_l2_factor + self.mu * b_l2_factor

    def calibrate(
        self,
        train_set: Dataset,
        test_set: Dataset,
        batch_size: int,
        epoch: int,
        use_cuda: bool,
        double_fit: bool = False,
        **kwargs
    ):
        """
        Training the linear layer given a training set and a validation set.
        The training set should be different from what model is trained on.

        Args:
            train_set (Dataset): The training set.
            test_set (Dataset): The validation set.
            batch_size (int): Batch size used.
            epoch (int): Number of epochs to train the linear layer for.
            use_cuda (bool): If "True", will use GPU.
            double_fit (bool): If "True" would fit twice on the train set.
            kwargs (dict): Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().

        Returns:
            loss_history (list[float]): List of loss values for each epoch.
            model.state_dict (dict): Model weights.

        """

        # reinitialize the dirichlet calibration layer
        self.dirichlet_linear.weight.data.copy_(torch.eye(self.dirichlet_linear.weight.shape[0]))
        self.dirichlet_linear.bias.data.copy_(torch.zeros(*self.dirichlet_linear.bias.shape))
        if use_cuda:
            self.dirichlet_linear.cuda()

        optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)

        loss_history, weights = self.wrapper.train_and_test_on_datasets(
            train_set,
            test_set,
            optimizer,
            batch_size,
            epoch,
            use_cuda,
            regularizer=self.l2_reg,
            return_best_weights=True,
            patience=None,
            **kwargs
        )
        self.model.load_state_dict(weights)

        if double_fit:
            lr = self.lr / 10
            optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
            loss_history, weights = self.wrapper.train_and_test_on_datasets(
                train_set,
                test_set,
                optimizer,
                batch_size,
                epoch,
                use_cuda,
                regularizer=self.l2_reg,
                return_best_weights=True,
                patience=None,
                **kwargs
            )
            self.model.load_state_dict(weights)

        return loss_history, self.model.state_dict()

    @property
    def calibrated_model(self):
        return self.model

    @property
    def metrics(self):
        return self.wrapper.metrics

calibrate(train_set, test_set, batch_size, epoch, use_cuda, double_fit=False, kwargs)

Training the linear layer given a training set and a validation set. The training set should be different from what model is trained on.

Parameters:

Name Type Description Default
train_set Dataset

The training set.

required
test_set Dataset

The validation set.

required
batch_size int

Batch size used.

required
epoch int

Number of epochs to train the linear layer for.

required
use_cuda bool

If "True", will use GPU.

required
double_fit bool

If "True" would fit twice on the train set.

False
kwargs dict

Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().

{}

Returns:

Name Type Description
loss_history list[float]

List of loss values for each epoch.

model.state_dict (dict): Model weights.

Source code in baal/calibration/calibration.py
def calibrate(
    self,
    train_set: Dataset,
    test_set: Dataset,
    batch_size: int,
    epoch: int,
    use_cuda: bool,
    double_fit: bool = False,
    **kwargs
):
    """
    Training the linear layer given a training set and a validation set.
    The training set should be different from what model is trained on.

    Args:
        train_set (Dataset): The training set.
        test_set (Dataset): The validation set.
        batch_size (int): Batch size used.
        epoch (int): Number of epochs to train the linear layer for.
        use_cuda (bool): If "True", will use GPU.
        double_fit (bool): If "True" would fit twice on the train set.
        kwargs (dict): Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset().

    Returns:
        loss_history (list[float]): List of loss values for each epoch.
        model.state_dict (dict): Model weights.

    """

    # reinitialize the dirichlet calibration layer
    self.dirichlet_linear.weight.data.copy_(torch.eye(self.dirichlet_linear.weight.shape[0]))
    self.dirichlet_linear.bias.data.copy_(torch.zeros(*self.dirichlet_linear.bias.shape))
    if use_cuda:
        self.dirichlet_linear.cuda()

    optimizer = Adam(self.dirichlet_linear.parameters(), lr=self.lr)

    loss_history, weights = self.wrapper.train_and_test_on_datasets(
        train_set,
        test_set,
        optimizer,
        batch_size,
        epoch,
        use_cuda,
        regularizer=self.l2_reg,
        return_best_weights=True,
        patience=None,
        **kwargs
    )
    self.model.load_state_dict(weights)

    if double_fit:
        lr = self.lr / 10
        optimizer = Adam(self.dirichlet_linear.parameters(), lr=lr)
        loss_history, weights = self.wrapper.train_and_test_on_datasets(
            train_set,
            test_set,
            optimizer,
            batch_size,
            epoch,
            use_cuda,
            regularizer=self.l2_reg,
            return_best_weights=True,
            patience=None,
            **kwargs
        )
        self.model.load_state_dict(weights)

    return loss_history, self.model.state_dict()

l2_reg()

Using trainable layer's parameters for l2 regularization.

Returns:

Type Description

The regularization term for the linear layer.

Source code in baal/calibration/calibration.py
def l2_reg(self):
    """Using trainable layer's parameters for l2 regularization.

    Returns:
        The regularization term for the linear layer.
    """
    weight_p, bias_p = self.dirichlet_linear.parameters()
    w_l2_factor = weight_p.norm(2)
    b_l2_factor = bias_p.norm(2)
    return self.reg_factor * w_l2_factor + self.mu * b_l2_factor