mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
Add UnlabeledTextProcessor (#2054)
* add UnlabeledTextProcessor * allow choosing processor when finetuning or distilling * fix type hint * Add latest docstring and tutorial changes * improve segment id computation for UnlabeledTextProcessor * add text and documentation * change batch size parameter for intermediate layer distillation * Add latest docstring and tutorial changes * fix distillation dim mapping * remove unnecessary changes * removed confusing parameter * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
c6f23dce88
commit
5b6b0cef77
@ -110,7 +110,6 @@ and that FARM includes no_answer in the sorted list of predictions.
|
||||
```
|
||||
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
||||
|
||||
@ -152,6 +151,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
|
||||
- `checkpoints_to_keep`: maximum number of train checkpoints to save.
|
||||
:param caching whether or not to use caching for preprocessed dataset
|
||||
- `cache_path`: Path to cache the preprocessed dataset
|
||||
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
|
||||
**Returns**:
|
||||
|
||||
@ -170,12 +170,10 @@ using a more complex teacher.
|
||||
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
|
||||
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
|
||||
https://arxiv.org/pdf/1909.10351.pdf
|
||||
|
||||
**Example**
|
||||
```python
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
|
||||
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
|
||||
|
||||
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
|
||||
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
|
||||
```
|
||||
@ -227,6 +225,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
|
||||
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
|
||||
- `tinybert_learning_rate`: Learning rate to use when training the student model with the TinyBERT loss function.
|
||||
- `tinybert_train_filename`: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
|
||||
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
|
||||
**Returns**:
|
||||
|
||||
@ -236,17 +235,15 @@ None
|
||||
#### distil\_intermediate\_layers\_from
|
||||
|
||||
```python
|
||||
| distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0)
|
||||
| distil_intermediate_layers_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, batch_size: int = 10, n_epochs: int = 5, learning_rate: float = 5e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse", temperature: float = 1.0, processor: Optional[Processor] = None)
|
||||
```
|
||||
|
||||
The first stage of distillation finetuning as described in the TinyBERT paper:
|
||||
https://arxiv.org/pdf/1909.10351.pdf
|
||||
|
||||
**Example**
|
||||
```python
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
|
||||
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
|
||||
|
||||
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
|
||||
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
|
||||
```
|
||||
@ -294,6 +291,7 @@ If any checkpoints are stored, a subsequent run of train() will resume training
|
||||
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
- `processor`: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
|
||||
**Returns**:
|
||||
|
||||
|
||||
@ -14,6 +14,9 @@ from pathlib import Path
|
||||
from io import StringIO
|
||||
from typing import Optional, Dict, List, Union, Any, Iterable
|
||||
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from haystack.modeling.model.tokenization import (
|
||||
@ -1973,6 +1976,38 @@ class InferenceProcessor(TextClassificationProcessor):
|
||||
)
|
||||
return features
|
||||
|
||||
class UnlabeledTextProcessor(Processor):
|
||||
"""
|
||||
Processor to be used for distilling a teacher model into a student model from scratch. Can only be used with distil_intermediate_layers_from.
|
||||
"""
|
||||
def __init__(self, tokenizer, max_seq_len: int, train_filename: Optional[Union[Path, str]] = None, dev_filename: Optional[Union[Path, str]] = None, test_filename: Optional[Union[Path, str]] = None, dev_split: float = 0, data_dir: Optional[Union[Path, str]] = None, tasks: Dict = {}, proxies: Optional[Dict] = None, multithreading_rust: Optional[bool] = True):
|
||||
super().__init__(tokenizer, max_seq_len, train_filename, dev_filename, test_filename, dev_split, data_dir, tasks, proxies, multithreading_rust)
|
||||
self.add_task("question_answering", "squad", ["start_token", "end_token"])
|
||||
def file_to_dicts(self, file: str) -> List[dict]:
|
||||
dicts = []
|
||||
with open(file, "r") as f:
|
||||
for line in f:
|
||||
dicts.append({"text": line})
|
||||
return dicts
|
||||
|
||||
def dataset_from_dicts(self, dicts: List[dict], indices: Optional[List[int]] = None, return_baskets: bool = False):
|
||||
if return_baskets:
|
||||
raise NotImplementedError("return_baskets is not supported by UnlabeledTextProcessor")
|
||||
texts = [dict_["text"] for dict_ in dicts]
|
||||
tokens = self.tokenizer.batch_encode_plus(texts, add_special_tokens=True, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_seq_len)
|
||||
names = [key for key in tokens]
|
||||
inputs = [tokens[key] for key in tokens]
|
||||
if not "padding_mask" in names:
|
||||
index = names.index("attention_mask")
|
||||
names[index] = "padding_mask"
|
||||
if not "segment_ids" in names:
|
||||
index = names.index("token_type_ids")
|
||||
names[index] = "segment_ids"
|
||||
|
||||
dataset = TensorDataset(*inputs)
|
||||
return dataset, names, []
|
||||
def _create_dataset(self, baskets:List[SampleBasket]):
|
||||
raise NotImplementedError("_create_dataset is not supported by UnlabeledTextProcessor")
|
||||
|
||||
# helper fcts
|
||||
def write_squad_predictions(predictions, out_filename, predictions_filename=None):
|
||||
|
||||
@ -854,7 +854,7 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
|
||||
for teacher_dim, student_dim in zip(teacher_dims, student_dims):
|
||||
if teacher_dim != student_dim:
|
||||
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False))
|
||||
self.dim_mappings.append(Linear(student_dim, teacher_dim, bias=False).to(device))
|
||||
else:
|
||||
self.dim_mappings.append(None)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from time import perf_counter
|
||||
import torch
|
||||
|
||||
from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo
|
||||
from haystack.modeling.data_handler.processor import SquadProcessor
|
||||
from haystack.modeling.data_handler.processor import SquadProcessor, Processor
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||
from haystack.modeling.infer import QAInferencer
|
||||
@ -183,6 +183,7 @@ class FARMReader(BaseReader):
|
||||
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div",
|
||||
temperature: float = 1.0,
|
||||
tinybert: bool = False,
|
||||
processor: Optional[Processor] = None,
|
||||
):
|
||||
if dev_filename:
|
||||
dev_split = 0
|
||||
@ -209,17 +210,18 @@ class FARMReader(BaseReader):
|
||||
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||
label_list = ["start_token", "end_token"]
|
||||
metric = "squad"
|
||||
processor = SquadProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
dev_split=dev_split,
|
||||
test_filename=test_filename,
|
||||
data_dir=Path(data_dir),
|
||||
)
|
||||
if processor is None:
|
||||
processor = SquadProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
dev_split=dev_split,
|
||||
test_filename=test_filename,
|
||||
data_dir=Path(data_dir),
|
||||
)
|
||||
data_silo: DataSilo
|
||||
|
||||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||
@ -323,11 +325,10 @@ class FARMReader(BaseReader):
|
||||
checkpoint_every: Optional[int] = None,
|
||||
checkpoints_to_keep: int = 3,
|
||||
caching: bool = False,
|
||||
cache_path: Path = Path("cache/data_silo")
|
||||
cache_path: Path = Path("cache/data_silo"),
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a QA dataset. Options:
|
||||
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for QA (e.g. on SQuAD data)
|
||||
- Take a QA model (e.g. `deepset/bert-base-cased-squad2`) and fine-tune it for your domain (e.g. using your labels collected via the haystack annotation tool)
|
||||
|
||||
@ -367,6 +368,7 @@ class FARMReader(BaseReader):
|
||||
:param checkpoints_to_keep: maximum number of train checkpoints to save.
|
||||
:param caching whether or not to use caching for preprocessed dataset
|
||||
:param cache_path: Path to cache the preprocessed dataset
|
||||
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
:return: None
|
||||
"""
|
||||
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
|
||||
@ -415,12 +417,10 @@ class FARMReader(BaseReader):
|
||||
Originally proposed in: https://arxiv.org/pdf/1503.02531.pdf
|
||||
This can also be considered as the second stage of distillation finetuning as described in the TinyBERT paper:
|
||||
https://arxiv.org/pdf/1909.10351.pdf
|
||||
|
||||
**Example**
|
||||
```python
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
|
||||
teacher = FARMReader(model_name_or_path="deepset/bert-large-uncased-whole-word-masking-squad2")
|
||||
|
||||
student.distil_prediction_layer_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
|
||||
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
|
||||
```
|
||||
@ -470,6 +470,7 @@ class FARMReader(BaseReader):
|
||||
:param tinybert_epochs: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.
|
||||
:param tinybert_learning_rate: Learning rate to use when training the student model with the TinyBERT loss function.
|
||||
:param tinybert_train_filename: Filename of training data to use when training the student model with the TinyBERT loss function. To best follow the original paper, this should be an augmented version of the training data created using the augment_squad.py script. If not specified, the training data from the original training is used.
|
||||
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
:return: None
|
||||
"""
|
||||
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
|
||||
@ -493,8 +494,7 @@ class FARMReader(BaseReader):
|
||||
dev_filename: Optional[str] = None,
|
||||
test_filename: Optional[str] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
student_batch_size: int = 10,
|
||||
teacher_batch_size: Optional[int] = None,
|
||||
batch_size: int = 10,
|
||||
n_epochs: int = 5,
|
||||
learning_rate: float = 5e-5,
|
||||
max_seq_len: Optional[int] = None,
|
||||
@ -511,16 +511,15 @@ class FARMReader(BaseReader):
|
||||
cache_path: Path = Path("cache/data_silo"),
|
||||
distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "mse",
|
||||
temperature: float = 1.0,
|
||||
processor: Optional[Processor] = None,
|
||||
):
|
||||
"""
|
||||
The first stage of distillation finetuning as described in the TinyBERT paper:
|
||||
https://arxiv.org/pdf/1909.10351.pdf
|
||||
|
||||
**Example**
|
||||
```python
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-medium")
|
||||
teacher = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D")
|
||||
|
||||
student.distil_intermediate_layers_from(teacher, data_dir="squad2", train_filename="train.json", test_filename="dev.json",
|
||||
learning_rate=3e-5, distillation_loss_weight=1.0, temperature=5)
|
||||
```
|
||||
@ -566,20 +565,22 @@ class FARMReader(BaseReader):
|
||||
:param distillation_loss_weight: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
|
||||
:param distillation_loss: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
|
||||
:param temperature: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
|
||||
:param processor: The processor to use for preprocessing. If None, the default SquadProcessor is used.
|
||||
:return: None
|
||||
"""
|
||||
return self._training_procedure(data_dir=data_dir, train_filename=train_filename,
|
||||
dev_filename=dev_filename, test_filename=test_filename,
|
||||
use_gpu=use_gpu, batch_size=student_batch_size,
|
||||
use_gpu=use_gpu, batch_size=batch_size,
|
||||
n_epochs=n_epochs, learning_rate=learning_rate,
|
||||
max_seq_len=max_seq_len, warmup_proportion=warmup_proportion,
|
||||
dev_split=dev_split, evaluate_every=evaluate_every,
|
||||
save_dir=save_dir, num_processes=num_processes,
|
||||
use_amp=use_amp, checkpoint_root_dir=checkpoint_root_dir,
|
||||
checkpoint_every=checkpoint_every, checkpoints_to_keep=checkpoints_to_keep,
|
||||
teacher_model=teacher_model, teacher_batch_size=teacher_batch_size,
|
||||
teacher_model=teacher_model, teacher_batch_size=batch_size,
|
||||
caching=caching, cache_path=cache_path,
|
||||
distillation_loss=distillation_loss, temperature=temperature, tinybert=True)
|
||||
distillation_loss=distillation_loss, temperature=temperature, tinybert=True,
|
||||
processor=processor)
|
||||
|
||||
def update_parameters(
|
||||
self,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from haystack.nodes import FARMReader
|
||||
from haystack.modeling.data_handler.processor import UnlabeledTextProcessor
|
||||
import torch
|
||||
|
||||
def create_checkpoint(model):
|
||||
@ -12,7 +13,7 @@ def assert_weight_change(weights, new_weights):
|
||||
print([torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights)])
|
||||
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(weights, new_weights))
|
||||
|
||||
def test_distillation():
|
||||
def test_prediction_layer_distillation():
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
|
||||
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
|
||||
|
||||
@ -35,7 +36,7 @@ def test_distillation():
|
||||
# check if weights have changed
|
||||
assert_weight_change(student_weights, new_student_weights)
|
||||
|
||||
def test_tinybert_distillation():
|
||||
def test_intermediate_layer_distillation():
|
||||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
|
||||
teacher = FARMReader(model_name_or_path="bert-base-uncased")
|
||||
|
||||
@ -57,5 +58,31 @@ def test_tinybert_distillation():
|
||||
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
|
||||
new_student_weights.pop(-1) # pooler is not updated due to different attention head
|
||||
|
||||
# check if weights have changed
|
||||
assert_weight_change(student_weights, new_student_weights)
|
||||
|
||||
def test_intermediate_layer_distillation_from_scratch():
|
||||
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
|
||||
teacher = FARMReader(model_name_or_path="bert-base-uncased")
|
||||
|
||||
# create a checkpoint of weights before distillation
|
||||
student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(student_weights) == 38
|
||||
|
||||
student_weights.pop(-1) # last layer is not affected by tinybert loss
|
||||
student_weights.pop(-1) # pooler is not updated due to different attention head
|
||||
|
||||
processor = UnlabeledTextProcessor(tokenizer=teacher.inferencer.processor.tokenizer, max_seq_len=128, train_filename="doc_2.txt", data_dir="samples/docs")
|
||||
student.distil_intermediate_layers_from(teacher_model=teacher, data_dir="samples/squad", train_filename="tiny.json", processor=processor)
|
||||
|
||||
# create new checkpoint
|
||||
new_student_weights = create_checkpoint(student)
|
||||
|
||||
assert len(new_student_weights) == 38
|
||||
|
||||
new_student_weights.pop(-1) # last layer is not affected by tinybert loss
|
||||
new_student_weights.pop(-1) # pooler is not updated due to different attention head
|
||||
|
||||
# check if weights have changed
|
||||
assert_weight_change(student_weights, new_student_weights)
|
||||
Loading…
x
Reference in New Issue
Block a user