mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 21:33:39 +00:00
Added new formatting for examples in docstrings (#555)
This commit is contained in:
parent
727767388a
commit
40c5c8edb4
@ -25,29 +25,31 @@ that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_
|
|||||||
names must match with the filters dict supplied in self.retrieve().
|
names must match with the filters dict supplied in self.retrieve().
|
||||||
::
|
::
|
||||||
|
|
||||||
An example custom_query:
|
**An example custom_query:**
|
||||||
|
```python
|
||||||
{
|
{
|
||||||
"size": 10,
|
> "size": 10,
|
||||||
"query": {
|
> "query": {
|
||||||
"bool": {
|
> "bool": {
|
||||||
"should": [{"multi_match": {
|
> "should": [{"multi_match": {
|
||||||
"query": "${question}", // mandatory $question placeholder
|
> "query": "${question}", // mandatory $question placeholder
|
||||||
"type": "most_fields",
|
> "type": "most_fields",
|
||||||
"fields": ["text", "title"]}}],
|
> "fields": ["text", "title"]}}],
|
||||||
"filter": [ // optional custom filters
|
> "filter": [ // optional custom filters
|
||||||
{"terms": {"year": "${years}"}},
|
> {"terms": {"year": "${years}"}},
|
||||||
{"terms": {"quarter": "${quarters}"}},
|
> {"terms": {"quarter": "${quarters}"}},
|
||||||
{"range": {"date": {"gte": "${date}"}}}
|
> {"range": {"date": {"gte": "${date}"}}}
|
||||||
],
|
> ],
|
||||||
|
> }
|
||||||
}
|
> },
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
For this custom_query, a sample retrieve() could be:
|
**For this custom_query, a sample retrieve() could be:**
|
||||||
::
|
```python
|
||||||
self.retrieve(query="Why did the revenue increase?",
|
self.retrieve(query="Why did the revenue increase?",
|
||||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||||
|
```
|
||||||
|
|
||||||
<a name="sparse.ElasticsearchFilterOnlyRetriever"></a>
|
<a name="sparse.ElasticsearchFilterOnlyRetriever"></a>
|
||||||
## ElasticsearchFilterOnlyRetriever
|
## ElasticsearchFilterOnlyRetriever
|
||||||
@ -92,12 +94,25 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que
|
|||||||
#### \_\_init\_\_
|
#### \_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, remove_sep_tok_from_untitled_passages: bool = True)
|
| __init__(document_store: BaseDocumentStore, query_embedding_model: str = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: str = "facebook/dpr-ctx_encoder-single-nq-base", max_seq_len_query: int = 64, max_seq_len_passage: int = 256, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product")
|
||||||
```
|
```
|
||||||
|
|
||||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||||
The checkpoint format matches huggingface transformers' model format
|
The checkpoint format matches huggingface transformers' model format
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# remote model from FAIR
|
||||||
|
DensePassageRetriever(document_store=your_doc_store,
|
||||||
|
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||||
|
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
|
||||||
|
# or from local path
|
||||||
|
DensePassageRetriever(document_store=your_doc_store,
|
||||||
|
> query_embedding_model="model_directory/question-encoder",
|
||||||
|
> passage_embedding_model="model_directory/context-encoder")
|
||||||
|
```
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
|
|
||||||
- `document_store`: An instance of DocumentStore from which to retrieve documents.
|
- `document_store`: An instance of DocumentStore from which to retrieve documents.
|
||||||
@ -107,7 +122,8 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas
|
|||||||
- `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the
|
- `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the
|
||||||
one used by hugging-face transformers' modelhub models
|
one used by hugging-face transformers' modelhub models
|
||||||
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
|
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
|
||||||
- `max_seq_len`: Longest length of each sequence
|
- `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
|
||||||
|
- `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
|
||||||
- `use_gpu`: Whether to use gpu or not
|
- `use_gpu`: Whether to use gpu or not
|
||||||
- `batch_size`: Number of questions or passages to encode at once
|
- `batch_size`: Number of questions or passages to encode at once
|
||||||
- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
|
- `embed_title`: Whether to concatenate title and passage to a text pair that is then used to create the embedding.
|
||||||
@ -116,9 +132,6 @@ titles contain meaningful information for retrieval (topic, entities etc.) .
|
|||||||
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
The title is expected to be present in doc.meta["name"] and can be supplied in the documents
|
||||||
before writing them to the DocumentStore like this:
|
before writing them to the DocumentStore like this:
|
||||||
{"text": "my text", "meta": {"name": "my title"}}.
|
{"text": "my text", "meta": {"name": "my title"}}.
|
||||||
- `remove_sep_tok_from_untitled_passages`: If embed_title is ``True``, there are different strategies to deal with documents that don't have a title.
|
|
||||||
If this param is ``True`` => Embed passage as single text, similar to embed_title = False (i.e [CLS] passage_tok1 ... [SEP]).
|
|
||||||
If this param is ``False`` => Embed passage as text pair with empty title (i.e. [CLS] [SEP] passage_tok1 ... [SEP])
|
|
||||||
|
|
||||||
<a name="dense.DensePassageRetriever.embed_queries"></a>
|
<a name="dense.DensePassageRetriever.embed_queries"></a>
|
||||||
#### embed\_queries
|
#### embed\_queries
|
||||||
@ -154,6 +167,39 @@ Create embeddings for a list of passages using the passage encoder
|
|||||||
|
|
||||||
Embeddings of documents / passages shape (batch_size, embedding_dim)
|
Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||||
|
|
||||||
|
<a name="dense.DensePassageRetriever.train"></a>
|
||||||
|
#### train
|
||||||
|
|
||||||
|
```python
|
||||||
|
| train(data_dir: str, train_filename: str, dev_filename: str = None, test_filename: str = None, batch_size: int = 2, embed_title: bool = True, num_hard_negatives: int = 1, num_negatives: int = 0, n_epochs: int = 3, evaluate_every: int = 1000, n_gpu: int = 1, learning_rate: float = 1e-5, epsilon: float = 1e-08, weight_decay: float = 0.0, num_warmup_steps: int = 100, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, save_dir: str = "../saved_models/dpr-tutorial", query_encoder_save_dir: str = "lm1", passage_encoder_save_dir: str = "lm2")
|
||||||
|
```
|
||||||
|
|
||||||
|
train a DensePassageRetrieval model
|
||||||
|
|
||||||
|
**Arguments**:
|
||||||
|
|
||||||
|
- `data_dir`: Directory where training file, dev file and test file are present
|
||||||
|
- `train_filename`: training filename
|
||||||
|
- `dev_filename`: development set filename, file to be used by model in eval step of training
|
||||||
|
- `test_filename`: test set filename, file to be used by model in test step after training
|
||||||
|
- `batch_size`: total number of samples in 1 batch of data
|
||||||
|
- `embed_title`: whether to concatenate passage title with each passage. The default setting in official DPR embeds passage title with the corresponding passage
|
||||||
|
- `num_hard_negatives`: number of hard negative passages(passages which are very similar(high score by BM25) to query but do not contain the answer
|
||||||
|
- `num_negatives`: number of negative passages(any random passage from dataset which do not contain answer to query)
|
||||||
|
- `n_epochs`: number of epochs to train the model on
|
||||||
|
- `evaluate_every`: number of training steps after evaluation is run
|
||||||
|
- `n_gpu`: number of gpus to train on
|
||||||
|
- `learning_rate`: learning rate of optimizer
|
||||||
|
- `epsilon`: epsilon parameter of optimizer
|
||||||
|
- `weight_decay`: weight decay parameter of optimizer
|
||||||
|
- `grad_acc_steps`: number of steps to accumulate gradient over before back-propagation is done
|
||||||
|
- `optimizer_name`: what optimizer to use (default: TransformersAdamW)
|
||||||
|
- `num_warmup_steps`: number of warmup steps
|
||||||
|
- `optimizer_correct_bias`: Whether to correct bias in optimizer
|
||||||
|
- `save_dir`: directory where models are saved
|
||||||
|
- `query_encoder_save_dir`: directory inside save_dir where query_encoder model files are saved
|
||||||
|
- `passage_encoder_save_dir`: directory inside save_dir where passage_encoder model files are saved
|
||||||
|
|
||||||
<a name="dense.EmbeddingRetriever"></a>
|
<a name="dense.EmbeddingRetriever"></a>
|
||||||
## EmbeddingRetriever
|
## EmbeddingRetriever
|
||||||
|
|
||||||
@ -271,7 +317,7 @@ that are most relevant to the query.
|
|||||||
#### eval
|
#### eval
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False) -> dict
|
| eval(label_index: str = "label", doc_index: str = "eval_document", label_origin: str = "gold_label", top_k: int = 10, open_domain: bool = False, return_preds: bool = False) -> dict
|
||||||
```
|
```
|
||||||
|
|
||||||
Performs evaluation on the Retriever.
|
Performs evaluation on the Retriever.
|
||||||
@ -293,4 +339,6 @@ documents a higher rank.
|
|||||||
contained in the retrieved docs (common approach in open-domain QA).
|
contained in the retrieved docs (common approach in open-domain QA).
|
||||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||||
are within ids explicitly stated in the labels.
|
are within ids explicitly stated in the labels.
|
||||||
|
- `return_preds`: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||||
|
contains the keys "predictions" and "metrics".
|
||||||
|
|
||||||
|
|||||||
@ -49,6 +49,19 @@ class DensePassageRetriever(BaseRetriever):
|
|||||||
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
Init the Retriever incl. the two encoder models from a local or remote model checkpoint.
|
||||||
The checkpoint format matches huggingface transformers' model format
|
The checkpoint format matches huggingface transformers' model format
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# remote model from FAIR
|
||||||
|
DensePassageRetriever(document_store=your_doc_store,
|
||||||
|
> query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||||
|
> passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base")
|
||||||
|
# or from local path
|
||||||
|
DensePassageRetriever(document_store=your_doc_store,
|
||||||
|
> query_embedding_model="model_directory/question-encoder",
|
||||||
|
> passage_embedding_model="model_directory/context-encoder")
|
||||||
|
```
|
||||||
|
|
||||||
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
:param document_store: An instance of DocumentStore from which to retrieve documents.
|
||||||
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
|
:param query_embedding_model: Local path or remote name of question encoder checkpoint. The format equals the
|
||||||
one used by hugging-face transformers' modelhub models
|
one used by hugging-face transformers' modelhub models
|
||||||
|
|||||||
@ -25,29 +25,31 @@ class ElasticsearchRetriever(BaseRetriever):
|
|||||||
names must match with the filters dict supplied in self.retrieve().
|
names must match with the filters dict supplied in self.retrieve().
|
||||||
::
|
::
|
||||||
|
|
||||||
An example custom_query:
|
**An example custom_query:**
|
||||||
|
```python
|
||||||
{
|
{
|
||||||
"size": 10,
|
> "size": 10,
|
||||||
"query": {
|
> "query": {
|
||||||
"bool": {
|
> "bool": {
|
||||||
"should": [{"multi_match": {
|
> "should": [{"multi_match": {
|
||||||
"query": "${question}", // mandatory $question placeholder
|
> "query": "${question}", // mandatory $question placeholder
|
||||||
"type": "most_fields",
|
> "type": "most_fields",
|
||||||
"fields": ["text", "title"]}}],
|
> "fields": ["text", "title"]}}],
|
||||||
"filter": [ // optional custom filters
|
> "filter": [ // optional custom filters
|
||||||
{"terms": {"year": "${years}"}},
|
> {"terms": {"year": "${years}"}},
|
||||||
{"terms": {"quarter": "${quarters}"}},
|
> {"terms": {"quarter": "${quarters}"}},
|
||||||
{"range": {"date": {"gte": "${date}"}}}
|
> {"range": {"date": {"gte": "${date}"}}}
|
||||||
],
|
> ],
|
||||||
|
> }
|
||||||
}
|
> },
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
```
|
||||||
|
|
||||||
For this custom_query, a sample retrieve() could be:
|
**For this custom_query, a sample retrieve() could be:**
|
||||||
::
|
```python
|
||||||
self.retrieve(query="Why did the revenue increase?",
|
self.retrieve(query="Why did the revenue increase?",
|
||||||
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
> filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
self.document_store: ElasticsearchDocumentStore = document_store
|
self.document_store: ElasticsearchDocumentStore = document_store
|
||||||
self.custom_query = custom_query
|
self.custom_query = custom_query
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user