diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 7b02221e9..8d588594f 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -25,29 +25,31 @@ that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_ names must match with the filters dict supplied in self.retrieve(). :: -An example custom_query: +**An example custom_query:** +```python { -"size": 10, -"query": { -"bool": { -"should": [{"multi_match": { -"query": "${question}", // mandatory $question placeholder -"type": "most_fields", -"fields": ["text", "title"]}}], -"filter": [ // optional custom filters -{"terms": {"year": "${years}"}}, -{"terms": {"quarter": "${quarters}"}}, -{"range": {"date": {"gte": "${date}"}}} -], - -} -}, +> "size": 10, +> "query": { +> "bool": { +> "should": [{"multi_match": { +> "query": "${question}", // mandatory $question placeholder +> "type": "most_fields", +> "fields": ["text", "title"]}}], +> "filter": [ // optional custom filters +> {"terms": {"year": "${years}"}}, +> {"terms": {"quarter": "${quarters}"}}, +> {"range": {"date": {"gte": "${date}"}}} +> ], +> } +> }, } +``` -For this custom_query, a sample retrieve() could be: -:: +**For this custom_query, a sample retrieve() could be:** +```python self.retrieve(query="Why did the revenue increase?", -filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) +> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) +``` ## ElasticsearchFilterOnlyRetriever @@ -92,12 +94,25 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que #### \_\_init\_\_ ```python - | __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, remove_sep_tok_from_untitled_passages: bool = True) + | __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product") ``` Init the Retriever incl. the two encoder models from a local or remote model checkpoint. The checkpoint format matches huggingface transformers' model format +**Example:** + +```python +# remote model from FAIR +DensePassageRetriever(document_store=your_doc_store, +> query_embedding_model="facebook/dpr-question_encoder-single-nq-base", +> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base") +# or from local path +DensePassageRetriever(document_store=your_doc_store, +> query_embedding_model="model_directory/question-encoder", +> passage_embedding_model="model_directory/context-encoder") +``` + **Arguments**: - `document_store`: An instance of DocumentStore from which to retrieve documents. @@ -107,7 +122,8 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas - `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the one used by hugging-face transformers' modelhub models Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"`` -- `max_seq_len`: Longest length of each sequence +- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." +- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." - `use_gpu`: Whether to use gpu or not - `batch_size`: Number of questions or passages to encode at once - `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding. @@ -116,9 +132,6 @@ titles contain meaningful information for retrieval (topic, entities etc.) . The title is expected to be present in doc.meta["name"] and can be supplied in the documents before writing them to the DocumentStore like this: {"text": "my text", "meta": {"name": "my title"}}. -- `remove_sep_tok_from_untitled_passages`: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title. -If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]). -If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP]) #### embed\_queries @@ -154,6 +167,39 @@ Create embeddings for a list of passages using the passage encoder Embeddings of documents / passages shape (batch_size, embedding_dim) + +#### train + +```python + | train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_negatives: int = 0, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr-tutorial", query_encoder_save_dir: str = "lm1", passage_encoder_save_dir: str = "lm2") +``` + +train a DensePassageRetrieval model + +**Arguments**: + +- `data_dir`: Directory where training file, dev file and test file are present +- `train_filename`: training filename +- `dev_filename`: development set filename, file to be used by model in eval step of training +- `test_filename`: test set filename, file to be used by model in test step after training +- `batch_size`: total number of samples in 1 batch of data +- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage +- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer +- `num_negatives`: number of negative passages(any random passage from dataset which do not contain answer to query) +- `n_epochs`: number of epochs to train the model on +- `evaluate_every`: number of training steps after evaluation is run +- `n_gpu`: number of gpus to train on +- `learning_rate`: learning rate of optimizer +- `epsilon`: epsilon parameter of optimizer +- `weight_decay`: weight decay parameter of optimizer +- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done +- `optimizer_name`: what optimizer to use (default: TransformersAdamW) +- `num_warmup_steps`: number of warmup steps +- `optimizer_correct_bias`: Whether to correct bias in optimizer +- `save_dir`: directory where models are saved +- `query_encoder_save_dir`: directory inside save_dir where query_encoder model files are saved +- `passage_encoder_save_dir`: directory inside save_dir where passage_encoder model files are saved + ## EmbeddingRetriever @@ -271,7 +317,7 @@ that are most relevant to the query. #### eval ```python - | eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False) -> dict + | eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False, return_preds: bool = False) -> dict ``` Performs evaluation on the Retriever. @@ -293,4 +339,6 @@ documents a higher rank. contained in the retrieved docs (common approach in open-domain QA). If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids are within ids explicitly stated in the labels. +- `return_preds`: Whether to add predictions in the returned dictionary. If True, the returned dictionary +contains the keys "predictions" and "metrics". diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 289bb7de1..a1e3f8c2f 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -49,6 +49,19 @@ class DensePassageRetriever(BaseRetriever): Init the Retriever incl. the two encoder models from a local or remote model checkpoint. The checkpoint format matches huggingface transformers' model format + **Example:** + + ```python + # remote model from FAIR + DensePassageRetriever(document_store=your_doc_store, + > query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + > passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base") + # or from local path + DensePassageRetriever(document_store=your_doc_store, + > query_embedding_model="model_directory/question-encoder", + > passage_embedding_model="model_directory/context-encoder") + ``` + :param document_store: An instance of DocumentStore from which to retrieve documents. :param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the one used by hugging-face transformers' modelhub models diff --git a/haystack/retriever/sparse.py b/haystack/retriever/sparse.py index 5ecead1b1..ea658325f 100644 --- a/haystack/retriever/sparse.py +++ b/haystack/retriever/sparse.py @@ -25,29 +25,31 @@ class ElasticsearchRetriever(BaseRetriever): names must match with the filters dict supplied in self.retrieve(). :: - An example custom_query: + **An example custom_query:** + ```python { - "size": 10, - "query": { - "bool": { - "should": [{"multi_match": { - "query": "${question}", // mandatory $question placeholder - "type": "most_fields", - "fields": ["text", "title"]}}], - "filter": [ // optional custom filters - {"terms": {"year": "${years}"}}, - {"terms": {"quarter": "${quarters}"}}, - {"range": {"date": {"gte": "${date}"}}} - ], + > "size": 10, + > "query": { + > "bool": { + > "should": [{"multi_match": { + > "query": "${question}", // mandatory $question placeholder + > "type": "most_fields", + > "fields": ["text", "title"]}}], + > "filter": [ // optional custom filters + > {"terms": {"year": "${years}"}}, + > {"terms": {"quarter": "${quarters}"}}, + > {"range": {"date": {"gte": "${date}"}}} + > ], + > } + > }, + } + ``` - } - }, - } - - For this custom_query, a sample retrieve() could be: - :: - self.retrieve(query="Why did the revenue increase?", - filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + **For this custom_query, a sample retrieve() could be:** + ```python + self.retrieve(query="Why did the revenue increase?", + > filters={"years": ["2019"], "quarters": ["Q1", "Q2"]}) + ``` """ self.document_store: ElasticsearchDocumentStore = document_store self.custom_query = custom_query