Add UnlabeledTextProcessor (#2054)

* add UnlabeledTextProcessor

* allow choosing processor when finetuning or distilling

* fix type hint

* Add latest docstring and tutorial changes

* improve segment id computation for UnlabeledTextProcessor

* add text and documentation

* change batch size parameter for intermediate layer distillation

* Add latest docstring and tutorial changes

* fix distillation dim mapping

* remove unnecessary changes

* removed confusing parameter

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
MichelBartels 2022-01-25 14:54:34 +01:00 committed by GitHub
parent c6f23dce88
commit 5b6b0cef77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 93 additions and 32 deletions

View File

@ -110,7 +110,6 @@ and that FARM includes no_answer in the sorted list of predictions.
```
Fine-tune a model on a QA dataset. Options:
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
@ -152,6 +151,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset
- `cache_path`: Path to cache the preprocessed dataset
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
**Returns**:
@ -170,12 +170,10 @@ using a more complex teacher.
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -227,6 +225,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
- `tinybert_learning_rate`: Learning rate to use when training the student model with the TinyBERT loss function.
- `tinybert_train_filename`: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
**Returns**:
@ -236,17 +235,15 @@ None
#### distil\_intermediate\_layers\_from
```python
| distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0)
| distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, batch_size: int = 10, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0, processor: Optional[Processor] = None)
```
The first stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -294,6 +291,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
**Returns**:

View File

@ -14,6 +14,9 @@ from pathlib import Path
from io import StringIO
from typing import Optional, Dict, List, Union, Any, Iterable
import torch
from torch.utils.data import TensorDataset
import pandas as pd
import numpy as np
from haystack.modeling.model.tokenization import (
@ -1973,6 +1976,38 @@ class InferenceProcessor(TextClassificationProcessor):
)
return features
class UnlabeledTextProcessor(Processor):
"""
Processor to be used for distilling a teacher model into a student model from scratch. Can only be used with distil_intermediate_layers_from.
"""
def __init__(self, tokenizer, max_seq_len: int, train_filename: Optional[Union[Path, str]] = None, dev_filename: Optional[Union[Path, str]] = None, test_filename: Optional[Union[Path, str]] = None, dev_split: float = 0, data_dir: Optional[Union[Path, str]] = None, tasks: Dict = {}, proxies: Optional[Dict] = None, multithreading_rust: Optional[bool] = True):
super().__init__(tokenizer, max_seq_len, train_filename, dev_filename, test_filename, dev_split, data_dir, tasks, proxies, multithreading_rust)
self.add_task("question_answering", "squad", ["start_token", "end_token"])
def file_to_dicts(self, file: str) -> List[dict]:
dicts = []
with open(file, "r") as f:
for line in f:
dicts.append({"text": line})
return dicts
def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
if return_baskets:
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
texts = [dict_["text"] for dict_ in dicts]
tokens = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_len)
names = [key for key in tokens]
inputs = [tokens[key] for key in tokens]
if not "padding_mask" in names:
index = names.index("attention_mask")
names[index] = "padding_mask"
if not "segment_ids" in names:
index = names.index("token_type_ids")
names[index] = "segment_ids"
dataset = TensorDataset(*inputs)
return dataset, names, []
def _create_dataset(self, baskets:List[SampleBasket]):
raise NotImplementedError("_create_dataset is not supported by UnlabeledTextProcessor")
# helper fcts
def write_squad_predictions(predictions, out_filename, predictions_filename=None):

View File

@ -854,7 +854,7 @@ class TinyBERTDistillationTrainer(Trainer):
for teacher_dim, student_dim in zip(teacher_dims, student_dims):
if teacher_dim != student_dim:
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False))
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False).to(device))
else:
self.dim_mappings.append(None)

View File

@ -8,7 +8,7 @@ from time import perf_counter
import torch
from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo
from haystack.modeling.data_handler.processor import SquadProcessor
from haystack.modeling.data_handler.processor import SquadProcessor, Processor
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.inputs import QAInput, Question
from haystack.modeling.infer import QAInferencer
@ -183,6 +183,7 @@ class FARMReader(BaseReader):
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
temperature: float = 1.0,
tinybert: bool = False,
processor: Optional[Processor] = None,
):
if dev_filename:
dev_split = 0
@ -209,17 +210,18 @@ class FARMReader(BaseReader):
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
label_list = ["start_token", "end_token"]
metric = "squad"
processor = SquadProcessor(
tokenizer=self.inferencer.processor.tokenizer,
max_seq_len=max_seq_len,
label_list=label_list,
metric=metric,
train_filename=train_filename,
dev_filename=dev_filename,
dev_split=dev_split,
test_filename=test_filename,
data_dir=Path(data_dir),
)
if processor is None:
processor = SquadProcessor(
tokenizer=self.inferencer.processor.tokenizer,
max_seq_len=max_seq_len,
label_list=label_list,
metric=metric,
train_filename=train_filename,
dev_filename=dev_filename,
dev_split=dev_split,
test_filename=test_filename,
data_dir=Path(data_dir),
)
data_silo: DataSilo
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
@ -323,11 +325,10 @@ class FARMReader(BaseReader):
checkpoint_every: Optional[int] = None,
checkpoints_to_keep: int = 3,
caching: bool = False,
cache_path: Path = Path("cache/data_silo")
cache_path: Path = Path("cache/data_silo"),
):
"""
Fine-tune a model on a QA dataset. Options:
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
@ -367,6 +368,7 @@ class FARMReader(BaseReader):
:param checkpoints_to_keep: maximum number of train checkpoints to save.
:param caching whether or not to use caching for preprocessed dataset
:param cache_path: Path to cache the preprocessed dataset
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
:return: None
"""
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
@ -415,12 +417,10 @@ class FARMReader(BaseReader):
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -470,6 +470,7 @@ class FARMReader(BaseReader):
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
:param tinybert_learning_rate: Learning rate to use when training the student model with the TinyBERT loss function.
:param tinybert_train_filename: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
:return: None
"""
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
@ -493,8 +494,7 @@ class FARMReader(BaseReader):
dev_filename: Optional[str] = None,
test_filename: Optional[str] = None,
use_gpu: Optional[bool] = None,
student_batch_size: int = 10,
teacher_batch_size: Optional[int] = None,
batch_size: int = 10,
n_epochs: int = 5,
learning_rate: float = 5e-5,
max_seq_len: Optional[int] = None,
@ -511,16 +511,15 @@ class FARMReader(BaseReader):
cache_path: Path = Path("cache/data_silo"),
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse",
temperature: float = 1.0,
processor: Optional[Processor] = None,
):
"""
The first stage of distillation finetuning as described in the TinyBERT paper:
https://arxiv.org/pdf/1909.10351.pdf
**Example**
```python
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
```
@ -566,20 +565,22 @@ class FARMReader(BaseReader):
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
:return: None
"""
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
dev_filename=dev_filename, test_filename=test_filename,
use_gpu=use_gpu, batch_size=student_batch_size,
use_gpu=use_gpu, batch_size=batch_size,
n_epochs=n_epochs, learning_rate=learning_rate,
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
dev_split=dev_split, evaluate_every=evaluate_every,
save_dir=save_dir, num_processes=num_processes,
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
teacher_model=teacher_model, teacher_batch_size=batch_size,
caching=caching, cache_path=cache_path,
distillation_loss=distillation_loss, temperature=temperature, tinybert=True)
distillation_loss=distillation_loss, temperature=temperature, tinybert=True,
processor=processor)
def update_parameters(
self,

View File

@ -1,4 +1,5 @@
from haystack.nodes import FARMReader
from haystack.modeling.data_handler.processor import UnlabeledTextProcessor
import torch
def create_checkpoint(model):
@ -12,7 +13,7 @@ def assert_weight_change(weights, new_weights):
print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)])
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights))
def test_distillation():
def test_prediction_layer_distillation():
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
@ -35,7 +36,7 @@ def test_distillation():
# check if weights have changed
assert_weight_change(student_weights, new_student_weights)
def test_tinybert_distillation():
def test_intermediate_layer_distillation():
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
teacher = FARMReader(model_name_or_path="bert-base-uncased")
@ -57,5 +58,31 @@ def test_tinybert_distillation():
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
new_student_weights.pop(-1) # pooler is not updated due to different attention head
# check if weights have changed
assert_weight_change(student_weights, new_student_weights)
def test_intermediate_layer_distillation_from_scratch():
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
teacher = FARMReader(model_name_or_path="bert-base-uncased")
# create a checkpoint of weights before distillation
student_weights = create_checkpoint(student)
assert len(student_weights) == 38
student_weights.pop(-1) # last layer is not affected by tinybert loss
student_weights.pop(-1) # pooler is not updated due to different attention head
processor = UnlabeledTextProcessor(tokenizer=teacher.inferencer.processor.tokenizer, max_seq_len=128, train_filename="doc_2.txt", data_dir="samples/docs")
student.distil_intermediate_layers_from(teacher_model=teacher, data_dir="samples/squad", train_filename="tiny.json", processor=processor)
# create new checkpoint
new_student_weights = create_checkpoint(student)
assert len(new_student_weights) == 38
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
new_student_weights.pop(-1) # pooler is not updated due to different attention head
# check if weights have changed
assert_weight_change(student_weights, new_student_weights)