Huggingface
HuggingFace Compatibility
baal.transformers_trainer_wrapper.BaalTransformersTrainer
Bases: Trainer
The purpose of this wrapper is to provide extra capabilities for HuggingFace Trainer, so that it can output several forward pass for samples in prediction time and hence be able to work with baal. For a more detailed description of the arguments refer to ( https://huggingface.co/transformers/v3.0.2/main_classes/trainer.html)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PreTrainedModel
|
The model to train, evaluate or use for predictions. |
None
|
replicate_in_memory |
If True, will perform MC-Dropout in a single forward pass. It is faster, but more memory expensive. Default: True. |
True
|
|
data_collator |
Optional(Callable
|
The function to use to from a batch. |
required |
train_dataset |
Optional(torch.utils.data.Dataset
|
The dataset to use for training. |
required |
eval_dataset |
Optional(torch.utils.data.Dataset
|
The dataset to use for evaluation. |
required |
tokenizer |
Optional(transformers.PreTrainedTokenizer
|
a tokenizer provided by huggingface. |
required |
model_init |
Optional(Dict
|
Model initial weights for fine tuning. |
required |
compute_metrics |
Optional(Callable[[EvalPrediction], Dict]
|
The function that will be used to compute metrics at evaluation. |
required |
callbacks |
Optional(List[transformers.TrainerCallback]
|
A list of callbacks to customize the training loop. |
required |
optimizers |
Optional(Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
|
A tuple containing the optimizer and the scheduler to use. |
required |
Source code in baal/transformers_trainer_wrapper.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
|
load_state_dict(state_dict, strict=True)
predict_on_dataset(dataset, iterations=1, half=False, ignore_keys=None)
Use the model to predict on a dataset iterations
time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
Dataset to predict on. |
required |
iterations |
int
|
Number of iterations per sample. |
1
|
half |
bool
|
If True use half precision. |
False
|
ignore_keys |
Optional[List[str]]
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. |
None
|
Notes:
The "batch" is made of batch_size
* iterations
samples.
Returns:
Type | Description |
---|---|
Array [n_samples, n_outputs, ..., n_iterations]. |
Source code in baal/transformers_trainer_wrapper.py
predict_on_dataset_generator(dataset, iterations=1, half=False, ignore_keys=None)
Use the model to predict on a dataset iterations
time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
Dataset to predict on. |
required |
iterations |
int
|
Number of iterations per sample. |
1
|
half |
bool
|
If True use half precision. |
False
|
ignore_keys |
Optional[List[str]]
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. |
None
|
Notes:
The "batch" is made of batch_size
* iterations
samples.
Returns:
Type | Description |
---|---|
Generators [batch_size, n_classes, ..., n_iterations]. |
Source code in baal/transformers_trainer_wrapper.py
baal.active.dataset.nlp_datasets.HuggingFaceDatasets
Bases: Dataset
Support for huggingface.datasets
: (https://github.com/huggingface/datasets).
The purpose of this wrapper is to separate the labels from the rest of the sample information
and make the dataset ready to be used by baal.active.ActiveLearningDataset
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
a dataset provided by huggingface. |
required |
tokenizer |
PreTrainedTokenizer
|
a tokenizer provided by huggingface. |
None
|
target_key |
str
|
target key used in the dataset's dictionary. |
'label'
|
input_key |
str
|
input key used in the dataset's dictionary. |
'sentence'
|
max_seq_len |
int
|
max length of a sequence to be used for padding the shorter sequences. |
128
|
Source code in baal/active/dataset/nlp_datasets.py
label(idx, value)
Label the item.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int
|
index to label |
required |
value |
int
|
Value to label the index. |
required |