Clean API docs and increase coverage (#621)

* Fix docstrings

* Fix docstrings

* docstrings for retrievers and docstores

* Clean and add more docstrings
This commit is contained in:
Branden Chan 2020-11-27 17:17:58 +01:00 committed by GitHub
parent fa55de2fab
commit 9fbd845ef3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 167 additions and 8 deletions

View File

@ -79,9 +79,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
If set to False, an error is raised if the document ID of the document being
added already exists.
:param refresh_type: Type of ES refresh used to control when changes made by a request (e.g. bulk) are made visible to search.
Values:
- 'wait_for' => continue only after changes are visible (slow, but safe)
- 'false' => continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion)
If set to 'wait_for', continue only after changes are visible (slow, but safe).
If set to 'false', continue directly (fast, but sometimes unintuitive behaviour when docs are not immediately available after ingestion).
More info at https://www.elastic.co/guide/en/elasticsearch/reference/6.8/docs-refresh.html
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
@ -220,6 +219,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
}
def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
index = index or self.index
documents = self.get_documents_by_id([id], index=index)
if documents:
@ -228,6 +228,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return None
def get_documents_by_id(self, ids: List[str], index=None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
query = {"query": {"ids": {"values": ids}}}
result = self.client.search(index=index, body=query)["hits"]["hits"]
@ -298,6 +299,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
if index and not self.client.indices.exists(index=index):
self._create_label_index(index)
@ -317,10 +319,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
"""
body = {"doc": meta}
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index
body: dict = {"query": {"bool": {}}}
@ -343,6 +351,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return count
def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
return self.get_document_count(index=index)
def get_all_documents(
@ -372,12 +383,18 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return documents
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index
result = self.get_all_documents_in_index(index=index, filters=filters)
labels = [Label.from_dict(hit["_source"]) for hit in result]
return labels
def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
"""
Return all documents in a specific index in the document store
"""
body = {
"query": {
"bool": {
@ -409,6 +426,15 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
custom_query: Optional[str] = None,
index: Optional[str] = None,
) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query as defined by the BM25 algorithm.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if index is None:
index = self.index
@ -483,6 +509,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
if index is None:
index = self.index
@ -572,6 +609,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return document
def describe_documents(self, index=None):
"""
Return a summary of the documents in the document store
"""
if index is None:
index = self.index
docs = self.get_all_documents(index)

View File

@ -104,6 +104,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
@ -229,6 +230,9 @@ class FAISSDocumentStore(SQLDocumentStore):
self.faiss_index.train(embeddings)
def delete_all_documents(self, index=None):
"""
Delete all documents from the document store.
"""
index = index or self.index
self.faiss_index.reset()
super().delete_all_documents(index=index)

View File

@ -47,6 +47,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
self.indexes[index][document.id] = document
def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None):
"""Write annotation labels into document store."""
index = index or self.label_index
label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
@ -55,6 +56,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
self.indexes[index][label_id] = label
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
index = index or self.index
documents = self.get_documents_by_id([id], index=index)
if documents:
@ -63,6 +65,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
return None
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
documents = [self.indexes[index][id] for id in ids]
return documents
@ -74,6 +77,18 @@ class InMemoryDocumentStore(BaseDocumentStore):
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: Index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
from numpy import dot
from numpy.linalg import norm
@ -137,13 +152,19 @@ class InMemoryDocumentStore(BaseDocumentStore):
self.indexes[index][doc.id].embedding = emb
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
documents = self.get_all_documents(index=index, filters=filters)
return len(documents)
def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
index = index or self.label_index
return len(self.indexes[index].items())
def get_all_documents(
self,
index: Optional[str] = None,
@ -187,6 +208,9 @@ class InMemoryDocumentStore(BaseDocumentStore):
return filtered_documents
def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index
if filters:

View File

@ -89,11 +89,13 @@ class SQLDocumentStore(BaseDocumentStore):
self.update_existing_documents = update_existing_documents
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(DocumentORM.id.in_(ids), DocumentORM.index == index).all()
documents = [self._convert_sql_row_to_document(row) for row in results]
@ -101,6 +103,7 @@ class SQLDocumentStore(BaseDocumentStore):
return documents
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index
results = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids),
@ -138,6 +141,9 @@ class SQLDocumentStore(BaseDocumentStore):
return documents
def get_all_labels(self, index=None, filters: Optional[dict] = None):
"""
Return all labels in the document store
"""
index = index or self.label_index
label_rows = self.session.query(LabelORM).filter_by(index=index).all()
labels = [self._convert_sql_row_to_label(row) for row in label_rows]
@ -182,6 +188,7 @@ class SQLDocumentStore(BaseDocumentStore):
raise ex
def write_labels(self, labels, index=None):
"""Write annotation labels into document store."""
labels = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels]
index = index or self.label_index
@ -221,6 +228,9 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
"""
self.session.query(MetaORM).filter_by(document_id=id).delete()
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
for m in meta_orms:
@ -244,6 +254,9 @@ class SQLDocumentStore(BaseDocumentStore):
self.write_labels(labels, index=label_index)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the number of documents in the document store.
"""
index = index or self.index
query = self.session.query(DocumentORM).filter_by(index=index)
@ -256,6 +269,9 @@ class SQLDocumentStore(BaseDocumentStore):
return count
def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store
"""
index = index or self.index
return self.session.query(LabelORM).filter_by(index=index).count()

View File

@ -41,6 +41,11 @@ class PDFToTextConverter(BaseConverter):
super().__init__(remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages)
def convert(self, file_path: Path, meta: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
"""
Extract text from a .pdf file.
:param file_path: Path to the .pdf file you want to convert
"""
pages = self._read_pdf(file_path, layout=False)

View File

@ -3,6 +3,9 @@ from typing import List, Dict, Any
class BasePreProcessor:
def process(self, document: dict) -> List[dict]:
"""
Perform document cleaning and splitting. Takes a single document as input and returns a list of documents.
"""
cleaned_document = self.clean(document)
split_documents = self.split(cleaned_document)
return split_documents

View File

@ -2,6 +2,10 @@ import re
def clean_wiki_text(text: str) -> str:
"""
Clean wikipedia text by removing multiple new lines, removing extremely short lines,
adding paragraph breaks and removing empty paragraphs
"""
# get rid of multiple new lines
while "\n\n" in text:
text = text.replace("\n\n", "\n")

View File

@ -52,6 +52,10 @@ class PreProcessor(BasePreProcessor):
self.split_respect_sentence_boundary = split_respect_sentence_boundary
def clean(self, document: dict) -> dict:
"""
Perform document cleaning on a single document and return a single document. This method will deal with whitespaces, headers, footers
and empty lines. Its exact functionality is defined by the parameters passed into PreProcessor.__init__().
"""
text = document["text"]
if self.clean_header_footer:
text = self._find_and_remove_header_footer(
@ -74,6 +78,10 @@ class PreProcessor(BasePreProcessor):
return document
def split(self, document: dict) -> List[dict]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
with different strides. It can also respect sectence boundaries. Its exact functionality is defined by
the parameters passed into PreProcessor.__init__(). Takes a single document as input and returns a list of documents. """
if not self.split_by:
return [document]

View File

@ -11,10 +11,7 @@ class TransformersReader(BaseReader):
Transformer based model for extractive Question Answering using the HuggingFace's transformers framework
(https://github.com/huggingface/transformers).
While the underlying model can vary (BERT, Roberta, DistilBERT ...), the interface remains the same.
| With the reader, you can:
- directly get predictions via predict()
With this reader, you can directly get predictions via predict()
"""
def __init__(

View File

@ -29,6 +29,7 @@ class BaseRetriever(ABC):
pass
def timing(self, fn):
"""Wrapper method used to time functions. """
@wraps(fn)
def wrapper(*args, **kwargs):
if "retrieve_time" not in self.__dict__:

View File

@ -127,6 +127,15 @@ class DensePassageRetriever(BaseRetriever):
self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if index is None:
index = self.document_store.index
query_emb = self.embed_queries(texts=[query])
@ -305,6 +314,12 @@ class DensePassageRetriever(BaseRetriever):
self.processor.save(Path(save_dir))
def save(self, save_dir: Union[Path, str]):
"""
Save DensePassageRetriever to the specified directory.
:param save_dir: Directory to save to.
:return: None
"""
save_dir = Path(save_dir)
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
save_dir = str(save_dir)
@ -323,6 +338,9 @@ class DensePassageRetriever(BaseRetriever):
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product",
):
"""
Load DensePassageRetriever from the specified directory.
"""
load_dir = Path(load_dir)
dpr = cls(
@ -401,6 +419,15 @@ class EmbeddingRetriever(BaseRetriever):
raise NotImplementedError
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if index is None:
index = self.document_store.index
query_emb = self.embed(texts=[query])

View File

@ -55,6 +55,15 @@ class ElasticsearchRetriever(BaseRetriever):
self.custom_query = custom_query
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if index is None:
index = self.document_store.index
@ -69,6 +78,15 @@ class ElasticsearchFilterOnlyRetriever(ElasticsearchRetriever):
"""
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if index is None:
index = self.document_store.index
documents = self.document_store.query(query=None, filters=filters, top_k=top_k,
@ -132,6 +150,15 @@ class TfidfRetriever(BaseRetriever):
return indices_and_scores
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
"""
Scan through documents in DocumentStore and return a small number documents
that are most relevant to the query.
:param query: The query
:param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
:param top_k: How many documents to return per query.
:param index: The name of the index in the DocumentStore from which to retrieve documents
"""
if filters:
raise NotImplementedError("Filters are not implemented in TfidfRetriever.")
if index:
@ -165,6 +192,9 @@ class TfidfRetriever(BaseRetriever):
return documents
def fit(self):
"""
Performing training on this class according to the TF-IDF algorithm.
"""
if not self.paragraphs or len(self.paragraphs) == 0:
self.paragraphs = self._get_all_paragraphs()
if not self.paragraphs or len(self.paragraphs) == 0: