mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
Update docstrings for GPL (#2633)
* Update docstrings * Update Documentation & Code Style * Update wrong param description Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
c178f60e3a
commit
f90649fab1
@ -11,14 +11,14 @@ from haystack.schema import Document
|
|||||||
|
|
||||||
class PseudoLabelGenerator(BaseComponent):
|
class PseudoLabelGenerator(BaseComponent):
|
||||||
"""
|
"""
|
||||||
The PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the
|
PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the
|
||||||
training of dense retrievers.
|
training of dense retrievers.
|
||||||
|
|
||||||
GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question
|
GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question
|
||||||
generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access
|
generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access
|
||||||
to an unlabeled target corpus, usually via DocumentStore and a retriever to mine for negatives.
|
to an unlabeled target corpus, usually through DocumentStore and a Retriever to mine for negatives.
|
||||||
|
|
||||||
For more details see [https://github.com/UKPLab/gpl](https://github.com/UKPLab/gpl)
|
For more details, see [GPL](https://github.com/UKPLab/gpl).
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
@ -43,21 +43,21 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
progress_bar: bool = True,
|
progress_bar: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads the cross encoder model and prepares PseudoLabelGenerator.
|
Loads the cross-encoder model and prepares PseudoLabelGenerator.
|
||||||
|
|
||||||
:param question_producer: The question producer used to generate questions or a list of already produced
|
:param question_producer: The question producer used to generate questions or a list of already produced
|
||||||
questions/document pairs in Dict format {"question": "question text ...", "document": "document text ..."}.
|
questions/document pairs in a Dictionary format {"question": "question text ...", "document": "document text ..."}.
|
||||||
:type question_producer: Union[QuestionGenerator, List[Dict[str, str]]]
|
:type question_producer: Union[QuestionGenerator, List[Dict[str, str]]]
|
||||||
:param retriever: The retriever used to query document stores
|
:param retriever: The Retriever used to query document stores.
|
||||||
:type retriever: BaseRetriever
|
:type retriever: BaseRetriever
|
||||||
:param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to
|
:param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to
|
||||||
cross-encoder/ms-marco-MiniLM-L-6-v2
|
`cross-encoder/ms-marco-MiniLM-L-6-v2`.
|
||||||
:type cross_encoder_model_name_or_path: str (optional)
|
:type cross_encoder_model_name_or_path: str (optional)
|
||||||
:param max_questions_per_document: The max number of questions generated per document, defaults to 3
|
:param max_questions_per_document: The max number of questions generated per document, defaults to 3.
|
||||||
:type max_questions_per_document: int
|
:type max_questions_per_document: int
|
||||||
:param top_k: The number of answers retrieved for each question, defaults to 50
|
:param top_k: The number of answers retrieved for each question, defaults to 50.
|
||||||
:type top_k: int (optional)
|
:type top_k: int (optional)
|
||||||
:param batch_size: Number of documents to process at a time
|
:param batch_size: The number of documents to process at a time.
|
||||||
:type batch_size: int (optional)
|
:type batch_size: int (optional)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -71,9 +71,11 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
if isinstance(example, dict) and "question" in example and "document" in example:
|
if isinstance(example, dict) and "question" in example and "document" in example:
|
||||||
self.question_document_pairs = question_producer
|
self.question_document_pairs = question_producer
|
||||||
else:
|
else:
|
||||||
raise ValueError("question_producer list must contain dicts with keys 'question' and 'document'")
|
raise ValueError(
|
||||||
|
"The question_producer list must contain dictionaries with keys 'question' and 'document'."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Provide either a QuestionGenerator or nonempty list of questions/document pairs")
|
raise ValueError("Provide either a QuestionGenerator or a non-empty list of questions/document pairs.")
|
||||||
|
|
||||||
self.retriever = retriever
|
self.retriever = retriever
|
||||||
self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path)
|
self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path)
|
||||||
@ -86,9 +88,9 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
"""
|
"""
|
||||||
It takes a list of documents and generates a list of question-document pairs.
|
It takes a list of documents and generates a list of question-document pairs.
|
||||||
|
|
||||||
:param documents: A list of documents to generate questions from
|
:param documents: A list of documents to generate questions from.
|
||||||
:type documents: List[Document]
|
:type documents: List[Document]
|
||||||
:param batch_size: Number of documents to process at a time.
|
:param batch_size: The number of documents to process at a time.
|
||||||
:type batch_size: Optional[int]
|
:type batch_size: Optional[int]
|
||||||
:return: A list of question-document pairs.
|
:return: A list of question-document pairs.
|
||||||
"""
|
"""
|
||||||
@ -109,12 +111,12 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None
|
self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Given a list of question and pos_doc pairs, this function returns a list of question/pos_doc/neg_doc
|
Given a list of question and positive document pairs, this function returns a list of question/positive document/negative document
|
||||||
dictionaries.
|
dictionaries.
|
||||||
|
|
||||||
:param question_doc_pairs: A list of question/pos_doc pairs
|
:param question_doc_pairs: A list of question/positive document pairs.
|
||||||
:type question_doc_pairs: List[Dict[str, str]]
|
:type question_doc_pairs: List[Dict[str, str]]
|
||||||
:param batch_size: The number of queries to run in a batch
|
:param batch_size: The number of queries to run in a batch.
|
||||||
:type batch_size: int (optional)
|
:type batch_size: int (optional)
|
||||||
:return: A list of dictionaries, where each dictionary contains the question, positive document,
|
:return: A list of dictionaries, where each dictionary contains the question, positive document,
|
||||||
and negative document.
|
and negative document.
|
||||||
@ -148,24 +150,24 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None
|
self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Given a list of mined negatives, predict the score margin between the positive and negative document using
|
Given a list of mined negatives, this function predicts the score margin between the positive and negative document using
|
||||||
the cross encoder.
|
the cross-encoder.
|
||||||
|
|
||||||
The function returns a list of examples, where each example is a dictionary with the following keys:
|
The function returns a list of examples, where each example is a dictionary with the following keys:
|
||||||
|
|
||||||
* question: the question string
|
* question: The question string.
|
||||||
* pos_doc: the positive document string
|
* pos_doc: Positive document string (the document containing the answer).
|
||||||
* neg_doc: the negative document string
|
* neg_doc: Negative document string (the document that doesn't contain the answer).
|
||||||
* score: the score margin
|
* score: The margin between the score for question-positive document pair and the score for question-negative document pair.
|
||||||
|
|
||||||
:param mined_negatives: List of mined negatives
|
:param mined_negatives: The list of mined negatives.
|
||||||
:type mined_negatives: List[Dict[str, str]]
|
:type mined_negatives: List[Dict[str, str]]
|
||||||
:param batch_size: The number of mined negative lists to run in a batch
|
:param batch_size: The number of mined negative lists to run in a batch.
|
||||||
:type batch_size: int (optional)
|
:type batch_size: int (optional)
|
||||||
:return: A list of dictionaries, each of which has the following keys:
|
:return: A list of dictionaries, each of which has the following keys:
|
||||||
- question: The question string
|
- question: The question string
|
||||||
- pos_doc: The positive document string
|
- pos_doc: Positive document string
|
||||||
- neg_doc: The negative document string
|
- neg_doc: Negative document string
|
||||||
- score: The score margin
|
- score: The score margin
|
||||||
"""
|
"""
|
||||||
examples: List[Dict] = []
|
examples: List[Dict] = []
|
||||||
@ -192,20 +194,20 @@ class PseudoLabelGenerator(BaseComponent):
|
|||||||
|
|
||||||
def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]:
|
def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]:
|
||||||
"""
|
"""
|
||||||
Given a list of documents, generate a list of question-document pairs, mine for negatives, and
|
Given a list of documents, this function generates a list of question-document pairs, mines for negatives, and
|
||||||
score positive/negative margin with cross-encoder. The output is the training data for the
|
scores a positive/negative margin with cross-encoder. The output is the training data for the
|
||||||
adaptation of dense retriever models.
|
adaptation of dense retriever models.
|
||||||
|
|
||||||
:param documents: List[Document] = List of documents to mine negatives from
|
:param documents: List[Document] = The list of documents to mine negatives from.
|
||||||
:type documents: List[Document]
|
:type documents: List[Document]
|
||||||
:param batch_size: The number of documents to process in a batch
|
:param batch_size: The number of documents to process in a batch.
|
||||||
:type batch_size: Optional[int]
|
:type batch_size: Optional[int]
|
||||||
:return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each
|
:return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each
|
||||||
dictionary contains the following keys:
|
dictionary contains the following keys:
|
||||||
- question: the question
|
- question: The question string.
|
||||||
- pos_doc: the positive document for the given question
|
- pos_doc: Positive document for the given question.
|
||||||
- neg_doc: the negative document for the given question
|
- neg_doc: Negative document for the given question.
|
||||||
- score: the margin score (a float)
|
- score: The margin between the score for question-positive document pair and the score for question-negative document pair.
|
||||||
"""
|
"""
|
||||||
# see https://github.com/UKPLab/gpl for more information about GPL algorithm
|
# see https://github.com/UKPLab/gpl for more information about GPL algorithm
|
||||||
batch_size = batch_size if batch_size else self.batch_size
|
batch_size = batch_size if batch_size else self.batch_size
|
||||||
|
@ -53,34 +53,34 @@ class _BaseEmbeddingEncoder:
|
|||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Trains/adapts the underlying embedding model.
|
Trains or adapts the underlying embedding model.
|
||||||
|
|
||||||
Each training data example is a dictionary with the following keys:
|
Each training data example is a dictionary with the following keys:
|
||||||
|
|
||||||
* question: the question string
|
* question: The question string.
|
||||||
* pos_doc: the positive document string
|
* pos_doc: Positive document string (the document containing the answer).
|
||||||
* neg_doc: the negative document string
|
* neg_doc: Negative document string (the document that doesn't contain the answer).
|
||||||
* score: the score margin
|
* score: The score margin the answer must fall within.
|
||||||
|
|
||||||
|
|
||||||
:param training_data: The training data
|
:param training_data: The training data in a dictionary format. Required.
|
||||||
:type training_data: List[Dict[str, Any]]
|
:type training_data: List[Dict[str, Any]]
|
||||||
:param learning_rate: The learning rate
|
:param learning_rate: The speed at which the model learns. Required. We recommend that you leave the default `2e-5` value.
|
||||||
:type learning_rate: float
|
:type learning_rate: float
|
||||||
:param n_epochs: The number of training epochs
|
:param n_epochs: The number of epochs (complete passes of the training data through the algorithm) that you want the model to go through. Required.
|
||||||
:type n_epochs: int
|
:type n_epochs: int
|
||||||
:param num_warmup_steps: The number of warmup steps
|
:param num_warmup_steps: The number of warmup steps for the model. Warmup steps are epochs when the learning rate is very low. You can use them at the beginning of the training to prevent early overfitting of your model. Required.
|
||||||
:type num_warmup_steps: int
|
:type num_warmup_steps: int
|
||||||
:param batch_size: The batch size to use for the training, defaults to 16
|
:param batch_size: The batch size to use for the training. Optional. The default values is 16.
|
||||||
:type batch_size: int (optional)
|
:type batch_size: int (optional)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]):
|
def save(self, save_dir: Union[Path, str]):
|
||||||
"""
|
"""
|
||||||
Save the model to the given directory
|
Save the model to the directory you specify.
|
||||||
|
|
||||||
:param save_dir: The directory where the model will be saved
|
:param save_dir: The directory where the model is saved. Required.
|
||||||
:type save_dir: Union[Path, str]
|
:type save_dir: Union[Path, str]
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@ -139,10 +139,14 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
num_warmup_steps: int = None,
|
num_warmup_steps: int = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
raise NotImplementedError(
|
||||||
|
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||||
|
)
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]):
|
def save(self, save_dir: Union[Path, str]):
|
||||||
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
raise NotImplementedError(
|
||||||
|
"You can't save your record as `save` only works for sentence-transformers EmbeddingRetrievers."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||||
@ -300,10 +304,14 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
|||||||
num_warmup_steps: int = None,
|
num_warmup_steps: int = None,
|
||||||
batch_size: int = 16,
|
batch_size: int = 16,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
raise NotImplementedError(
|
||||||
|
"You can't train this retriever. You can only use the `train` method with sentence-transformers EmbeddingRetrievers."
|
||||||
|
)
|
||||||
|
|
||||||
def save(self, save_dir: Union[Path, str]):
|
def save(self, save_dir: Union[Path, str]):
|
||||||
raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
|
raise NotImplementedError(
|
||||||
|
"You can't save your record as `save` only works for sentence-transformers EmbeddingRetrievers."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user