Utilities
Metrics
To work with baal.modelwrapper.ModelWrapper
, we provide Metrics
.
Starting with Baal 1.7.0, users can use TorchMetrics as well.
Examples
from baal.modelwrapper import ModelWrapper
from baal.utils.metrics import Accuracy
from torchmetrics import F1Score
wrapper : ModelWrapper = ...
# You can add any metrics from `baal.utils.metrics`.
wrapper.add_metric(name='accuracy',initializer=lambda : Accuracy())
wrapper.add_metric(name='f1',initializer=lambda : F1Score())
# Metrics are automatically updated when training and evaluating.
wrapper.train_on_dataset(...)
wrapper.test_on_dataset(...)
print(wrapper.get_metrics())
"""
>>> {'dataset_size': 200,
'test_accuracy': 0.2603,
'test_f1': 0.1945,
'test_loss': 2.1901,
'train_accuracy': 0.3214,
'train_f1': 0.2531,
'train_loss': 2.1795}
"""
# Get metrics per dataset_size (state is kept for the entire loop.
print(wrapper.active_learning_metrics)
"""
>>> {200: {'dataset_size': 200,
'test_accuracy': 0.26038339734077454,
'test_loss': 2.190103769302368,
'train_accuracy': 0.3214285671710968,
'train_loss': 2.1795670986175537},
...
"""
baal.utils.metrics
Accuracy
Bases: Metrics
computes the top first and top five accuracy for the model batch by batch. Args: average (bool): a way to output one single value for metrics that are calculated in several trials. topk (tuple): the value of k for calculating the topk accuracy.
Source code in baal/utils/metrics.py
update(output=None, target=None)
Update TP and support.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
predictions of model |
None
|
target |
tensor
|
labels |
None
|
Source code in baal/utils/metrics.py
ClassificationReport
Bases: Metrics
Compute a classification report as a metric.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
int
|
the number of classes. |
required |
Source code in baal/utils/metrics.py
update(output=None, target=None)
Update the confusion matrice according to output and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
predictions of model |
None
|
target |
tensor
|
labels |
None
|
Source code in baal/utils/metrics.py
ECE
Bases: Metrics
Expected Calibration Error (ECE)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_bins |
int
|
number of bins to discretize the uncertainty. |
10
|
References
https://arxiv.org/pdf/1706.04599.pdf
Source code in baal/utils/metrics.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
|
plot(pth=None)
Plot each bins, ideally this would be a diagonal line.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pth |
str
|
if provided the figure will be saved under the given path |
None
|
Source code in baal/utils/metrics.py
update(output=None, target=None)
Updating the true positive (tp) and number of samples in each bin.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
logits or predictions of model |
None
|
target |
tensor
|
labels |
None
|
Source code in baal/utils/metrics.py
ECE_PerCLs
Bases: Metrics
Expected Calibration Error (ECE)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_cls |
int
|
number of existing target classes |
required |
n_bins |
int
|
number of bins to discretize the uncertainty. |
10
|
References
https://arxiv.org/pdf/1706.04599.pdf
Source code in baal/utils/metrics.py
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
|
calculate_result()
calculates the ece per class.
Returns:
Name | Type | Description |
---|---|---|
ece |
array
|
ece value per class |
Source code in baal/utils/metrics.py
plot(pth=None)
Plot each bins, ideally this would be a diagonal line.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pth |
str
|
if provided the figure will be saved under the given path |
None
|
Source code in baal/utils/metrics.py
update(output=None, target=None)
Updating the true positive (tp) and number of samples in each bin.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
logits or predictions of model |
None
|
target |
tensor
|
labels |
None
|
Source code in baal/utils/metrics.py
Loss
Bases: Metrics
Parameters:
Name | Type | Description | Default |
---|---|---|---|
average |
bool
|
a way to output one single value for metrics that are calculated in several trials. |
True
|
Source code in baal/utils/metrics.py
Metrics
metric is an abstract class. Args: average (bool): a way to output one single value for metrics that are calculated in several trials.
Source code in baal/utils/metrics.py
standard_dev
property
Return the standard deviation of the metric.
value
property
output the metric results (array shape) or averaging out over the results to output one single float number.
Returns:
Name | Type | Description |
---|---|---|
result |
array / float
|
final metric result |
calculate_result()
reset()
update(output=None, target=None)
Main calculation of the metric which updated the private values respectively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
predictions of model |
None
|
target |
tensor
|
labels |
None
|
PRAuC
Bases: Metrics
Precision-Recall Area under the curve.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
int
|
Number of classes |
required |
n_bins |
int
|
number of confidence threshold to evaluate on. |
required |
average |
bool
|
If true will return the mean AuC of all classes. |
required |
Source code in baal/utils/metrics.py
update(output=None, target=None)
Update the confusion matrice according to output and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
predictions of model |
None
|
target |
tensor
|
labels |
None
|
Source code in baal/utils/metrics.py
Precision
Bases: Metrics
computes the precision for each class over epochs. Args: num_classes (int): number of classes. average (bool): a way to output one single value for metrics that are calculated in several trials.
Source code in baal/utils/metrics.py
update(output=None, target=None)
Update tp, fp and support acoording to output and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output |
tensor
|
predictions of model |
None
|
target |
tensor
|
labels |
None
|