Huggingface
HuggingFace Compatibility
baal.transformers_trainer_wrapper.BaalTransformersTrainer
Bases: Trainer
The purpose of this wrapper is to provide extra capabilities for HuggingFace Trainer, so that it can output several forward pass for samples in prediction time and hence be able to work with baal. For a more detailed description of the arguments refer to ( https://huggingface.co/transformers/v3.0.2/main_classes/trainer.html)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
transformers.PreTrainedModel
|
The model to train, evaluate or use for predictions. |
required |
data_collator |
Optional(Callable
|
The function to use to from a batch. |
required |
train_dataset |
Optional(torch.utils.data.Dataset
|
The dataset to use for training. |
required |
eval_dataset |
Optional(torch.utils.data.Dataset
|
The dataset to use for evaluation. |
required |
tokenizer |
Optional(transformers.PreTrainedTokenizer
|
a tokenizer provided by huggingface. |
required |
model_init |
Optional(Dict
|
Model initial weights for fine tuning. |
required |
compute_metrics |
Optional(Callable[[EvalPrediction], Dict]
|
The function that will be used to compute metrics at evaluation. |
required |
callbacks |
Optional(List[transformers.TrainerCallback]
|
A list of callbacks to customize the training loop. |
required |
optimizers |
Optional(Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
|
A tuple containing the optimizer and the scheduler to use. |
required |
Source code in baal/transformers_trainer_wrapper.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|
load_state_dict(state_dict, strict=True)
predict_on_dataset(dataset, iterations=1, half=False, ignore_keys=None)
Use the model to predict on a dataset iterations
time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
Dataset to predict on. |
required |
iterations |
int
|
Number of iterations per sample. |
1
|
half |
bool
|
If True use half precision. |
False
|
ignore_keys |
Optional[List[str]]
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. |
None
|
Notes
The "batch" is made of batch_size
* iterations
samples.
Returns:
Type | Description |
---|---|
Array [n_samples, n_outputs, ..., n_iterations]. |
Source code in baal/transformers_trainer_wrapper.py
predict_on_dataset_generator(dataset, iterations=1, half=False, ignore_keys=None)
Use the model to predict on a dataset iterations
time.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
Dataset to predict on. |
required |
iterations |
int
|
Number of iterations per sample. |
1
|
half |
bool
|
If True use half precision. |
False
|
ignore_keys |
Optional[List[str]]
|
A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. |
None
|
Notes
The "batch" is made of batch_size
* iterations
samples.
Returns:
Type | Description |
---|---|
Generators [batch_size, n_classes, ..., n_iterations]. |
Source code in baal/transformers_trainer_wrapper.py
baal.active.dataset.nlp_datasets.HuggingFaceDatasets
Bases: Dataset
Support for huggingface.datasets
: (https://github.com/huggingface/datasets).
The purpose of this wrapper is to separate the labels from the rest of the sample information
and make the dataset ready to be used by baal.active.ActiveLearningDataset
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset
|
a dataset provided by huggingface. |
required |
tokenizer |
transformers.PreTrainedTokenizer
|
a tokenizer provided by huggingface. |
None
|
target_key |
str
|
target key used in the dataset's dictionary. |
'label'
|
input_key |
str
|
input key used in the dataset's dictionary. |
'sentence'
|
max_seq_len |
int
|
max length of a sequence to be used for padding the shorter sequences. |
128
|
Source code in baal/active/dataset/nlp_datasets.py
label(idx, value)
Label the item.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int
|
index to label |
required |
value |
int
|
Value to label the index. |
required |