Calibration Wrapper
baal.calibration.DirichletCalibrator
Bases: object
Adding a linear layer to a classifier model after the model is trained and train this new layer until convergence. Together with the linear layer, the model is now calibrated. Source: https://arxiv.org/abs/1910.12656 Code inspired from: https://github.com/dirichletcal/experiments_neurips
References
@article{kullbeyond, title={Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with Dirichlet calibration Supplementary material}, author={Kull, Meelis and Perello-Nieto, Miquel and K{"a}ngsepp, Markus and Silva Filho, Telmo and Song, Hao and Flach, Peter} }
Args:
wrapper (ModelWrapper): Provides training and testing methods.
num_classes (int): Number of classes in classification task.
lr (float): Learning rate.
reg_factor (float): Regularization factor for the linear layer weights.
mu (float): Regularization factor for the linear layer biases.
If not given, will be initialized by "l".
Source code in baal/calibration/calibration.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
calibrate(train_set, test_set, batch_size, epoch, use_cuda, double_fit=False, **kwargs)
Training the linear layer given a training set and a validation set. The training set should be different from what model is trained on.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_set |
Dataset
|
The training set. |
required |
test_set |
Dataset
|
The validation set. |
required |
batch_size |
int
|
Batch size used. |
required |
epoch |
int
|
Number of epochs to train the linear layer for. |
required |
use_cuda |
bool
|
If "True", will use GPU. |
required |
double_fit |
bool
|
If "True" would fit twice on the train set. |
False
|
kwargs |
dict
|
Rest of parameters for baal.ModelWrapper.train_and_test_on_dataset(). |
{}
|
Returns:
Name | Type | Description |
---|---|---|
loss_history |
list[float]
|
List of loss values for each epoch. |
model.state_dict (dict): Model weights. |
Source code in baal/calibration/calibration.py
l2_reg()
Using trainable layer's parameters for l2 regularization.
Returns:
Type | Description |
---|---|
The regularization term for the linear layer. |