Pytorch lightning
Pytorch Lightning Compatibility
baal.utils.pytorch_lightning.ResetCallback
Bases: Callback
Callback to reset the weights between active learning steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights |
dict
|
State dict of the model. |
required |
Notes
The weight should be deep copied beforehand.
Source code in baal/utils/pytorch_lightning.py
on_train_start(trainer, module)
baal.utils.pytorch_lightning.BaalTrainer
Bases: Trainer
Object that perform the training and active learning iteration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
ActiveLearningDataset
|
Dataset with some sample already labelled. |
required |
heuristic |
Heuristic
|
Heuristic from baal.active.heuristics. |
Random()
|
query_size |
int
|
Number of sample to label per step. |
1
|
max_sample |
int
|
Limit the number of sample used (-1 is no limit). |
required |
**kwargs |
Parameters forwarded to |
{}
|
Source code in baal/utils/pytorch_lightning.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
|
predict_on_dataset(model=None, dataloader=None, *args, **kwargs)
For documentation, see predict_on_dataset_generator
Source code in baal/utils/pytorch_lightning.py
predict_on_dataset_generator(model=None, dataloader=None, *args, **kwargs)
Predict on the pool loader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Model to be used in prediction. If None, will get the Trainer's model. |
None
|
|
dataloader |
Optional[DataLoader]
|
If provided, will predict on this dataloader. Otherwise, uses model.pool_dataloader(). |
None
|
Returns:
Type | Description |
---|---|
Numpy arrays with all the predictions. |
Source code in baal/utils/pytorch_lightning.py
step(model=None, datamodule=None)
Perform an active learning step.
model: Model to be used in prediction. If None, will get the Trainer's model. dataloader (Optional[DataLoader]): If provided, will predict on this dataloader. Otherwise, uses model.pool_dataloader().
Notes
This will get the pool from the model pool_dataloader and if max_sample is set, it will
require the data_loader sampler to select max_pool
samples.
Returns:
Type | Description |
---|---|
bool
|
boolean, Flag indicating if we continue training. |
Source code in baal/utils/pytorch_lightning.py
baal.utils.pytorch_lightning.BaaLDataModule
Bases: LightningDataModule
Source code in baal/utils/pytorch_lightning.py
pool_dataloader()
Create Dataloader for the pool of unlabelled examples.