mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
Add Generative Pseudo Labeling (#2388)
This commit is contained in:
parent
61d9429c25
commit
e10a3fba74
@ -23,7 +23,7 @@ come from earlier in the document.
|
||||
#### QuestionGenerator.\_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", batch_size: Optional[int] = None)
|
||||
def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", num_queries_per_doc=1, batch_size: Optional[int] = None)
|
||||
```
|
||||
|
||||
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is
|
||||
|
||||
@ -1433,6 +1433,45 @@ Create embeddings for a list of documents.
|
||||
|
||||
Embeddings, one per input document
|
||||
|
||||
<a id="dense.EmbeddingRetriever.train"></a>
|
||||
|
||||
#### EmbeddingRetriever.train
|
||||
|
||||
```python
|
||||
def train(training_data: List[Dict[str, Any]], learning_rate: float = 2e-5, n_epochs: int = 1, num_warmup_steps: int = None, batch_size: int = 16) -> None
|
||||
```
|
||||
|
||||
Trains/adapts the underlying embedding model.
|
||||
|
||||
Each training data example is a dictionary with the following keys:
|
||||
|
||||
* question: the question string
|
||||
* pos_doc: the positive document string
|
||||
* neg_doc: the negative document string
|
||||
* score: the score margin
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `training_data` (`List[Dict[str, Any]]`): The training data
|
||||
- `learning_rate` (`float`): The learning rate
|
||||
- `n_epochs` (`int`): The number of epochs
|
||||
- `num_warmup_steps` (`int`): The number of warmup steps
|
||||
- `batch_size` (`int (optional)`): The batch size to use for the training, defaults to 16
|
||||
|
||||
<a id="dense.EmbeddingRetriever.save"></a>
|
||||
|
||||
#### EmbeddingRetriever.save
|
||||
|
||||
```python
|
||||
def save(save_dir: Union[Path, str]) -> None
|
||||
```
|
||||
|
||||
Save the model to the given directory
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `save_dir` (`Union[Path, str]`): The directory where the model will be saved
|
||||
|
||||
<a id="text2sparql"></a>
|
||||
|
||||
# Module text2sparql
|
||||
|
||||
@ -79,6 +79,7 @@ from haystack.nodes import (
|
||||
retriever,
|
||||
summarizer,
|
||||
translator,
|
||||
label_generator,
|
||||
)
|
||||
|
||||
# Note that we ignore the ImportError here because if the user did not install
|
||||
|
||||
@ -110,6 +110,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/PreProcessorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/QuestionGeneratorComponent"
|
||||
},
|
||||
@ -2439,6 +2442,75 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"PseudoLabelGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "PseudoLabelGenerator"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question_producer": {
|
||||
"title": "Question Producer",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"retriever": {
|
||||
"title": "Retriever",
|
||||
"type": "string"
|
||||
},
|
||||
"cross_encoder_model_name_or_path": {
|
||||
"title": "Cross Encoder Model Name Or Path",
|
||||
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"type": "string"
|
||||
},
|
||||
"total_number_of_questions": {
|
||||
"title": "Total Number Of Questions",
|
||||
"default": 9223372036854775807,
|
||||
"type": "integer"
|
||||
},
|
||||
"top_k": {
|
||||
"title": "Top K",
|
||||
"default": 10,
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"question_producer",
|
||||
"retriever"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"QuestionGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@ -113,6 +113,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/PreProcessorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/QuestionGeneratorComponent"
|
||||
},
|
||||
@ -2618,6 +2621,75 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"PseudoLabelGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "PseudoLabelGenerator"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question_producer": {
|
||||
"title": "Question Producer",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"retriever": {
|
||||
"title": "Retriever",
|
||||
"type": "string"
|
||||
},
|
||||
"cross_encoder_model_name_or_path": {
|
||||
"title": "Cross Encoder Model Name Or Path",
|
||||
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"type": "string"
|
||||
},
|
||||
"total_number_of_questions": {
|
||||
"title": "Total Number Of Questions",
|
||||
"default": 9223372036854775807,
|
||||
"type": "integer"
|
||||
},
|
||||
"top_k": {
|
||||
"title": "Top K",
|
||||
"default": 10,
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"question_producer",
|
||||
"retriever"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"QuestionGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@ -127,6 +127,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/PreProcessorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/PseudoLabelGeneratorComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/QuestionGeneratorComponent"
|
||||
},
|
||||
@ -3193,6 +3196,85 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"PseudoLabelGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "PseudoLabelGenerator"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question_producer": {
|
||||
"title": "Question Producer",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"retriever": {
|
||||
"title": "Retriever",
|
||||
"type": "string"
|
||||
},
|
||||
"cross_encoder_model_name_or_path": {
|
||||
"title": "Cross Encoder Model Name Or Path",
|
||||
"default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"type": "string"
|
||||
},
|
||||
"max_questions_per_document": {
|
||||
"title": "Max Questions Per Document",
|
||||
"default": 3,
|
||||
"type": "integer"
|
||||
},
|
||||
"top_k": {
|
||||
"title": "Top K",
|
||||
"default": 50,
|
||||
"type": "integer"
|
||||
},
|
||||
"batch_size": {
|
||||
"title": "Batch Size",
|
||||
"default": 4,
|
||||
"type": "integer"
|
||||
},
|
||||
"progress_bar": {
|
||||
"title": "Progress Bar",
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"question_producer",
|
||||
"retriever"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"QuestionGeneratorComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -3254,6 +3336,10 @@
|
||||
"title": "Prompt",
|
||||
"default": "generate questions:"
|
||||
},
|
||||
"num_queries_per_doc": {
|
||||
"title": "Num Queries Per Doc",
|
||||
"default": 1
|
||||
},
|
||||
"batch_size": {
|
||||
"title": "Batch Size",
|
||||
"type": "integer"
|
||||
|
||||
@ -20,6 +20,7 @@ from haystack.nodes.file_converter import (
|
||||
AzureConverter,
|
||||
ParsrConverter,
|
||||
)
|
||||
from haystack.nodes.label_generator import PseudoLabelGenerator
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers
|
||||
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
|
||||
1
haystack/nodes/label_generator/__init__.py
Normal file
1
haystack/nodes/label_generator/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.nodes.label_generator.pseudo_label_generator import PseudoLabelGenerator
|
||||
233
haystack/nodes/label_generator/pseudo_label_generator.py
Normal file
233
haystack/nodes/label_generator/pseudo_label_generator.py
Normal file
@ -0,0 +1,233 @@
|
||||
import random
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from sentence_transformers import CrossEncoder
|
||||
from tqdm.auto import tqdm
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.schema import Document
|
||||
|
||||
|
||||
class PseudoLabelGenerator(BaseComponent):
|
||||
"""
|
||||
The PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the
|
||||
training of dense retrievers.
|
||||
|
||||
GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question
|
||||
generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access
|
||||
to an unlabeled target corpus, usually via DocumentStore and a retriever to mine for negatives.
|
||||
|
||||
For more details see [https://github.com/UKPLab/gpl](https://github.com/UKPLab/gpl)
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
| document_store = DocumentStore(...)
|
||||
| retriever = Retriever(...)
|
||||
| qg = QuestionGenerator(model_name_or_path="doc2query/msmarco-t5-base-v1")
|
||||
| plg = PseudoLabelGenerator(qg, retriever)
|
||||
| output, output_id = psg.run(documents=document_store.get_all_documents())
|
||||
|
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
|
||||
retriever: BaseRetriever,
|
||||
cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
max_questions_per_document: int = 3,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 4,
|
||||
progress_bar: bool = True,
|
||||
):
|
||||
"""
|
||||
Loads the cross encoder model and prepares PseudoLabelGenerator.
|
||||
|
||||
:param question_producer: The question producer used to generate questions or a list of already produced
|
||||
questions/document pairs in Dict format {"question": "question text ...", "document": "document text ..."}.
|
||||
:type question_producer: Union[QuestionGenerator, List[Dict[str, str]]]
|
||||
:param retriever: The retriever used to query document stores
|
||||
:type retriever: BaseRetriever
|
||||
:param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to
|
||||
cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
:type cross_encoder_model_name_or_path: str (optional)
|
||||
:param max_questions_per_document: The max number of questions generated per document, defaults to 3
|
||||
:type max_questions_per_document: int
|
||||
:param top_k: The number of answers retrieved for each question, defaults to 50
|
||||
:type top_k: int (optional)
|
||||
:param batch_size: Number of documents to process at a time
|
||||
:type batch_size: int (optional)
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.question_document_pairs = None
|
||||
self.question_generator = None # type: ignore
|
||||
if isinstance(question_producer, QuestionGenerator):
|
||||
self.question_generator = question_producer
|
||||
elif isinstance(question_producer, list) and len(question_producer) > 0:
|
||||
example = question_producer[0]
|
||||
if isinstance(example, dict) and "question" in example and "document" in example:
|
||||
self.question_document_pairs = question_producer
|
||||
else:
|
||||
raise ValueError("question_producer list must contain dicts with keys 'question' and 'document'")
|
||||
else:
|
||||
raise ValueError("Provide either a QuestionGenerator or nonempty list of questions/document pairs")
|
||||
|
||||
self.retriever = retriever
|
||||
self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path)
|
||||
self.max_questions_per_document = max_questions_per_document
|
||||
self.top_k = top_k
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def generate_questions(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Dict[str, str]]:
|
||||
"""
|
||||
It takes a list of documents and generates a list of question-document pairs.
|
||||
|
||||
:param documents: A list of documents to generate questions from
|
||||
:type documents: List[Document]
|
||||
:param batch_size: Number of documents to process at a time.
|
||||
:type batch_size: Optional[int]
|
||||
:return: A list of question-document pairs.
|
||||
"""
|
||||
question_doc_pairs: List[Dict[str, str]] = []
|
||||
if self.question_document_pairs:
|
||||
question_doc_pairs = self.question_document_pairs
|
||||
else:
|
||||
batch_size = batch_size if batch_size else self.batch_size
|
||||
questions: List[List[str]] = self.question_generator.generate_batch( # type: ignore
|
||||
[d.content for d in documents], batch_size=batch_size
|
||||
)
|
||||
for idx, question_list_per_doc in enumerate(questions):
|
||||
for q in question_list_per_doc[: self.max_questions_per_document]: # type: ignore
|
||||
question_doc_pairs.append({"question": q.strip(), "document": documents[idx].content})
|
||||
return question_doc_pairs
|
||||
|
||||
def mine_negatives(
|
||||
self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Given a list of question and pos_doc pairs, this function returns a list of question/pos_doc/neg_doc
|
||||
dictionaries.
|
||||
|
||||
:param question_doc_pairs: A list of question/pos_doc pairs
|
||||
:type question_doc_pairs: List[Dict[str, str]]
|
||||
:param batch_size: The number of queries to run in a batch
|
||||
:type batch_size: int (optional)
|
||||
:return: A list of dictionaries, where each dictionary contains the question, positive document,
|
||||
and negative document.
|
||||
"""
|
||||
question_pos_doc_neg_doc: List[Dict[str, str]] = []
|
||||
batch_size = batch_size if batch_size else self.batch_size
|
||||
|
||||
for i in tqdm(
|
||||
range(0, len(question_doc_pairs), batch_size), disable=not self.progress_bar, desc="Mine negatives"
|
||||
):
|
||||
# question in batches to minimize network latency
|
||||
i_end = min(i + batch_size, len(question_doc_pairs))
|
||||
queries: List[str] = [e["question"] for e in question_doc_pairs[i:i_end]]
|
||||
pos_docs: List[str] = [e["document"] for e in question_doc_pairs[i:i_end]]
|
||||
|
||||
docs: List[List[Document]] = self.retriever.retrieve_batch(
|
||||
queries=queries, top_k=self.top_k, batch_size=batch_size
|
||||
)
|
||||
|
||||
# iterate through queries and find negatives
|
||||
for question, pos_doc, top_docs in zip(queries, pos_docs, docs):
|
||||
random.shuffle(top_docs)
|
||||
for doc_item in top_docs:
|
||||
neg_doc = doc_item.content
|
||||
if neg_doc != pos_doc:
|
||||
question_pos_doc_neg_doc.append({"question": question, "pos_doc": pos_doc, "neg_doc": neg_doc})
|
||||
break
|
||||
return question_pos_doc_neg_doc
|
||||
|
||||
def generate_margin_scores(
|
||||
self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Given a list of mined negatives, predict the score margin between the positive and negative document using
|
||||
the cross encoder.
|
||||
|
||||
The function returns a list of examples, where each example is a dictionary with the following keys:
|
||||
|
||||
* question: the question string
|
||||
* pos_doc: the positive document string
|
||||
* neg_doc: the negative document string
|
||||
* score: the score margin
|
||||
|
||||
:param mined_negatives: List of mined negatives
|
||||
:type mined_negatives: List[Dict[str, str]]
|
||||
:param batch_size: The number of mined negative lists to run in a batch
|
||||
:type batch_size: int (optional)
|
||||
:return: A list of dictionaries, each of which has the following keys:
|
||||
- question: The question string
|
||||
- pos_doc: The positive document string
|
||||
- neg_doc: The negative document string
|
||||
- score: The score margin
|
||||
"""
|
||||
examples: List[Dict] = []
|
||||
batch_size = batch_size if batch_size else self.batch_size
|
||||
for i in tqdm(range(0, len(mined_negatives), batch_size), disable=not self.progress_bar, desc="Score margin"):
|
||||
negatives_batch = mined_negatives[i : i + batch_size]
|
||||
pb = []
|
||||
for item in negatives_batch:
|
||||
pb.append([item["question"], item["pos_doc"]])
|
||||
pb.append([item["question"], item["neg_doc"]])
|
||||
scores = self.cross_encoder.predict(pb)
|
||||
for idx, item in enumerate(negatives_batch):
|
||||
scores_idx = idx * 2
|
||||
score_margin = scores[scores_idx] - scores[scores_idx + 1]
|
||||
examples.append(
|
||||
{
|
||||
"question": item["question"],
|
||||
"pos_doc": item["pos_doc"],
|
||||
"neg_doc": item["neg_doc"],
|
||||
"score": score_margin,
|
||||
}
|
||||
)
|
||||
return examples
|
||||
|
||||
def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]:
|
||||
"""
|
||||
Given a list of documents, generate a list of question-document pairs, mine for negatives, and
|
||||
score positive/negative margin with cross-encoder. The output is the training data for the
|
||||
adaptation of dense retriever models.
|
||||
|
||||
:param documents: List[Document] = List of documents to mine negatives from
|
||||
:type documents: List[Document]
|
||||
:param batch_size: The number of documents to process in a batch
|
||||
:type batch_size: Optional[int]
|
||||
:return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each
|
||||
dictionary contains the following keys:
|
||||
- question: the question
|
||||
- pos_doc: the positive document for the given question
|
||||
- neg_doc: the negative document for the given question
|
||||
- score: the margin score (a float)
|
||||
"""
|
||||
# see https://github.com/UKPLab/gpl for more information about GPL algorithm
|
||||
batch_size = batch_size if batch_size else self.batch_size
|
||||
|
||||
# step 1: generate questions
|
||||
question_doc_pairs = self.generate_questions(documents=documents, batch_size=batch_size)
|
||||
|
||||
# step 2: negative mining
|
||||
mined_negatives = self.mine_negatives(question_doc_pairs=question_doc_pairs, batch_size=batch_size)
|
||||
|
||||
# step 3: pseudo labeling (scoring) with cross-encoder
|
||||
pseudo_labels: List[Dict[str, str]] = self.generate_margin_scores(mined_negatives, batch_size=batch_size)
|
||||
return {"gpl_labels": pseudo_labels}, "output_1"
|
||||
|
||||
def run(self, documents: List[Document]) -> Tuple[dict, str]: # type: ignore
|
||||
return self.generate_pseudo_labels(documents=documents)
|
||||
|
||||
def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[dict, str]: # type: ignore
|
||||
flat_list_of_documents = []
|
||||
for sub_list_documents in documents:
|
||||
if isinstance(sub_list_documents, Iterable):
|
||||
flat_list_of_documents += sub_list_documents
|
||||
else:
|
||||
flat_list_of_documents.append(sub_list_documents)
|
||||
return self.generate_pseudo_labels(documents=flat_list_of_documents)
|
||||
@ -37,6 +37,7 @@ class QuestionGenerator(BaseComponent):
|
||||
split_overlap=10,
|
||||
use_gpu=True,
|
||||
prompt="generate questions:",
|
||||
num_queries_per_doc=1,
|
||||
batch_size: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
@ -65,6 +66,7 @@ class QuestionGenerator(BaseComponent):
|
||||
self.split_overlap = split_overlap
|
||||
self.preprocessor = PreProcessor()
|
||||
self.prompt = prompt
|
||||
self.num_queries_per_doc = num_queries_per_doc
|
||||
self.batch_size = batch_size
|
||||
|
||||
def run(self, documents: List[Document]): # type: ignore
|
||||
@ -122,6 +124,7 @@ class QuestionGenerator(BaseComponent):
|
||||
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.early_stopping,
|
||||
num_return_sequences=self.num_queries_per_doc,
|
||||
)
|
||||
|
||||
string_output = self.tokenizer.batch_decode(tokens_output)
|
||||
@ -190,6 +193,7 @@ class QuestionGenerator(BaseComponent):
|
||||
no_repeat_ngram_size=self.no_repeat_ngram_size,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.early_stopping,
|
||||
num_return_sequences=self.num_queries_per_doc,
|
||||
)
|
||||
|
||||
string_output = self.tokenizer.batch_decode(tokens_output)
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
from typing import TYPE_CHECKING, Callable, List, Union, Dict
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
import torch
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
|
||||
|
||||
from haystack.schema import Document
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import InputExample, losses
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
|
||||
from haystack.modeling.infer import Inferencer
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.schema import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import EmbeddingRetriever
|
||||
@ -41,6 +44,47 @@ class _BaseEmbeddingEncoder:
|
||||
"""
|
||||
pass
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
"""
|
||||
Trains/adapts the underlying embedding model.
|
||||
|
||||
Each training data example is a dictionary with the following keys:
|
||||
|
||||
* question: the question string
|
||||
* pos_doc: the positive document string
|
||||
* neg_doc: the negative document string
|
||||
* score: the score margin
|
||||
|
||||
|
||||
:param training_data: The training data
|
||||
:type training_data: List[Dict[str, Any]]
|
||||
:param learning_rate: The learning rate
|
||||
:type learning_rate: float
|
||||
:param n_epochs: The number of training epochs
|
||||
:type n_epochs: int
|
||||
:param num_warmup_steps: The number of warmup steps
|
||||
:type num_warmup_steps: int
|
||||
:param batch_size: The batch size to use for the training, defaults to 16
|
||||
:type batch_size: int (optional)
|
||||
"""
|
||||
pass
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
"""
|
||||
Save the model to the given directory
|
||||
|
||||
:param save_dir: The directory where the model will be saved
|
||||
:type save_dir: Union[Path, str]
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
@ -87,6 +131,19 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
passages = [d.content for d in docs] # type: ignore
|
||||
return self.embed(passages)
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
||||
|
||||
|
||||
class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
@ -127,6 +184,33 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
|
||||
return self.embed(passages)
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
|
||||
train_examples = [
|
||||
InputExample(texts=[i["question"], i["pos_doc"], i["neg_doc"]], label=i["score"]) for i in training_data
|
||||
]
|
||||
logger.info(f"GPL training/adapting {self.embedding_model} with {len(train_examples)} examples")
|
||||
train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True)
|
||||
train_loss = losses.MarginMSELoss(self.embedding_model)
|
||||
|
||||
# Tune the model
|
||||
self.embedding_model.fit(
|
||||
train_objectives=[(train_dataloader, train_loss)],
|
||||
epochs=n_epochs,
|
||||
optimizer_params={"lr": learning_rate},
|
||||
warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
|
||||
)
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
self.embedding_model.save(path=str(save_dir))
|
||||
|
||||
|
||||
class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
def __init__(self, retriever: "EmbeddingRetriever"):
|
||||
@ -208,6 +292,19 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
dataset, tensornames = convert_features_to_dataset(features=features_flat)
|
||||
return dataset, tensornames
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
||||
|
||||
|
||||
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
||||
"farm": _DefaultEmbeddingEncoder,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Dict, Union, Optional
|
||||
from typing import List, Dict, Union, Optional, Any
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@ -1851,3 +1851,50 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
|
||||
# Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder
|
||||
return "farm"
|
||||
|
||||
def train(
|
||||
self,
|
||||
training_data: List[Dict[str, Any]],
|
||||
learning_rate: float = 2e-5,
|
||||
n_epochs: int = 1,
|
||||
num_warmup_steps: int = None,
|
||||
batch_size: int = 16,
|
||||
) -> None:
|
||||
"""
|
||||
Trains/adapts the underlying embedding model.
|
||||
|
||||
Each training data example is a dictionary with the following keys:
|
||||
|
||||
* question: the question string
|
||||
* pos_doc: the positive document string
|
||||
* neg_doc: the negative document string
|
||||
* score: the score margin
|
||||
|
||||
|
||||
:param training_data: The training data
|
||||
:type training_data: List[Dict[str, Any]]
|
||||
:param learning_rate: The learning rate
|
||||
:type learning_rate: float
|
||||
:param n_epochs: The number of epochs
|
||||
:type n_epochs: int
|
||||
:param num_warmup_steps: The number of warmup steps
|
||||
:type num_warmup_steps: int
|
||||
:param batch_size: The batch size to use for the training, defaults to 16
|
||||
:type batch_size: int (optional)
|
||||
"""
|
||||
self.embedding_encoder.train(
|
||||
training_data,
|
||||
learning_rate=learning_rate,
|
||||
n_epochs=n_epochs,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
def save(self, save_dir: Union[Path, str]) -> None:
|
||||
"""
|
||||
Save the model to the given directory
|
||||
|
||||
:param save_dir: The directory where the model will be saved
|
||||
:type save_dir: Union[Path, str]
|
||||
"""
|
||||
self.embedding_encoder.save(save_dir=save_dir)
|
||||
|
||||
@ -672,6 +672,13 @@ def get_retriever(retriever_type, document_store):
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False
|
||||
)
|
||||
elif retriever_type == "embedding_sbert":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
|
||||
model_format="sentence_transformers",
|
||||
use_gpu=False,
|
||||
)
|
||||
elif retriever_type == "retribert":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
||||
|
||||
62
test/nodes/test_label_generator.py
Normal file
62
test/nodes/test_label_generator.py
Normal file
@ -0,0 +1,62 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.nodes import QuestionGenerator, EmbeddingRetriever, PseudoLabelGenerator
|
||||
from test.conftest import DOCS_WITH_EMBEDDINGS
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
|
||||
def test_pseudo_label_generator(
|
||||
document_store, retriever: EmbeddingRetriever, question_generator: QuestionGenerator, tmp_path: Path
|
||||
):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
psg = PseudoLabelGenerator(question_generator, retriever)
|
||||
train_examples = []
|
||||
for idx, doc in enumerate(document_store):
|
||||
output, stream = psg.run(documents=[doc])
|
||||
assert "gpl_labels" in output
|
||||
for item in output["gpl_labels"]:
|
||||
assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
|
||||
train_examples.append(item)
|
||||
|
||||
assert len(train_examples) > 0
|
||||
retriever.train(train_examples)
|
||||
retriever.save(tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
|
||||
def test_pseudo_label_generator_using_question_document_pairs(
|
||||
document_store, retriever: EmbeddingRetriever, tmp_path: Path
|
||||
):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
docs = [
|
||||
{
|
||||
"question": "What is the capital of Germany?",
|
||||
"document": "Berlin is the capital and largest city of Germany by both area and population.",
|
||||
},
|
||||
{
|
||||
"question": "What is the largest city in Germany by population and area?",
|
||||
"document": "Berlin is the capital and largest city of Germany by both area and population.",
|
||||
},
|
||||
]
|
||||
psg = PseudoLabelGenerator(docs, retriever)
|
||||
train_examples = []
|
||||
for idx, doc in enumerate(document_store):
|
||||
# the documents passed here are ignored as we provided source documents in the constructor
|
||||
output, stream = psg.run(documents=[doc])
|
||||
assert "gpl_labels" in output
|
||||
for item in output["gpl_labels"]:
|
||||
assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
|
||||
train_examples.append(item)
|
||||
|
||||
assert len(train_examples) > 0
|
||||
|
||||
retriever.train(train_examples)
|
||||
retriever.save(tmp_path)
|
||||
Loading…
x
Reference in New Issue
Block a user