diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md
index 7b02221e9..8d588594f 100644
--- a/docs/_src/api/api/retriever.md
+++ b/docs/_src/api/api/retriever.md
@@ -25,29 +25,31 @@ that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_
names must match with the filters dict supplied in self.retrieve().
::
-An example custom_query:
+**An example custom_query:**
+```python
{
-"size": 10,
-"query": {
-"bool": {
-"should": [{"multi_match": {
-"query": "${question}", // mandatory $question placeholder
-"type": "most_fields",
-"fields": ["text", "title"]}}],
-"filter": [ // optional custom filters
-{"terms": {"year": "${years}"}},
-{"terms": {"quarter": "${quarters}"}},
-{"range": {"date": {"gte": "${date}"}}}
-],
-
-}
-},
+> "size": 10,
+> "query": {
+> "bool": {
+> "should": [{"multi_match": {
+> "query": "${question}", // mandatory $question placeholder
+> "type": "most_fields",
+> "fields": ["text", "title"]}}],
+> "filter": [ // optional custom filters
+> {"terms": {"year": "${years}"}},
+> {"terms": {"quarter": "${quarters}"}},
+> {"range": {"date": {"gte": "${date}"}}}
+> ],
+> }
+> },
}
+```
-For this custom_query, a sample retrieve() could be:
-::
+**For this custom_query, a sample retrieve() could be:**
+```python
self.retrieve(query="Why did the revenue increase?",
-filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
+> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
+```
## ElasticsearchFilterOnlyRetriever
@@ -92,12 +94,25 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_
```python
- | __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, remove_sep_tok_from_untitled_passages: bool = True)
+ | __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product")
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
The checkpoint format matches huggingface transformers' model format
+**Example:**
+
+```python
+# remote model from FAIR
+DensePassageRetriever(document_store=your_doc_store,
+> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
+> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
+# or from local path
+DensePassageRetriever(document_store=your_doc_store,
+> query_embedding_model="model_directory/question-encoder",
+> passage_embedding_model="model_directory/context-encoder")
+```
+
**Arguments**:
- `document_store`: An instance of DocumentStore from which to retrieve documents.
@@ -107,7 +122,8 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas
- `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
-- `max_seq_len`: Longest length of each sequence
+- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
+- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
- `use_gpu`: Whether to use gpu or not
- `batch_size`: Number of questions or passages to encode at once
- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
@@ -116,9 +132,6 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
-- `remove_sep_tok_from_untitled_passages`: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
-If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
-If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
#### embed\_queries
@@ -154,6 +167,39 @@ Create embeddings for a list of passages using the passage encoder
Embeddings of documents / passages shape (batch_size, embedding_dim)
+
+#### train
+
+```python
+ | train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_negatives: int = 0, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr-tutorial", query_encoder_save_dir: str = "lm1", passage_encoder_save_dir: str = "lm2")
+```
+
+train a DensePassageRetrieval model
+
+**Arguments**:
+
+- `data_dir`: Directory where training file, dev file and test file are present
+- `train_filename`: training filename
+- `dev_filename`: development set filename, file to be used by model in eval step of training
+- `test_filename`: test set filename, file to be used by model in test step after training
+- `batch_size`: total number of samples in 1 batch of data
+- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
+- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
+- `num_negatives`: number of negative passages(any random passage from dataset which do not contain answer to query)
+- `n_epochs`: number of epochs to train the model on
+- `evaluate_every`: number of training steps after evaluation is run
+- `n_gpu`: number of gpus to train on
+- `learning_rate`: learning rate of optimizer
+- `epsilon`: epsilon parameter of optimizer
+- `weight_decay`: weight decay parameter of optimizer
+- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
+- `optimizer_name`: what optimizer to use (default: TransformersAdamW)
+- `num_warmup_steps`: number of warmup steps
+- `optimizer_correct_bias`: Whether to correct bias in optimizer
+- `save_dir`: directory where models are saved
+- `query_encoder_save_dir`: directory inside save_dir where query_encoder model files are saved
+- `passage_encoder_save_dir`: directory inside save_dir where passage_encoder model files are saved
+
## EmbeddingRetriever
@@ -271,7 +317,7 @@ that are most relevant to the query.
#### eval
```python
- | eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False) -> dict
+ | eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False, return_preds: bool = False) -> dict
```
Performs evaluation on the Retriever.
@@ -293,4 +339,6 @@ documents a higher rank.
contained in the retrieved docs (common approach in open-domain QA).
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
are within ids explicitly stated in the labels.
+- `return_preds`: Whether to add predictions in the returned dictionary. If True, the returned dictionary
+contains the keys "predictions" and "metrics".
diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py
index 289bb7de1..a1e3f8c2f 100644
--- a/haystack/retriever/dense.py
+++ b/haystack/retriever/dense.py
@@ -49,6 +49,19 @@ class DensePassageRetriever(BaseRetriever):
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
The checkpoint format matches huggingface transformers' model format
+ **Example:**
+
+ ```python
+ # remote model from FAIR
+ DensePassageRetriever(document_store=your_doc_store,
+ > query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
+ > passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
+ # or from local path
+ DensePassageRetriever(document_store=your_doc_store,
+ > query_embedding_model="model_directory/question-encoder",
+ > passage_embedding_model="model_directory/context-encoder")
+ ```
+
:param document_store: An instance of DocumentStore from which to retrieve documents.
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
diff --git a/haystack/retriever/sparse.py b/haystack/retriever/sparse.py
index 5ecead1b1..ea658325f 100644
--- a/haystack/retriever/sparse.py
+++ b/haystack/retriever/sparse.py
@@ -25,29 +25,31 @@ class ElasticsearchRetriever(BaseRetriever):
names must match with the filters dict supplied in self.retrieve().
::
- An example custom_query:
+ **An example custom_query:**
+ ```python
{
- "size": 10,
- "query": {
- "bool": {
- "should": [{"multi_match": {
- "query": "${question}", // mandatory $question placeholder
- "type": "most_fields",
- "fields": ["text", "title"]}}],
- "filter": [ // optional custom filters
- {"terms": {"year": "${years}"}},
- {"terms": {"quarter": "${quarters}"}},
- {"range": {"date": {"gte": "${date}"}}}
- ],
+ > "size": 10,
+ > "query": {
+ > "bool": {
+ > "should": [{"multi_match": {
+ > "query": "${question}", // mandatory $question placeholder
+ > "type": "most_fields",
+ > "fields": ["text", "title"]}}],
+ > "filter": [ // optional custom filters
+ > {"terms": {"year": "${years}"}},
+ > {"terms": {"quarter": "${quarters}"}},
+ > {"range": {"date": {"gte": "${date}"}}}
+ > ],
+ > }
+ > },
+ }
+ ```
- }
- },
- }
-
- For this custom_query, a sample retrieve() could be:
- ::
- self.retrieve(query="Why did the revenue increase?",
- filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
+ **For this custom_query, a sample retrieve() could be:**
+ ```python
+ self.retrieve(query="Why did the revenue increase?",
+ > filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
+ ```
"""
self.document_store: ElasticsearchDocumentStore = document_store
self.custom_query = custom_query