mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-28 23:48:53 +00:00
Clean API docs and increase coverage (#621)
* Fix docstrings * Fix docstrings * docstrings for retrievers and docstores * Clean and add more docstrings
This commit is contained in:
parent
fa55de2fab
commit
9fbd845ef3
@ -79,9 +79,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
If set to False, an error is raised if the document ID of the document being
|
||||
added already exists.
|
||||
:param refresh_type: Type of ES refresh used to control when changes made by a request (e.g. bulk) are made visible to search.
|
||||
Values:
|
||||
- 'wait_for' => continue only after changes are visible (slow, but safe)
|
||||
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion)
|
||||
If set to 'wait_for', continue only after changes are visible (slow, but safe).
|
||||
If set to 'false', continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion).
|
||||
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
|
||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
||||
@ -220,6 +219,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
}
|
||||
|
||||
def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
index = index or self.index
|
||||
documents = self.get_documents_by_id([id], index=index)
|
||||
if documents:
|
||||
@ -228,6 +228,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
return None
|
||||
|
||||
def get_documents_by_id(self, ids: List[str], index=None) -> List[Document]:
|
||||
"""Fetch documents by specifying a list of text id strings"""
|
||||
index = index or self.index
|
||||
query = {"query": {"ids": {"values": ids}}}
|
||||
result = self.client.search(index=index, body=query)["hits"]["hits"]
|
||||
@ -298,6 +299,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||
|
||||
def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None):
|
||||
"""Write annotation labels into document store."""
|
||||
index = index or self.label_index
|
||||
if index and not self.client.indices.exists(index=index):
|
||||
self._create_label_index(index)
|
||||
@ -317,10 +319,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
body = {"doc": meta}
|
||||
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of documents in the document store.
|
||||
"""
|
||||
index = index or self.index
|
||||
|
||||
body: dict = {"query": {"bool": {}}}
|
||||
@ -343,6 +351,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
return count
|
||||
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of labels in the document store
|
||||
"""
|
||||
return self.get_document_count(index=index)
|
||||
|
||||
def get_all_documents(
|
||||
@ -372,12 +383,18 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
return documents
|
||||
|
||||
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
|
||||
"""
|
||||
Return all labels in the document store
|
||||
"""
|
||||
index = index or self.label_index
|
||||
result = self.get_all_documents_in_index(index=index, filters=filters)
|
||||
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
||||
return labels
|
||||
|
||||
def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
|
||||
"""
|
||||
Return all documents in a specific index in the document store
|
||||
"""
|
||||
body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
@ -409,6 +426,15 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
custom_query: Optional[str] = None,
|
||||
index: Optional[str] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query as defined by the BM25 algorithm.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
|
||||
if index is None:
|
||||
index = self.index
|
||||
@ -483,6 +509,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None,
|
||||
return_embedding: Optional[bool] = None) -> List[Document]:
|
||||
"""
|
||||
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
|
||||
|
||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||
:param filters: Optional filters to narrow down the search space.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param top_k: How many documents to return
|
||||
:param index: Index name for storing the docs and metadata
|
||||
:param return_embedding: To return document embedding
|
||||
:return:
|
||||
"""
|
||||
if index is None:
|
||||
index = self.index
|
||||
|
||||
@ -572,6 +609,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
return document
|
||||
|
||||
def describe_documents(self, index=None):
|
||||
"""
|
||||
Return a summary of the documents in the document store
|
||||
"""
|
||||
if index is None:
|
||||
index = self.index
|
||||
docs = self.get_all_documents(index)
|
||||
|
||||
@ -104,6 +104,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
Add new documents to the DocumentStore.
|
||||
|
||||
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
|
||||
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
|
||||
:param index: (SQL) index name for storing the docs and metadata
|
||||
@ -229,6 +230,9 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
self.faiss_index.train(embeddings)
|
||||
|
||||
def delete_all_documents(self, index=None):
|
||||
"""
|
||||
Delete all documents from the document store.
|
||||
"""
|
||||
index = index or self.index
|
||||
self.faiss_index.reset()
|
||||
super().delete_all_documents(index=index)
|
||||
|
||||
@ -47,6 +47,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.indexes[index][document.id] = document
|
||||
|
||||
def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
|
||||
"""Write annotation labels into document store."""
|
||||
index = index or self.label_index
|
||||
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
|
||||
@ -55,6 +56,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.indexes[index][label_id] = label
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
index = index or self.index
|
||||
documents = self.get_documents_by_id([id], index=index)
|
||||
if documents:
|
||||
@ -63,6 +65,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
return None
|
||||
|
||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
|
||||
"""Fetch documents by specifying a list of text id strings"""
|
||||
index = index or self.index
|
||||
documents = [self.indexes[index][id] for id in ids]
|
||||
return documents
|
||||
@ -74,6 +77,18 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
index: Optional[str] = None,
|
||||
return_embedding: Optional[bool] = None) -> List[Document]:
|
||||
|
||||
"""
|
||||
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
|
||||
|
||||
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||
:param filters: Optional filters to narrow down the search space.
|
||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
:param top_k: How many documents to return
|
||||
:param index: Index name for storing the docs and metadata
|
||||
:param return_embedding: To return document embedding
|
||||
:return:
|
||||
"""
|
||||
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
|
||||
@ -137,13 +152,19 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
self.indexes[index][doc.id].embedding = emb
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of documents in the document store.
|
||||
"""
|
||||
documents = self.get_all_documents(index=index, filters=filters)
|
||||
return len(documents)
|
||||
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of labels in the document store
|
||||
"""
|
||||
index = index or self.label_index
|
||||
return len(self.indexes[index].items())
|
||||
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
@ -187,6 +208,9 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
return filtered_documents
|
||||
|
||||
def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
|
||||
"""
|
||||
Return all labels in the document store
|
||||
"""
|
||||
index = index or self.label_index
|
||||
|
||||
if filters:
|
||||
|
||||
@ -89,11 +89,13 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.update_existing_documents = update_existing_documents
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
"""Fetch a document by specifying its text id string"""
|
||||
documents = self.get_documents_by_id([id], index)
|
||||
document = documents[0] if documents else None
|
||||
return document
|
||||
|
||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
|
||||
"""Fetch documents by specifying a list of text id strings"""
|
||||
index = index or self.index
|
||||
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
|
||||
documents = [self._convert_sql_row_to_document(row) for row in results]
|
||||
@ -101,6 +103,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
return documents
|
||||
|
||||
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
|
||||
"""Fetch documents by specifying a list of text vector id strings"""
|
||||
index = index or self.index
|
||||
results = self.session.query(DocumentORM).filter(
|
||||
DocumentORM.vector_id.in_(vector_ids),
|
||||
@ -138,6 +141,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
return documents
|
||||
|
||||
def get_all_labels(self, index=None, filters: Optional[dict] = None):
|
||||
"""
|
||||
Return all labels in the document store
|
||||
"""
|
||||
index = index or self.label_index
|
||||
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
|
||||
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
|
||||
@ -182,6 +188,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
raise ex
|
||||
|
||||
def write_labels(self, labels, index=None):
|
||||
"""Write annotation labels into document store."""
|
||||
|
||||
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
|
||||
index = index or self.label_index
|
||||
@ -221,6 +228,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.session.commit()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
self.session.query(MetaORM).filter_by(document_id=id).delete()
|
||||
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
||||
for m in meta_orms:
|
||||
@ -244,6 +254,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.write_labels(labels, index=label_index)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of documents in the document store.
|
||||
"""
|
||||
index = index or self.index
|
||||
query = self.session.query(DocumentORM).filter_by(index=index)
|
||||
|
||||
@ -256,6 +269,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
return count
|
||||
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of labels in the document store
|
||||
"""
|
||||
index = index or self.index
|
||||
return self.session.query(LabelORM).filter_by(index=index).count()
|
||||
|
||||
|
||||
@ -41,6 +41,11 @@ class PDFToTextConverter(BaseConverter):
|
||||
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
|
||||
|
||||
def convert(self, file_path: Path, meta: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract text from a .pdf file.
|
||||
|
||||
:param file_path: Path to the .pdf file you want to convert
|
||||
"""
|
||||
|
||||
pages = self._read_pdf(file_path, layout=False)
|
||||
|
||||
|
||||
@ -3,6 +3,9 @@ from typing import List, Dict, Any
|
||||
|
||||
class BasePreProcessor:
|
||||
def process(self, document: dict) -> List[dict]:
|
||||
"""
|
||||
Perform document cleaning and splitting. Takes a single document as input and returns a list of documents.
|
||||
"""
|
||||
cleaned_document = self.clean(document)
|
||||
split_documents = self.split(cleaned_document)
|
||||
return split_documents
|
||||
|
||||
@ -2,6 +2,10 @@ import re
|
||||
|
||||
|
||||
def clean_wiki_text(text: str) -> str:
|
||||
"""
|
||||
Clean wikipedia text by removing multiple new lines, removing extremely short lines,
|
||||
adding paragraph breaks and removing empty paragraphs
|
||||
"""
|
||||
# get rid of multiple new lines
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
|
||||
@ -52,6 +52,10 @@ class PreProcessor(BasePreProcessor):
|
||||
self.split_respect_sentence_boundary = split_respect_sentence_boundary
|
||||
|
||||
def clean(self, document: dict) -> dict:
|
||||
"""
|
||||
Perform document cleaning on a single document and return a single document. This method will deal with whitespaces, headers, footers
|
||||
and empty lines. Its exact functionality is defined by the parameters passed into PreProcessor.__init__().
|
||||
"""
|
||||
text = document["text"]
|
||||
if self.clean_header_footer:
|
||||
text = self._find_and_remove_header_footer(
|
||||
@ -74,6 +78,10 @@ class PreProcessor(BasePreProcessor):
|
||||
return document
|
||||
|
||||
def split(self, document: dict) -> List[dict]:
|
||||
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
|
||||
with different strides. It can also respect sectence boundaries. Its exact functionality is defined by
|
||||
the parameters passed into PreProcessor.__init__(). Takes a single document as input and returns a list of documents. """
|
||||
|
||||
if not self.split_by:
|
||||
return [document]
|
||||
|
||||
|
||||
@ -11,10 +11,7 @@ class TransformersReader(BaseReader):
|
||||
Transformer based model for extractive Question Answering using the HuggingFace's transformers framework
|
||||
(https://github.com/huggingface/transformers).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
|
||||
|
||||
| With the reader, you can:
|
||||
|
||||
- directly get predictions via predict()
|
||||
With this reader, you can directly get predictions via predict()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -29,6 +29,7 @@ class BaseRetriever(ABC):
|
||||
pass
|
||||
|
||||
def timing(self, fn):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "retrieve_time" not in self.__dict__:
|
||||
|
||||
@ -127,6 +127,15 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if index is None:
|
||||
index = self.document_store.index
|
||||
query_emb = self.embed_queries(texts=[query])
|
||||
@ -305,6 +314,12 @@ class DensePassageRetriever(BaseRetriever):
|
||||
self.processor.save(Path(save_dir))
|
||||
|
||||
def save(self, save_dir: Union[Path, str]):
|
||||
"""
|
||||
Save DensePassageRetriever to the specified directory.
|
||||
|
||||
:param save_dir: Directory to save to.
|
||||
:return: None
|
||||
"""
|
||||
save_dir = Path(save_dir)
|
||||
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
|
||||
save_dir = str(save_dir)
|
||||
@ -323,6 +338,9 @@ class DensePassageRetriever(BaseRetriever):
|
||||
use_fast_tokenizers: bool = True,
|
||||
similarity_function: str = "dot_product",
|
||||
):
|
||||
"""
|
||||
Load DensePassageRetriever from the specified directory.
|
||||
"""
|
||||
|
||||
load_dir = Path(load_dir)
|
||||
dpr = cls(
|
||||
@ -401,6 +419,15 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if index is None:
|
||||
index = self.document_store.index
|
||||
query_emb = self.embed(texts=[query])
|
||||
|
||||
@ -55,6 +55,15 @@ class ElasticsearchRetriever(BaseRetriever):
|
||||
self.custom_query = custom_query
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if index is None:
|
||||
index = self.document_store.index
|
||||
|
||||
@ -69,6 +78,15 @@ class ElasticsearchFilterOnlyRetriever(ElasticsearchRetriever):
|
||||
"""
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if index is None:
|
||||
index = self.document_store.index
|
||||
documents = self.document_store.query(query=None, filters=filters, top_k=top_k,
|
||||
@ -132,6 +150,15 @@ class TfidfRetriever(BaseRetriever):
|
||||
return indices_and_scores
|
||||
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
"""
|
||||
Scan through documents in DocumentStore and return a small number documents
|
||||
that are most relevant to the query.
|
||||
|
||||
:param query: The query
|
||||
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
|
||||
:param top_k: How many documents to return per query.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
"""
|
||||
if filters:
|
||||
raise NotImplementedError("Filters are not implemented in TfidfRetriever.")
|
||||
if index:
|
||||
@ -165,6 +192,9 @@ class TfidfRetriever(BaseRetriever):
|
||||
return documents
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
Performing training on this class according to the TF-IDF algorithm.
|
||||
"""
|
||||
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||
self.paragraphs = self._get_all_paragraphs()
|
||||
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user