mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
Added new formatting for examples in docstrings (#555)
This commit is contained in:
parent
727767388a
commit
40c5c8edb4
@ -25,29 +25,31 @@ that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_
|
||||
names must match with the filters dict supplied in self.retrieve().
|
||||
::
|
||||
|
||||
An example custom_query:
|
||||
**An example custom_query:**
|
||||
```python
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{"multi_match": {
|
||||
"query": "${question}", // mandatory $question placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["text", "title"]}}],
|
||||
"filter": [ // optional custom filters
|
||||
{"terms": {"year": "${years}"}},
|
||||
{"terms": {"quarter": "${quarters}"}},
|
||||
{"range": {"date": {"gte": "${date}"}}}
|
||||
],
|
||||
|
||||
}
|
||||
},
|
||||
> "size": 10,
|
||||
> "query": {
|
||||
> "bool": {
|
||||
> "should": [{"multi_match": {
|
||||
> "query": "${question}", // mandatory $question placeholder
|
||||
> "type": "most_fields",
|
||||
> "fields": ["text", "title"]}}],
|
||||
> "filter": [ // optional custom filters
|
||||
> {"terms": {"year": "${years}"}},
|
||||
> {"terms": {"quarter": "${quarters}"}},
|
||||
> {"range": {"date": {"gte": "${date}"}}}
|
||||
> ],
|
||||
> }
|
||||
> },
|
||||
}
|
||||
```
|
||||
|
||||
For this custom_query, a sample retrieve() could be:
|
||||
::
|
||||
**For this custom_query, a sample retrieve() could be:**
|
||||
```python
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
```
|
||||
|
||||
<a name="sparse.ElasticsearchFilterOnlyRetriever"></a>
|
||||
## ElasticsearchFilterOnlyRetriever
|
||||
@ -92,12 +94,25 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, remove_sep_tok_from_untitled_passages: bool = True)
|
||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product")
|
||||
```
|
||||
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
The checkpoint format matches huggingface transformers' model format
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# remote model from FAIR
|
||||
DensePassageRetriever(document_store=your_doc_store,
|
||||
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
|
||||
# or from local path
|
||||
DensePassageRetriever(document_store=your_doc_store,
|
||||
> query_embedding_model="model_directory/question-encoder",
|
||||
> passage_embedding_model="model_directory/context-encoder")
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `document_store`: An instance of DocumentStore from which to retrieve documents.
|
||||
@ -107,7 +122,8 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas
|
||||
- `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
|
||||
- `max_seq_len`: Longest length of each sequence
|
||||
- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
|
||||
- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
|
||||
- `use_gpu`: Whether to use gpu or not
|
||||
- `batch_size`: Number of questions or passages to encode at once
|
||||
- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
|
||||
@ -116,9 +132,6 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
|
||||
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
||||
before writing them to the DocumentStore like this:
|
||||
{"text": "my text", "meta": {"name": "my title"}}.
|
||||
- `remove_sep_tok_from_untitled_passages`: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
|
||||
If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
|
||||
If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
|
||||
|
||||
<a name="dense.DensePassageRetriever.embed_queries"></a>
|
||||
#### embed\_queries
|
||||
@ -154,6 +167,39 @@ Create embeddings for a list of passages using the passage encoder
|
||||
|
||||
Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
|
||||
<a name="dense.DensePassageRetriever.train"></a>
|
||||
#### train
|
||||
|
||||
```python
|
||||
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_negatives: int = 0, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr-tutorial", query_encoder_save_dir: str = "lm1", passage_encoder_save_dir: str = "lm2")
|
||||
```
|
||||
|
||||
train a DensePassageRetrieval model
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `data_dir`: Directory where training file, dev file and test file are present
|
||||
- `train_filename`: training filename
|
||||
- `dev_filename`: development set filename, file to be used by model in eval step of training
|
||||
- `test_filename`: test set filename, file to be used by model in test step after training
|
||||
- `batch_size`: total number of samples in 1 batch of data
|
||||
- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
|
||||
- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
|
||||
- `num_negatives`: number of negative passages(any random passage from dataset which do not contain answer to query)
|
||||
- `n_epochs`: number of epochs to train the model on
|
||||
- `evaluate_every`: number of training steps after evaluation is run
|
||||
- `n_gpu`: number of gpus to train on
|
||||
- `learning_rate`: learning rate of optimizer
|
||||
- `epsilon`: epsilon parameter of optimizer
|
||||
- `weight_decay`: weight decay parameter of optimizer
|
||||
- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
|
||||
- `optimizer_name`: what optimizer to use (default: TransformersAdamW)
|
||||
- `num_warmup_steps`: number of warmup steps
|
||||
- `optimizer_correct_bias`: Whether to correct bias in optimizer
|
||||
- `save_dir`: directory where models are saved
|
||||
- `query_encoder_save_dir`: directory inside save_dir where query_encoder model files are saved
|
||||
- `passage_encoder_save_dir`: directory inside save_dir where passage_encoder model files are saved
|
||||
|
||||
<a name="dense.EmbeddingRetriever"></a>
|
||||
## EmbeddingRetriever
|
||||
|
||||
@ -271,7 +317,7 @@ that are most relevant to the query.
|
||||
#### eval
|
||||
|
||||
```python
|
||||
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False) -> dict
|
||||
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False, return_preds: bool = False) -> dict
|
||||
```
|
||||
|
||||
Performs evaluation on the Retriever.
|
||||
@ -293,4 +339,6 @@ documents a higher rank.
|
||||
contained in the retrieved docs (common approach in open-domain QA).
|
||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||
are within ids explicitly stated in the labels.
|
||||
- `return_preds`: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||
contains the keys "predictions" and "metrics".
|
||||
|
||||
|
||||
@ -49,6 +49,19 @@ class DensePassageRetriever(BaseRetriever):
|
||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||
The checkpoint format matches huggingface transformers' model format
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
# remote model from FAIR
|
||||
DensePassageRetriever(document_store=your_doc_store,
|
||||
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
|
||||
# or from local path
|
||||
DensePassageRetriever(document_store=your_doc_store,
|
||||
> query_embedding_model="model_directory/question-encoder",
|
||||
> passage_embedding_model="model_directory/context-encoder")
|
||||
```
|
||||
|
||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
|
||||
one used by hugging-face transformers' modelhub models
|
||||
|
||||
@ -25,29 +25,31 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
names must match with the filters dict supplied in self.retrieve().
|
||||
::
|
||||
|
||||
An example custom_query:
|
||||
**An example custom_query:**
|
||||
```python
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{"multi_match": {
|
||||
"query": "${question}", // mandatory $question placeholder
|
||||
"type": "most_fields",
|
||||
"fields": ["text", "title"]}}],
|
||||
"filter": [ // optional custom filters
|
||||
{"terms": {"year": "${years}"}},
|
||||
{"terms": {"quarter": "${quarters}"}},
|
||||
{"range": {"date": {"gte": "${date}"}}}
|
||||
],
|
||||
> "size": 10,
|
||||
> "query": {
|
||||
> "bool": {
|
||||
> "should": [{"multi_match": {
|
||||
> "query": "${question}", // mandatory $question placeholder
|
||||
> "type": "most_fields",
|
||||
> "fields": ["text", "title"]}}],
|
||||
> "filter": [ // optional custom filters
|
||||
> {"terms": {"year": "${years}"}},
|
||||
> {"terms": {"quarter": "${quarters}"}},
|
||||
> {"range": {"date": {"gte": "${date}"}}}
|
||||
> ],
|
||||
> }
|
||||
> },
|
||||
}
|
||||
```
|
||||
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
For this custom_query, a sample retrieve() could be:
|
||||
::
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
**For this custom_query, a sample retrieve() could be:**
|
||||
```python
|
||||
self.retrieve(query="Why did the revenue increase?",
|
||||
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||
```
|
||||
"""
|
||||
self.document_store: ElasticsearchDocumentStore = document_store
|
||||
self.custom_query = custom_query
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user