Add Generative Pseudo Labeling (#2388)

This commit is contained in:
Vladimir Blagojevic 2022-06-02 16:12:47 +02:00 committed by GitHub
parent 61d9429c25
commit e10a3fba74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 733 additions and 11 deletions

View File

@ -23,7 +23,7 @@ come from earlier in the document.
#### QuestionGenerator.\_\_init\_\_
```python
def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", batch_size: Optional[int] = None)
def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", num_queries_per_doc=1, batch_size: Optional[int] = None)
```
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is

View File

@ -1433,6 +1433,45 @@ Create embeddings for a list of documents.
Embeddings, one per input document
<a id="dense.EmbeddingRetriever.train"></a>
#### EmbeddingRetriever.train
```python
def train(training_data: List[Dict[str, Any]], learning_rate: float = 2e-5, n_epochs: int = 1, num_warmup_steps: int = None, batch_size: int = 16) -> None
```
Trains/adapts the underlying embedding model.
Each training data example is a dictionary with the following keys:
* question: the question string
* pos_doc: the positive document string
* neg_doc: the negative document string
* score: the score margin
**Arguments**:
- `training_data` (`List[Dict[str, Any]]`): The training data
- `learning_rate` (`float`): The learning rate
- `n_epochs` (`int`): The number of epochs
- `num_warmup_steps` (`int`): The number of warmup steps
- `batch_size` (`int (optional)`): The batch size to use for the training, defaults to 16
<a id="dense.EmbeddingRetriever.save"></a>
#### EmbeddingRetriever.save
```python
def save(save_dir: Union[Path, str]) -> None
```
Save the model to the given directory
**Arguments**:
- `save_dir` (`Union[Path, str]`): The directory where the model will be saved
<a id="text2sparql"></a>
# Module text2sparql

View File

@ -79,6 +79,7 @@ from haystack.nodes import (
retriever,
summarizer,
translator,
label_generator,
)
# Note that we ignore the ImportError here because if the user did not install

View File

@ -110,6 +110,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
{
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
},
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@ -2439,6 +2442,75 @@
],
"additionalProperties": false
},
"PseudoLabelGeneratorComponent": {
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "Custom name for the component. Helpful for visualization and debugging.",
"type": "string"
},
"type": {
"title": "Type",
"description": "Haystack Class name for the component.",
"type": "string",
"const": "PseudoLabelGenerator"
},
"params": {
"title": "Parameters",
"type": "object",
"properties": {
"question_producer": {
"title": "Question Producer",
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"type": "string"
}
}
}
]
},
"retriever": {
"title": "Retriever",
"type": "string"
},
"cross_encoder_model_name_or_path": {
"title": "Cross Encoder Model Name Or Path",
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"type": "string"
},
"total_number_of_questions": {
"title": "Total Number Of Questions",
"default": 9223372036854775807,
"type": "integer"
},
"top_k": {
"title": "Top K",
"default": 10,
"type": "integer"
}
},
"required": [
"question_producer",
"retriever"
],
"additionalProperties": false,
"description": "Each parameter can reference other components defined in the same YAML file."
}
},
"required": [
"type",
"name"
],
"additionalProperties": false
},
"QuestionGeneratorComponent": {
"type": "object",
"properties": {

View File

@ -113,6 +113,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
{
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
},
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@ -2618,6 +2621,75 @@
],
"additionalProperties": false
},
"PseudoLabelGeneratorComponent": {
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "Custom name for the component. Helpful for visualization and debugging.",
"type": "string"
},
"type": {
"title": "Type",
"description": "Haystack Class name for the component.",
"type": "string",
"const": "PseudoLabelGenerator"
},
"params": {
"title": "Parameters",
"type": "object",
"properties": {
"question_producer": {
"title": "Question Producer",
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"type": "string"
}
}
}
]
},
"retriever": {
"title": "Retriever",
"type": "string"
},
"cross_encoder_model_name_or_path": {
"title": "Cross Encoder Model Name Or Path",
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"type": "string"
},
"total_number_of_questions": {
"title": "Total Number Of Questions",
"default": 9223372036854775807,
"type": "integer"
},
"top_k": {
"title": "Top K",
"default": 10,
"type": "integer"
}
},
"required": [
"question_producer",
"retriever"
],
"additionalProperties": false,
"description": "Each parameter can reference other components defined in the same YAML file."
}
},
"required": [
"type",
"name"
],
"additionalProperties": false
},
"QuestionGeneratorComponent": {
"type": "object",
"properties": {

View File

@ -127,6 +127,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
{
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
},
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@ -3193,6 +3196,85 @@
],
"additionalProperties": false
},
"PseudoLabelGeneratorComponent": {
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "Custom name for the component. Helpful for visualization and debugging.",
"type": "string"
},
"type": {
"title": "Type",
"description": "Haystack Class name for the component.",
"type": "string",
"const": "PseudoLabelGenerator"
},
"params": {
"title": "Parameters",
"type": "object",
"properties": {
"question_producer": {
"title": "Question Producer",
"anyOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"type": "object",
"additionalProperties": {
"type": "string"
}
}
}
]
},
"retriever": {
"title": "Retriever",
"type": "string"
},
"cross_encoder_model_name_or_path": {
"title": "Cross Encoder Model Name Or Path",
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"type": "string"
},
"max_questions_per_document": {
"title": "Max Questions Per Document",
"default": 3,
"type": "integer"
},
"top_k": {
"title": "Top K",
"default": 50,
"type": "integer"
},
"batch_size": {
"title": "Batch Size",
"default": 4,
"type": "integer"
},
"progress_bar": {
"title": "Progress Bar",
"default": true,
"type": "boolean"
}
},
"required": [
"question_producer",
"retriever"
],
"additionalProperties": false,
"description": "Each parameter can reference other components defined in the same YAML file."
}
},
"required": [
"type",
"name"
],
"additionalProperties": false
},
"QuestionGeneratorComponent": {
"type": "object",
"properties": {
@ -3254,6 +3336,10 @@
"title": "Prompt",
"default": "generate questions:"
},
"num_queries_per_doc": {
"title": "Num Queries Per Doc",
"default": 1
},
"batch_size": {
"title": "Batch Size",
"type": "integer"

View File

@ -20,6 +20,7 @@ from haystack.nodes.file_converter import (
AzureConverter,
ParsrConverter,
)
from haystack.nodes.label_generator import PseudoLabelGenerator
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier

View File

@ -0,0 +1 @@
from haystack.nodes.label_generator.pseudo_label_generator import PseudoLabelGenerator

View File

@ -0,0 +1,233 @@
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
from sentence_transformers import CrossEncoder
from tqdm.auto import tqdm
from haystack.nodes.base import BaseComponent
from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.retriever.base import BaseRetriever
from haystack.schema import Document
class PseudoLabelGenerator(BaseComponent):
"""
The PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the
training of dense retrievers.
GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question
generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access
to an unlabeled target corpus, usually via DocumentStore and a retriever to mine for negatives.
For more details see [https://github.com/UKPLab/gpl](https://github.com/UKPLab/gpl)
For example:
```python
| document_store = DocumentStore(...)
| retriever = Retriever(...)
| qg = QuestionGenerator(model_name_or_path="doc2query/msmarco-t5-base-v1")
| plg = PseudoLabelGenerator(qg, retriever)
| output, output_id = psg.run(documents=document_store.get_all_documents())
|
```
"""
def __init__(
self,
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
retriever: BaseRetriever,
cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
max_questions_per_document: int = 3,
top_k: int = 50,
batch_size: int = 4,
progress_bar: bool = True,
):
"""
Loads the cross encoder model and prepares PseudoLabelGenerator.
:param question_producer: The question producer used to generate questions or a list of already produced
questions/document pairs in Dict format {"question": "question text ...", "document": "document text ..."}.
:type question_producer: Union[QuestionGenerator, List[Dict[str, str]]]
:param retriever: The retriever used to query document stores
:type retriever: BaseRetriever
:param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to
cross-encoder/ms-marco-MiniLM-L-6-v2
:type cross_encoder_model_name_or_path: str (optional)
:param max_questions_per_document: The max number of questions generated per document, defaults to 3
:type max_questions_per_document: int
:param top_k: The number of answers retrieved for each question, defaults to 50
:type top_k: int (optional)
:param batch_size: Number of documents to process at a time
:type batch_size: int (optional)
"""
super().__init__()
self.question_document_pairs = None
self.question_generator = None # type: ignore
if isinstance(question_producer, QuestionGenerator):
self.question_generator = question_producer
elif isinstance(question_producer, list) and len(question_producer) > 0:
example = question_producer[0]
if isinstance(example, dict) and "question" in example and "document" in example:
self.question_document_pairs = question_producer
else:
raise ValueError("question_producer list must contain dicts with keys 'question' and 'document'")
else:
raise ValueError("Provide either a QuestionGenerator or nonempty list of questions/document pairs")
self.retriever = retriever
self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path)
self.max_questions_per_document = max_questions_per_document
self.top_k = top_k
self.batch_size = batch_size
self.progress_bar = progress_bar
def generate_questions(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Dict[str, str]]:
"""
It takes a list of documents and generates a list of question-document pairs.
:param documents: A list of documents to generate questions from
:type documents: List[Document]
:param batch_size: Number of documents to process at a time.
:type batch_size: Optional[int]
:return: A list of question-document pairs.
"""
question_doc_pairs: List[Dict[str, str]] = []
if self.question_document_pairs:
question_doc_pairs = self.question_document_pairs
else:
batch_size = batch_size if batch_size else self.batch_size
questions: List[List[str]] = self.question_generator.generate_batch( # type: ignore
[d.content for d in documents], batch_size=batch_size
)
for idx, question_list_per_doc in enumerate(questions):
for q in question_list_per_doc[: self.max_questions_per_document]: # type: ignore
question_doc_pairs.append({"question": q.strip(), "document": documents[idx].content})
return question_doc_pairs
def mine_negatives(
self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None
) -> List[Dict[str, str]]:
"""
Given a list of question and pos_doc pairs, this function returns a list of question/pos_doc/neg_doc
dictionaries.
:param question_doc_pairs: A list of question/pos_doc pairs
:type question_doc_pairs: List[Dict[str, str]]
:param batch_size: The number of queries to run in a batch
:type batch_size: int (optional)
:return: A list of dictionaries, where each dictionary contains the question, positive document,
and negative document.
"""
question_pos_doc_neg_doc: List[Dict[str, str]] = []
batch_size = batch_size if batch_size else self.batch_size
for i in tqdm(
range(0, len(question_doc_pairs), batch_size), disable=not self.progress_bar, desc="Mine negatives"
):
# question in batches to minimize network latency
i_end = min(i + batch_size, len(question_doc_pairs))
queries: List[str] = [e["question"] for e in question_doc_pairs[i:i_end]]
pos_docs: List[str] = [e["document"] for e in question_doc_pairs[i:i_end]]
docs: List[List[Document]] = self.retriever.retrieve_batch(
queries=queries, top_k=self.top_k, batch_size=batch_size
)
# iterate through queries and find negatives
for question, pos_doc, top_docs in zip(queries, pos_docs, docs):
random.shuffle(top_docs)
for doc_item in top_docs:
neg_doc = doc_item.content
if neg_doc != pos_doc:
question_pos_doc_neg_doc.append({"question": question, "pos_doc": pos_doc, "neg_doc": neg_doc})
break
return question_pos_doc_neg_doc
def generate_margin_scores(
self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None
) -> List[Dict]:
"""
Given a list of mined negatives, predict the score margin between the positive and negative document using
the cross encoder.
The function returns a list of examples, where each example is a dictionary with the following keys:
* question: the question string
* pos_doc: the positive document string
* neg_doc: the negative document string
* score: the score margin
:param mined_negatives: List of mined negatives
:type mined_negatives: List[Dict[str, str]]
:param batch_size: The number of mined negative lists to run in a batch
:type batch_size: int (optional)
:return: A list of dictionaries, each of which has the following keys:
- question: The question string
- pos_doc: The positive document string
- neg_doc: The negative document string
- score: The score margin
"""
examples: List[Dict] = []
batch_size = batch_size if batch_size else self.batch_size
for i in tqdm(range(0, len(mined_negatives), batch_size), disable=not self.progress_bar, desc="Score margin"):
negatives_batch = mined_negatives[i : i + batch_size]
pb = []
for item in negatives_batch:
pb.append([item["question"], item["pos_doc"]])
pb.append([item["question"], item["neg_doc"]])
scores = self.cross_encoder.predict(pb)
for idx, item in enumerate(negatives_batch):
scores_idx = idx * 2
score_margin = scores[scores_idx] - scores[scores_idx + 1]
examples.append(
{
"question": item["question"],
"pos_doc": item["pos_doc"],
"neg_doc": item["neg_doc"],
"score": score_margin,
}
)
return examples
def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]:
"""
Given a list of documents, generate a list of question-document pairs, mine for negatives, and
score positive/negative margin with cross-encoder. The output is the training data for the
adaptation of dense retriever models.
:param documents: List[Document] = List of documents to mine negatives from
:type documents: List[Document]
:param batch_size: The number of documents to process in a batch
:type batch_size: Optional[int]
:return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each
dictionary contains the following keys:
- question: the question
- pos_doc: the positive document for the given question
- neg_doc: the negative document for the given question
- score: the margin score (a float)
"""
# see https://github.com/UKPLab/gpl for more information about GPL algorithm
batch_size = batch_size if batch_size else self.batch_size
# step 1: generate questions
question_doc_pairs = self.generate_questions(documents=documents, batch_size=batch_size)
# step 2: negative mining
mined_negatives = self.mine_negatives(question_doc_pairs=question_doc_pairs, batch_size=batch_size)
# step 3: pseudo labeling (scoring) with cross-encoder
pseudo_labels: List[Dict[str, str]] = self.generate_margin_scores(mined_negatives, batch_size=batch_size)
return {"gpl_labels": pseudo_labels}, "output_1"
def run(self, documents: List[Document]) -> Tuple[dict, str]: # type: ignore
return self.generate_pseudo_labels(documents=documents)
def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[dict, str]: # type: ignore
flat_list_of_documents = []
for sub_list_documents in documents:
if isinstance(sub_list_documents, Iterable):
flat_list_of_documents += sub_list_documents
else:
flat_list_of_documents.append(sub_list_documents)
return self.generate_pseudo_labels(documents=flat_list_of_documents)

View File

@ -37,6 +37,7 @@ class QuestionGenerator(BaseComponent):
split_overlap=10,
use_gpu=True,
prompt="generate questions:",
num_queries_per_doc=1,
batch_size: Optional[int] = None,
):
"""
@ -65,6 +66,7 @@ class QuestionGenerator(BaseComponent):
self.split_overlap = split_overlap
self.preprocessor = PreProcessor()
self.prompt = prompt
self.num_queries_per_doc = num_queries_per_doc
self.batch_size = batch_size
def run(self, documents: List[Document]): # type: ignore
@ -122,6 +124,7 @@ class QuestionGenerator(BaseComponent):
no_repeat_ngram_size=self.no_repeat_ngram_size,
length_penalty=self.length_penalty,
early_stopping=self.early_stopping,
num_return_sequences=self.num_queries_per_doc,
)
string_output = self.tokenizer.batch_decode(tokens_output)
@ -190,6 +193,7 @@ class QuestionGenerator(BaseComponent):
no_repeat_ngram_size=self.no_repeat_ngram_size,
length_penalty=self.length_penalty,
early_stopping=self.early_stopping,
num_return_sequences=self.num_queries_per_doc,
)
string_output = self.tokenizer.batch_decode(tokens_output)

View File

@ -1,17 +1,20 @@
from typing import TYPE_CHECKING, Callable, List, Union, Dict
import logging
from abc import abstractmethod
import numpy as np
from tqdm.auto import tqdm
import torch
from torch.utils.data.sampler import SequentialSampler
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
from haystack.schema import Document
import numpy as np
import torch
from sentence_transformers import InputExample, losses
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
from haystack.modeling.infer import Inferencer
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.schema import Document
if TYPE_CHECKING:
from haystack.nodes.retriever import EmbeddingRetriever
@ -41,6 +44,47 @@ class _BaseEmbeddingEncoder:
"""
pass
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
"""
Trains/adapts the underlying embedding model.
Each training data example is a dictionary with the following keys:
* question: the question string
* pos_doc: the positive document string
* neg_doc: the negative document string
* score: the score margin
:param training_data: The training data
:type training_data: List[Dict[str, Any]]
:param learning_rate: The learning rate
:type learning_rate: float
:param n_epochs: The number of training epochs
:type n_epochs: int
:param num_warmup_steps: The number of warmup steps
:type num_warmup_steps: int
:param batch_size: The batch size to use for the training, defaults to 16
:type batch_size: int (optional)
"""
pass
def save(self, save_dir: Union[Path, str]):
"""
Save the model to the given directory
:param save_dir: The directory where the model will be saved
:type save_dir: Union[Path, str]
"""
pass
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@ -87,6 +131,19 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
passages = [d.content for d in docs] # type: ignore
return self.embed(passages)
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
def save(self, save_dir: Union[Path, str]):
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@ -127,6 +184,33 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
return self.embed(passages)
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
train_examples = [
InputExample(texts=[i["question"], i["pos_doc"], i["neg_doc"]], label=i["score"]) for i in training_data
]
logger.info(f"GPL training/adapting {self.embedding_model} with {len(train_examples)} examples")
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True)
train_loss = losses.MarginMSELoss(self.embedding_model)
# Tune the model
self.embedding_model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=n_epochs,
optimizer_params={"lr": learning_rate},
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
)
def save(self, save_dir: Union[Path, str]):
self.embedding_model.save(path=str(save_dir))
class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@ -208,6 +292,19 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
dataset, tensornames = convert_features_to_dataset(features=features_flat)
return dataset, tensornames
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
):
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
def save(self, save_dir: Union[Path, str]):
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
"farm": _DefaultEmbeddingEncoder,

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Union, Optional
from typing import List, Dict, Union, Optional, Any
import logging
from pathlib import Path
@ -1851,3 +1851,50 @@ class EmbeddingRetriever(BaseRetriever):
# Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder
return "farm"
def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: int = None,
batch_size: int = 16,
) -> None:
"""
Trains/adapts the underlying embedding model.
Each training data example is a dictionary with the following keys:
* question: the question string
* pos_doc: the positive document string
* neg_doc: the negative document string
* score: the score margin
:param training_data: The training data
:type training_data: List[Dict[str, Any]]
:param learning_rate: The learning rate
:type learning_rate: float
:param n_epochs: The number of epochs
:type n_epochs: int
:param num_warmup_steps: The number of warmup steps
:type num_warmup_steps: int
:param batch_size: The batch size to use for the training, defaults to 16
:type batch_size: int (optional)
"""
self.embedding_encoder.train(
training_data,
learning_rate=learning_rate,
n_epochs=n_epochs,
num_warmup_steps=num_warmup_steps,
batch_size=batch_size,
)
def save(self, save_dir: Union[Path, str]) -> None:
"""
Save the model to the given directory
:param save_dir: The directory where the model will be saved
:type save_dir: Union[Path, str]
"""
self.embedding_encoder.save(save_dir=save_dir)

View File

@ -672,6 +672,13 @@ def get_retriever(retriever_type, document_store):
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False
)
elif retriever_type == "embedding_sbert":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
model_format="sentence_transformers",
use_gpu=False,
)
elif retriever_type == "retribert":
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False

View File

@ -0,0 +1,62 @@
from pathlib import Path
import pytest
from haystack.nodes import QuestionGenerator, EmbeddingRetriever, PseudoLabelGenerator
from test.conftest import DOCS_WITH_EMBEDDINGS
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
def test_pseudo_label_generator(
document_store, retriever: EmbeddingRetriever, question_generator: QuestionGenerator, tmp_path: Path
):
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
psg = PseudoLabelGenerator(question_generator, retriever)
train_examples = []
for idx, doc in enumerate(document_store):
output, stream = psg.run(documents=[doc])
assert "gpl_labels" in output
for item in output["gpl_labels"]:
assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
train_examples.append(item)
assert len(train_examples) > 0
retriever.train(train_examples)
retriever.save(tmp_path)
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
def test_pseudo_label_generator_using_question_document_pairs(
document_store, retriever: EmbeddingRetriever, tmp_path: Path
):
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
docs = [
{
"question": "What is the capital of Germany?",
"document": "Berlin is the capital and largest city of Germany by both area and population.",
},
{
"question": "What is the largest city in Germany by population and area?",
"document": "Berlin is the capital and largest city of Germany by both area and population.",
},
]
psg = PseudoLabelGenerator(docs, retriever)
train_examples = []
for idx, doc in enumerate(document_store):
# the documents passed here are ignored as we provided source documents in the constructor
output, stream = psg.run(documents=[doc])
assert "gpl_labels" in output
for item in output["gpl_labels"]:
assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
train_examples.append(item)
assert len(train_examples) > 0
retriever.train(train_examples)
retriever.save(tmp_path)