Added new formatting for examples in docstrings (#555)

This commit is contained in:
Markus Paff 2020-11-05 15:50:08 +01:00 committed by GitHub
parent 727767388a
commit 40c5c8edb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 46 deletions

View File

@ -25,29 +25,31 @@ that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_
names must match with the filters dict supplied in self.retrieve().
::
An example custom_query:
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"should": [{"multi_match": {
"query": "${question}", // mandatory $question placeholder
"type": "most_fields",
"fields": ["text", "title"]}}],
"filter": [ // optional custom filters
{"terms": {"year": "${years}"}},
{"terms": {"quarter": "${quarters}"}},
{"range": {"date": {"gte": "${date}"}}}
],
}
},
> "size": 10,
> "query": {
> "bool": {
> "should": [{"multi_match": {
> "query": "${question}", // mandatory $question placeholder
> "type": "most_fields",
> "fields": ["text", "title"]}}],
> "filter": [ // optional custom filters
> {"terms": {"year": "${years}"}},
> {"terms": {"quarter": "${quarters}"}},
> {"range": {"date": {"gte": "${date}"}}}
> ],
> }
> },
}
```
For this custom_query, a sample retrieve() could be:
::
**For this custom_query, a sample retrieve() could be:**
```python
self.retrieve(query="Why did the revenue increase?",
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
<a name="sparse.ElasticsearchFilterOnlyRetriever"></a>
## ElasticsearchFilterOnlyRetriever
@ -92,12 +94,25 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
#### \_\_init\_\_
```python
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, remove_sep_tok_from_untitled_passages: bool = True)
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product")
```
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
The checkpoint format matches huggingface transformers' model format
**Example:**
```python
# remote model from FAIR
DensePassageRetriever(document_store=your_doc_store,
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
# or from local path
DensePassageRetriever(document_store=your_doc_store,
> query_embedding_model="model_directory/question-encoder",
> passage_embedding_model="model_directory/context-encoder")
```
**Arguments**:
- `document_store`: An instance of DocumentStore from which to retrieve documents.
@ -107,7 +122,8 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas
- `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
- `max_seq_len`: Longest length of each sequence
- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
- `use_gpu`: Whether to use gpu or not
- `batch_size`: Number of questions or passages to encode at once
- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
@ -116,9 +132,6 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
before writing them to the DocumentStore like this:
{"text": "my text", "meta": {"name": "my title"}}.
- `remove_sep_tok_from_untitled_passages`: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
<a name="dense.DensePassageRetriever.embed_queries"></a>
#### embed\_queries
@ -154,6 +167,39 @@ Create embeddings for a list of passages using the passage encoder
Embeddings of documents / passages shape (batch_size, embedding_dim)
<a name="dense.DensePassageRetriever.train"></a>
#### train
```python
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_negatives: int = 0, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr-tutorial", query_encoder_save_dir: str = "lm1", passage_encoder_save_dir: str = "lm2")
```
train a DensePassageRetrieval model
**Arguments**:
- `data_dir`: Directory where training file, dev file and test file are present
- `train_filename`: training filename
- `dev_filename`: development set filename, file to be used by model in eval step of training
- `test_filename`: test set filename, file to be used by model in test step after training
- `batch_size`: total number of samples in 1 batch of data
- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
- `num_negatives`: number of negative passages(any random passage from dataset which do not contain answer to query)
- `n_epochs`: number of epochs to train the model on
- `evaluate_every`: number of training steps after evaluation is run
- `n_gpu`: number of gpus to train on
- `learning_rate`: learning rate of optimizer
- `epsilon`: epsilon parameter of optimizer
- `weight_decay`: weight decay parameter of optimizer
- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
- `optimizer_name`: what optimizer to use (default: TransformersAdamW)
- `num_warmup_steps`: number of warmup steps
- `optimizer_correct_bias`: Whether to correct bias in optimizer
- `save_dir`: directory where models are saved
- `query_encoder_save_dir`: directory inside save_dir where query_encoder model files are saved
- `passage_encoder_save_dir`: directory inside save_dir where passage_encoder model files are saved
<a name="dense.EmbeddingRetriever"></a>
## EmbeddingRetriever
@ -271,7 +317,7 @@ that are most relevant to the query.
#### eval
```python
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False) -> dict
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False, return_preds: bool = False) -> dict
```
Performs evaluation on the Retriever.
@ -293,4 +339,6 @@ documents a higher rank.
contained in the retrieved docs (common approach in open-domain QA).
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
are within ids explicitly stated in the labels.
- `return_preds`: Whether to add predictions in the returned dictionary. If True, the returned dictionary
contains the keys "predictions" and "metrics".

View File

@ -49,6 +49,19 @@ class DensePassageRetriever(BaseRetriever):
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
The checkpoint format matches huggingface transformers' model format
**Example:**
```python
# remote model from FAIR
DensePassageRetriever(document_store=your_doc_store,
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
# or from local path
DensePassageRetriever(document_store=your_doc_store,
> query_embedding_model="model_directory/question-encoder",
> passage_embedding_model="model_directory/context-encoder")
```
:param document_store: An instance of DocumentStore from which to retrieve documents.
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models

View File

@ -25,29 +25,31 @@ class ElasticsearchRetriever(BaseRetriever):
names must match with the filters dict supplied in self.retrieve().
::
An example custom_query:
**An example custom_query:**
```python
{
"size": 10,
"query": {
"bool": {
"should": [{"multi_match": {
"query": "${question}", // mandatory $question placeholder
"type": "most_fields",
"fields": ["text", "title"]}}],
"filter": [ // optional custom filters
{"terms": {"year": "${years}"}},
{"terms": {"quarter": "${quarters}"}},
{"range": {"date": {"gte": "${date}"}}}
],
> "size": 10,
> "query": {
> "bool": {
> "should": [{"multi_match": {
> "query": "${question}", // mandatory $question placeholder
> "type": "most_fields",
> "fields": ["text", "title"]}}],
> "filter": [ // optional custom filters
> {"terms": {"year": "${years}"}},
> {"terms": {"quarter": "${quarters}"}},
> {"range": {"date": {"gte": "${date}"}}}
> ],
> }
> },
}
```
}
},
}
For this custom_query, a sample retrieve() could be:
::
self.retrieve(query="Why did the revenue increase?",
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
**For this custom_query, a sample retrieve() could be:**
```python
self.retrieve(query="Why did the revenue increase?",
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
```
"""
self.document_store: ElasticsearchDocumentStore = document_store
self.custom_query = custom_query