mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 23:04:02 +00:00
Add batch_size and generators to document stores. (#733)
* Add batch update of embeddings in document stores * Resolve merge conflict * Remove document ordering dependency in tests * Adjust index buffer size for tests * Adjust ES Scroll Slice * Use generator for document store pagination * Add pagination for InMemoryDocumentStore * Fix missing index parameter in FAISS update_embeddings() * Fix FAISS update_embeddings() * Update FAISS tests * Update eval tests * Revert code formatting change * Fix document count in FAISS update embeddings * Fix vector_ids reset in SQLDocumentStore * Update doctrings * Update docstring
This commit is contained in:
parent
0b583b8972
commit
337376c81d
@ -3,7 +3,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any, Generator
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from elasticsearch.helpers import bulk, scan
|
from elasticsearch.helpers import bulk, scan
|
||||||
from elasticsearch.exceptions import RequestError
|
from elasticsearch.exceptions import RequestError
|
||||||
@ -13,6 +13,7 @@ from scipy.special import expit
|
|||||||
from haystack.document_store.base import BaseDocumentStore
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
from haystack.retriever.base import BaseRetriever
|
from haystack.retriever.base import BaseRetriever
|
||||||
|
from haystack.utils import get_batches_from_generator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -232,8 +233,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
|
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
|
def write_documents(
|
||||||
batch_size: Optional[int] = None):
|
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Indexes documents for later queries in Elasticsearch.
|
Indexes documents for later queries in Elasticsearch.
|
||||||
|
|
||||||
@ -253,7 +255,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
should be changed to what you have set for self.text_field and self.name_field.
|
should be changed to what you have set for self.text_field and self.name_field.
|
||||||
:param index: Elasticsearch index where the documents should be indexed. If not supplied, self.index will be used.
|
:param index: Elasticsearch index where the documents should be indexed. If not supplied, self.index will be used.
|
||||||
:param batch_size: Number of documents that are passed to Elasticsearch's bulk function at a time.
|
:param batch_size: Number of documents that are passed to Elasticsearch's bulk function at a time.
|
||||||
If `None`, all documents will be passed to bulk at once.
|
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -298,22 +299,21 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
_doc.pop("meta")
|
_doc.pop("meta")
|
||||||
documents_to_index.append(_doc)
|
documents_to_index.append(_doc)
|
||||||
|
|
||||||
if batch_size is not None:
|
# Pass batch_size number of documents to bulk
|
||||||
# Pass batch_size number of documents to bulk
|
if len(documents_to_index) % batch_size == 0:
|
||||||
if len(documents_to_index) % batch_size == 0:
|
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||||
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
|
documents_to_index = []
|
||||||
documents_to_index = []
|
|
||||||
|
|
||||||
if documents_to_index:
|
if documents_to_index:
|
||||||
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
|
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||||
|
|
||||||
def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None,
|
def write_labels(
|
||||||
batch_size: Optional[int] = None):
|
self, labels: Union[List[Label], List[dict]], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
):
|
||||||
"""Write annotation labels into document store.
|
"""Write annotation labels into document store.
|
||||||
|
|
||||||
:param labels: A list of Python dictionaries or a list of Haystack Label objects.
|
:param labels: A list of Python dictionaries or a list of Haystack Label objects.
|
||||||
:param batch_size: Number of labels that are passed to Elasticsearch's bulk function at a time.
|
:param batch_size: Number of labels that are passed to Elasticsearch's bulk function at a time.
|
||||||
If `None`, all labels will be passed to bulk at once.
|
|
||||||
"""
|
"""
|
||||||
index = index or self.label_index
|
index = index or self.label_index
|
||||||
if index and not self.client.indices.exists(index=index):
|
if index and not self.client.indices.exists(index=index):
|
||||||
@ -339,11 +339,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
labels_to_index.append(_label)
|
labels_to_index.append(_label)
|
||||||
|
|
||||||
if batch_size is not None:
|
# Pass batch_size number of labels to bulk
|
||||||
# Pass batch_size number of labels to bulk
|
if len(labels_to_index) % batch_size == 0:
|
||||||
if len(labels_to_index) % batch_size == 0:
|
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||||
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
|
labels_to_index = []
|
||||||
labels_to_index = []
|
|
||||||
|
|
||||||
if labels_to_index:
|
if labels_to_index:
|
||||||
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
|
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
|
||||||
@ -387,10 +386,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
return self.get_document_count(index=index)
|
return self.get_document_count(index=index)
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self,
|
self,
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
filters: Optional[Dict[str, List[str]]] = None,
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
return_embedding: Optional[bool] = None
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Get documents from the document store.
|
Get documents from the document store.
|
||||||
@ -400,28 +400,62 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
:param filters: Optional filters to narrow down the documents to return.
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
:param return_embedding: Whether to return the document embeddings.
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
"""
|
||||||
|
result = self.get_all_documents_generator(
|
||||||
|
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
|
||||||
|
)
|
||||||
|
documents = list(result)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_all_documents_generator(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
|
"""
|
||||||
|
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
|
||||||
|
document store and yielded as individual documents. This method can be used to iteratively process
|
||||||
|
a large number of documents without having to load all documents in memory.
|
||||||
|
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
index = self.index
|
index = self.index
|
||||||
|
|
||||||
result = self.get_all_documents_in_index(index=index, filters=filters)
|
|
||||||
if return_embedding is None:
|
if return_embedding is None:
|
||||||
return_embedding = self.return_embedding
|
return_embedding = self.return_embedding
|
||||||
documents = [self._convert_es_hit_to_document(hit, return_embedding=return_embedding) for hit in result]
|
|
||||||
|
|
||||||
return documents
|
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
|
||||||
|
for hit in result:
|
||||||
|
document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding)
|
||||||
|
yield document
|
||||||
|
|
||||||
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
|
def get_all_labels(
|
||||||
|
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, batch_size: int = 10_000
|
||||||
|
) -> List[Label]:
|
||||||
"""
|
"""
|
||||||
Return all labels in the document store
|
Return all labels in the document store
|
||||||
"""
|
"""
|
||||||
index = index or self.label_index
|
index = index or self.label_index
|
||||||
result = self.get_all_documents_in_index(index=index, filters=filters)
|
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
|
||||||
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
labels = [Label.from_dict(hit["_source"]) for hit in result]
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
|
def _get_all_documents_in_index(
|
||||||
|
self,
|
||||||
|
index: str,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[dict, None, None]:
|
||||||
"""
|
"""
|
||||||
Return all documents in a specific index in the document store
|
Return all documents in a specific index in the document store
|
||||||
"""
|
"""
|
||||||
@ -444,9 +478,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
body["query"]["bool"]["filter"] = filter_clause
|
body["query"]["bool"]["filter"] = filter_clause
|
||||||
result = list(scan(self.client, query=body, index=index))
|
|
||||||
|
|
||||||
return result
|
result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d")
|
||||||
|
yield from result
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
@ -683,13 +717,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
}
|
}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
|
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
"""
|
"""
|
||||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||||
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
||||||
|
|
||||||
:param retriever: Retriever
|
:param retriever: Retriever to use to update the embeddings.
|
||||||
:param index: Index name to update
|
:param index: Index name to update
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if index is None:
|
if index is None:
|
||||||
@ -698,26 +733,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
if not self.embedding_field:
|
if not self.embedding_field:
|
||||||
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
|
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
|
||||||
|
|
||||||
# TODO Index embeddings every X batches to avoid OOM for huge document collections
|
logger.info(f"Updating embeddings for {self.get_document_count(index=index)} docs ...")
|
||||||
docs = self.get_all_documents(index)
|
|
||||||
logger.info(f"Updating embeddings for {len(docs)} docs ...")
|
|
||||||
embeddings = retriever.embed_passages(docs) # type: ignore
|
|
||||||
assert len(docs) == len(embeddings)
|
|
||||||
|
|
||||||
if embeddings[0].shape[0] != self.embedding_dim:
|
result = self.get_all_documents_generator(index, batch_size=batch_size)
|
||||||
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
for document_batch in get_batches_from_generator(result, batch_size):
|
||||||
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
|
if len(document_batch) == 0:
|
||||||
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
|
break
|
||||||
doc_updates = []
|
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||||
for doc, emb in zip(docs, embeddings):
|
assert len(document_batch) == len(embeddings)
|
||||||
update = {"_op_type": "update",
|
|
||||||
"_index": index,
|
|
||||||
"_id": doc.id,
|
|
||||||
"doc": {self.embedding_field: emb.tolist()},
|
|
||||||
}
|
|
||||||
doc_updates.append(update)
|
|
||||||
|
|
||||||
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
if embeddings[0].shape[0] != self.embedding_dim:
|
||||||
|
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
|
||||||
|
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
|
||||||
|
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
|
||||||
|
doc_updates = []
|
||||||
|
for doc, emb in zip(document_batch, embeddings):
|
||||||
|
update = {"_op_type": "update",
|
||||||
|
"_index": index,
|
||||||
|
"_id": doc.id,
|
||||||
|
"doc": {self.embedding_field: emb.tolist()},
|
||||||
|
}
|
||||||
|
doc_updates.append(update)
|
||||||
|
|
||||||
|
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
||||||
|
|
||||||
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
|
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List, Optional, Dict
|
from typing import Union, List, Optional, Dict, Generator
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_store.sql import SQLDocumentStore
|
from haystack.document_store.sql import SQLDocumentStore
|
||||||
from haystack.retriever.base import BaseRetriever
|
from haystack.retriever.base import BaseRetriever
|
||||||
|
from haystack.utils import get_batches_from_generator
|
||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
|
|
||||||
if platform != 'win32' and platform != 'cygwin':
|
if platform != 'win32' and platform != 'cygwin':
|
||||||
@ -34,7 +34,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sql_url: str = "sqlite:///",
|
sql_url: str = "sqlite:///",
|
||||||
index_buffer_size: int = 10_000,
|
|
||||||
vector_dim: int = 768,
|
vector_dim: int = 768,
|
||||||
faiss_index_factory_str: str = "Flat",
|
faiss_index_factory_str: str = "Flat",
|
||||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
||||||
@ -47,8 +46,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
"""
|
"""
|
||||||
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
|
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
|
||||||
deployment, Postgres is recommended.
|
deployment, Postgres is recommended.
|
||||||
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
|
|
||||||
smaller chunks to reduce memory footprint.
|
|
||||||
:param vector_dim: the embedding vector size.
|
:param vector_dim: the embedding vector size.
|
||||||
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
|
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
|
||||||
The type is determined from the given string following the conventions
|
The type is determined from the given string following the conventions
|
||||||
@ -85,7 +82,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
if "ivf" in faiss_index_factory_str.lower(): # enable reconstruction of vectors for inverted index
|
if "ivf" in faiss_index_factory_str.lower(): # enable reconstruction of vectors for inverted index
|
||||||
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
||||||
|
|
||||||
self.index_buffer_size = index_buffer_size
|
|
||||||
self.return_embedding = return_embedding
|
self.return_embedding = return_embedding
|
||||||
if similarity == "dot_product":
|
if similarity == "dot_product":
|
||||||
self.similarity = similarity
|
self.similarity = similarity
|
||||||
@ -111,13 +107,16 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
index = faiss.index_factory(vector_dim, index_factory, metric_type)
|
index = faiss.index_factory(vector_dim, index_factory, metric_type)
|
||||||
return index
|
return index
|
||||||
|
|
||||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
def write_documents(
|
||||||
|
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Add new documents to the DocumentStore.
|
Add new documents to the DocumentStore.
|
||||||
|
|
||||||
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
|
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
|
||||||
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
|
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
|
||||||
:param index: (SQL) index name for storing the docs and metadata
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# vector index
|
# vector index
|
||||||
@ -136,15 +135,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
|
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
|
||||||
"Please call `update_embeddings` method to repopulate `faiss_index`")
|
"Please call `update_embeddings` method to repopulate `faiss_index`")
|
||||||
|
|
||||||
for i in range(0, len(document_objects), self.index_buffer_size):
|
vector_id = self.faiss_index.ntotal
|
||||||
vector_id = self.faiss_index.ntotal
|
for i in range(0, len(document_objects), batch_size):
|
||||||
if add_vectors:
|
if add_vectors:
|
||||||
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
|
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
|
||||||
embeddings = np.array(embeddings, dtype="float32")
|
embeddings = np.array(embeddings, dtype="float32")
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
|
|
||||||
docs_to_write_in_sql = []
|
docs_to_write_in_sql = []
|
||||||
for doc in document_objects[i : i + self.index_buffer_size]:
|
for doc in document_objects[i: i + batch_size]:
|
||||||
meta = doc.meta
|
meta = doc.meta
|
||||||
if add_vectors:
|
if add_vectors:
|
||||||
meta["vector_id"] = vector_id
|
meta["vector_id"] = vector_id
|
||||||
@ -158,13 +157,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
self.index: "embedding",
|
self.index: "embedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
|
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
"""
|
"""
|
||||||
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||||
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
||||||
|
|
||||||
:param retriever: Retriever to use to get embeddings for text
|
:param retriever: Retriever to use to get embeddings for text
|
||||||
:param index: (SQL) index name for storing the docs and metadata
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if not self.faiss_index:
|
if not self.faiss_index:
|
||||||
@ -172,55 +172,85 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
|
|
||||||
# Faiss does not support update in existing index data so clear all existing data in it
|
# Faiss does not support update in existing index data so clear all existing data in it
|
||||||
self.faiss_index.reset()
|
self.faiss_index.reset()
|
||||||
|
self.reset_vector_ids(index=index)
|
||||||
|
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
documents = self.get_all_documents(index=index)
|
|
||||||
|
|
||||||
if len(documents) == 0:
|
document_count = self.get_document_count(index=index)
|
||||||
|
if document_count == 0:
|
||||||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||||
return
|
return
|
||||||
|
|
||||||
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
logger.info(f"Updating embeddings for {document_count} docs...")
|
||||||
self.faiss_index.reset()
|
vector_id = self.faiss_index.ntotal
|
||||||
|
|
||||||
logger.info(f"Updating embeddings for {len(documents)} docs...")
|
result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
|
||||||
embeddings = retriever.embed_passages(documents) # type: ignore
|
batched_documents = get_batches_from_generator(result, batch_size)
|
||||||
assert len(documents) == len(embeddings)
|
with tqdm(total=document_count) as progress_bar:
|
||||||
for i, doc in enumerate(documents):
|
for document_batch in batched_documents:
|
||||||
doc.embedding = embeddings[i]
|
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||||
|
assert len(document_batch) == len(embeddings)
|
||||||
|
|
||||||
logger.info("Indexing embeddings and updating vectors_ids...")
|
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||||
for i in tqdm(range(0, len(documents), self.index_buffer_size)):
|
self.faiss_index.add(embeddings_to_index)
|
||||||
vector_id_map = {}
|
|
||||||
vector_id = self.faiss_index.ntotal
|
|
||||||
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
|
|
||||||
embeddings = np.array(embeddings, dtype="float32")
|
|
||||||
self.faiss_index.add(embeddings)
|
|
||||||
|
|
||||||
for doc in documents[i: i + self.index_buffer_size]:
|
vector_id_map = {}
|
||||||
vector_id_map[doc.id] = vector_id
|
for doc in document_batch:
|
||||||
vector_id += 1
|
vector_id_map[doc.id] = vector_id
|
||||||
self.update_vector_ids(vector_id_map, index=index)
|
vector_id += 1
|
||||||
|
self.update_vector_ids(vector_id_map, index=index)
|
||||||
|
progress_bar.update(batch_size)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self,
|
self,
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
filters: Optional[Dict[str, List[str]]] = None,
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
return_embedding: Optional[bool] = None
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
result = self.get_all_documents_generator(
|
||||||
|
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
|
||||||
|
)
|
||||||
|
documents = list(result)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_all_documents_generator(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
"""
|
"""
|
||||||
Get documents from the document store.
|
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
|
||||||
|
document store and yielded as individual documents. This method can be used to iteratively process
|
||||||
|
a large number of documents without having to load all documents in memory.
|
||||||
|
|
||||||
:param index: Name of the index to get the documents from. If None, the
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
DocumentStore's default index (self.index) will be used.
|
DocumentStore's default index (self.index) will be used.
|
||||||
:param filters: Optional filters to narrow down the documents to return.
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
:param return_embedding: Whether to return the document embeddings.
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
"""
|
"""
|
||||||
documents = super(FAISSDocumentStore, self).get_all_documents(index=index, filters=filters)
|
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
|
||||||
|
index=index, filters=filters, batch_size=batch_size
|
||||||
|
)
|
||||||
if return_embedding is None:
|
if return_embedding is None:
|
||||||
return_embedding = self.return_embedding
|
return_embedding = self.return_embedding
|
||||||
if return_embedding:
|
|
||||||
|
for doc in documents:
|
||||||
|
if return_embedding:
|
||||||
|
if doc.meta and doc.meta.get("vector_id") is not None:
|
||||||
|
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
def get_documents_by_id(
|
||||||
|
self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
) -> List[Document]:
|
||||||
|
documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index)
|
||||||
|
if self.return_embedding:
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if doc.meta and doc.meta.get("vector_id") is not None:
|
if doc.meta and doc.meta.get("vector_id") is not None:
|
||||||
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
||||||
@ -309,7 +339,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
cls,
|
cls,
|
||||||
faiss_file_path: Union[str, Path],
|
faiss_file_path: Union[str, Path],
|
||||||
sql_url: str,
|
sql_url: str,
|
||||||
index_buffer_size: int = 10_000,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load a saved FAISS index from a file and connect to the SQL database.
|
Load a saved FAISS index from a file and connect to the SQL database.
|
||||||
@ -318,8 +347,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
|
|
||||||
:param faiss_file_path: Stored FAISS index file. Can be created via calling `save()`
|
:param faiss_file_path: Stored FAISS index file. Can be created via calling `save()`
|
||||||
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
|
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
|
||||||
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
|
|
||||||
smaller chunks to reduce memory footprint.
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -328,7 +355,6 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
return cls(
|
return cls(
|
||||||
faiss_index=faiss_index,
|
faiss_index=faiss_index,
|
||||||
sql_url=sql_url,
|
sql_url=sql_url,
|
||||||
index_buffer_size=index_buffer_size,
|
|
||||||
vector_dim=faiss_index.d
|
vector_dim=faiss_index.d
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union, Generator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@ -183,13 +183,26 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
return len(self.indexes[index].items())
|
return len(self.indexes[index].items())
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self,
|
self,
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
filters: Optional[Dict[str, List[str]]] = None,
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
return_embedding: Optional[bool] = None
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
result = self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding)
|
||||||
|
documents = list(result)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_all_documents_generator(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
"""
|
"""
|
||||||
Get documents from the document store.
|
Get all documents from the document store. The methods returns a Python Generator that yields individual
|
||||||
|
documents.
|
||||||
|
|
||||||
:param index: Name of the index to get the documents from. If None, the
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
DocumentStore's default index (self.index) will be used.
|
DocumentStore's default index (self.index) will be used.
|
||||||
@ -222,7 +235,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
else:
|
else:
|
||||||
filtered_documents = documents
|
filtered_documents = documents
|
||||||
|
|
||||||
return filtered_documents
|
yield from filtered_documents
|
||||||
|
|
||||||
def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
|
def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Union, List, Optional
|
from typing import Any, Dict, Union, List, Optional, Generator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean, Text
|
from sqlalchemy import and_, func, create_engine, Column, Integer, String, DateTime, ForeignKey, Boolean, Text, text
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import relationship, sessionmaker
|
from sqlalchemy.orm import relationship, sessionmaker
|
||||||
from sqlalchemy.sql import case
|
from sqlalchemy.sql import case, null
|
||||||
|
|
||||||
from haystack.document_store.base import BaseDocumentStore
|
|
||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -73,7 +73,6 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
index: str = "document",
|
index: str = "document",
|
||||||
label_index: str = "label",
|
label_index: str = "label",
|
||||||
update_existing_documents: bool = False,
|
update_existing_documents: bool = False,
|
||||||
batch_size: int = 32766,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
||||||
@ -87,11 +86,6 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
If set to False, an error is raised if the document ID of the document being
|
If set to False, an error is raised if the document ID of the document being
|
||||||
added already exists. Using this parameter could cause performance degradation
|
added already exists. Using this parameter could cause performance degradation
|
||||||
for document insertion.
|
for document insertion.
|
||||||
:param batch_size: Maximum number of variable parameters and rows fetched in a single SQL statement,
|
|
||||||
to help in excessive memory allocations. In most methods of the DocumentStore this means number of documents fetched in one query.
|
|
||||||
Tune this value based on host machine main memory.
|
|
||||||
For SQLite versions prior to v3.32.0 keep this value less than 1000.
|
|
||||||
More info refer: https://www.sqlite.org/limits.html
|
|
||||||
"""
|
"""
|
||||||
engine = create_engine(url)
|
engine = create_engine(url)
|
||||||
ORMBase.metadata.create_all(engine)
|
ORMBase.metadata.create_all(engine)
|
||||||
@ -102,7 +96,6 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
self.update_existing_documents = update_existing_documents
|
self.update_existing_documents = update_existing_documents
|
||||||
if getattr(self, "similarity", None) is None:
|
if getattr(self, "similarity", None) is None:
|
||||||
self.similarity = None
|
self.similarity = None
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||||
"""Fetch a document by specifying its text id string"""
|
"""Fetch a document by specifying its text id string"""
|
||||||
@ -110,14 +103,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
document = documents[0] if documents else None
|
document = documents[0] if documents else None
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
|
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000) -> List[Document]:
|
||||||
"""Fetch documents by specifying a list of text id strings"""
|
"""Fetch documents by specifying a list of text id strings"""
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for i in range(0, len(ids), self.batch_size):
|
for i in range(0, len(ids), batch_size):
|
||||||
query = self.session.query(DocumentORM).filter(
|
query = self.session.query(DocumentORM).filter(
|
||||||
DocumentORM.id.in_(ids[i: i + self.batch_size]),
|
DocumentORM.id.in_(ids[i: i + batch_size]),
|
||||||
DocumentORM.index == index
|
DocumentORM.index == index
|
||||||
)
|
)
|
||||||
for row in query.all():
|
for row in query.all():
|
||||||
@ -125,14 +118,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
|
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
"""Fetch documents by specifying a list of text vector id strings"""
|
"""Fetch documents by specifying a list of text vector id strings"""
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
for i in range(0, len(vector_ids), self.batch_size):
|
for i in range(0, len(vector_ids), batch_size):
|
||||||
query = self.session.query(DocumentORM).filter(
|
query = self.session.query(DocumentORM).filter(
|
||||||
DocumentORM.vector_id.in_(vector_ids[i: i + self.batch_size]),
|
DocumentORM.vector_id.in_(vector_ids[i: i + batch_size]),
|
||||||
DocumentORM.index == index
|
DocumentORM.index == index
|
||||||
)
|
)
|
||||||
for row in query.all():
|
for row in query.all():
|
||||||
@ -142,13 +135,25 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
return sorted_documents
|
return sorted_documents
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self,
|
self,
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
filters: Optional[Dict[str, List[str]]] = None,
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
return_embedding: Optional[bool] = None
|
return_embedding: Optional[bool] = None,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
documents = list(self.get_all_documents_generator(index=index, filters=filters))
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_all_documents_generator(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
"""
|
"""
|
||||||
Get documents from the document store.
|
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
|
||||||
|
document store and yielded as individual documents. This method can be used to iteratively process
|
||||||
|
a large number of documents without having to load all documents in memory.
|
||||||
|
|
||||||
:param index: Name of the index to get the documents from. If None, the
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
DocumentStore's default index (self.index) will be used.
|
DocumentStore's default index (self.index) will be used.
|
||||||
@ -177,26 +182,31 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
documents_map = {}
|
documents_map = {}
|
||||||
for row in documents_query.all():
|
for i, row in enumerate(self._windowed_query(documents_query, DocumentORM.id, batch_size), start=1):
|
||||||
documents_map[row.id] = Document(
|
documents_map[row.id] = Document(
|
||||||
id=row.id,
|
id=row.id,
|
||||||
text=row.text,
|
text=row.text,
|
||||||
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
|
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
|
||||||
)
|
)
|
||||||
|
if i % batch_size == 0:
|
||||||
|
documents_map = self._get_documents_meta(documents_map)
|
||||||
|
yield from documents_map.values()
|
||||||
|
documents_map = {}
|
||||||
|
if documents_map:
|
||||||
|
documents_map = self._get_documents_meta(documents_map)
|
||||||
|
yield from documents_map.values()
|
||||||
|
|
||||||
for doc_ids in self.chunked_iterable(documents_map.keys(), size=self.batch_size):
|
def _get_documents_meta(self, documents_map):
|
||||||
meta_query = self.session.query(
|
doc_ids = documents_map.keys()
|
||||||
MetaORM.document_id,
|
meta_query = self.session.query(
|
||||||
MetaORM.name,
|
MetaORM.document_id,
|
||||||
MetaORM.value
|
MetaORM.name,
|
||||||
).filter(MetaORM.document_id.in_(doc_ids))
|
MetaORM.value
|
||||||
|
).filter(MetaORM.document_id.in_(doc_ids))
|
||||||
|
|
||||||
for row in meta_query.all():
|
for row in meta_query.all():
|
||||||
if documents_map[row.document_id].meta is None:
|
documents_map[row.document_id].meta[row.name] = row.value # type: ignore
|
||||||
documents_map[row.document_id].meta = {}
|
return documents_map
|
||||||
documents_map[row.document_id].meta[row.name] = row.value # type: ignore
|
|
||||||
|
|
||||||
return list(documents_map.values())
|
|
||||||
|
|
||||||
def get_all_labels(self, index=None, filters: Optional[dict] = None):
|
def get_all_labels(self, index=None, filters: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
@ -209,17 +219,20 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
def write_documents(
|
||||||
|
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Indexes documents for later queries.
|
Indexes documents for later queries.
|
||||||
|
|
||||||
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
|
||||||
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
|
||||||
Optionally: Include meta data via {"text": "<the-actual-text>",
|
Optionally: Include meta data via {"text": "<the-actual-text>",
|
||||||
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
|
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
|
||||||
It can be used for filtering and is accessible in the responses of the Finder.
|
It can be used for filtering and is accessible in the responses of the Finder.
|
||||||
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
|
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
|
||||||
documents for evaluation can be indexed in a separate index than the documents for search.
|
documents for evaluation can be indexed in a separate index than the documents for search.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
@ -233,10 +246,10 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
else:
|
else:
|
||||||
document_objects = documents
|
document_objects = documents
|
||||||
|
|
||||||
for i in range(0, len(document_objects), self.batch_size):
|
for i in range(0, len(document_objects), batch_size):
|
||||||
for doc in document_objects[i: i + self.batch_size]:
|
for doc in document_objects[i: i + batch_size]:
|
||||||
meta_fields = doc.meta or {}
|
meta_fields = doc.meta or {}
|
||||||
vector_id = meta_fields.get("vector_id")
|
vector_id = meta_fields.pop("vector_id", None)
|
||||||
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
||||||
if self.update_existing_documents:
|
if self.update_existing_documents:
|
||||||
@ -275,15 +288,16 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
self.session.add(label_orm)
|
self.session.add(label_orm)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None):
|
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
"""
|
"""
|
||||||
Update vector_ids for given document_ids.
|
Update vector_ids for given document_ids.
|
||||||
|
|
||||||
:param vector_id_map: dict containing mapping of document_id -> vector_id.
|
:param vector_id_map: dict containing mapping of document_id -> vector_id.
|
||||||
:param index: filter documents by the optional index attribute for documents in database.
|
:param index: filter documents by the optional index attribute for documents in database.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
"""
|
"""
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
for chunk_map in self.chunked_dict(vector_id_map, size=self.batch_size):
|
for chunk_map in self.chunked_dict(vector_id_map, size=batch_size):
|
||||||
self.session.query(DocumentORM).filter(
|
self.session.query(DocumentORM).filter(
|
||||||
DocumentORM.id.in_(chunk_map),
|
DocumentORM.id.in_(chunk_map),
|
||||||
DocumentORM.index == index
|
DocumentORM.index == index
|
||||||
@ -300,6 +314,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
self.session.rollback()
|
self.session.rollback()
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
def reset_vector_ids(self, index: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Set vector IDs for all documents as None
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||||
"""
|
"""
|
||||||
Update the metadata dictionary of a document by specifying its string id
|
Update the metadata dictionary of a document by specifying its string id
|
||||||
@ -392,16 +414,55 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
session.commit()
|
session.commit()
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
# Refer: https://alexwlchan.net/2018/12/iterating-in-fixed-size-chunks/
|
|
||||||
def chunked_iterable(self, iterable, size):
|
|
||||||
it = iter(iterable)
|
|
||||||
while True:
|
|
||||||
chunk = tuple(itertools.islice(it, size))
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def chunked_dict(self, dictionary, size):
|
def chunked_dict(self, dictionary, size):
|
||||||
it = iter(dictionary)
|
it = iter(dictionary)
|
||||||
for i in range(0, len(dictionary), size):
|
for i in range(0, len(dictionary), size):
|
||||||
yield {k: dictionary[k] for k in itertools.islice(it, size)}
|
yield {k: dictionary[k] for k in itertools.islice(it, size)}
|
||||||
|
|
||||||
|
def _column_windows(self, session, column, windowsize):
|
||||||
|
"""Return a series of WHERE clauses against
|
||||||
|
a given column that break it into windows.
|
||||||
|
|
||||||
|
Result is an iterable of tuples, consisting of
|
||||||
|
((start, end), whereclause), where (start, end) are the ids.
|
||||||
|
|
||||||
|
The code is taken from: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
|
||||||
|
"""
|
||||||
|
|
||||||
|
def int_for_range(start_id, end_id):
|
||||||
|
if end_id:
|
||||||
|
return and_(
|
||||||
|
column >= start_id,
|
||||||
|
column < end_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return column >= start_id
|
||||||
|
|
||||||
|
q = session.query(
|
||||||
|
column,
|
||||||
|
func.row_number(). \
|
||||||
|
over(order_by=column). \
|
||||||
|
label('rownum')
|
||||||
|
). \
|
||||||
|
from_self(column)
|
||||||
|
if windowsize > 1:
|
||||||
|
q = q.filter(text("rownum %% %d=1" % windowsize))
|
||||||
|
|
||||||
|
intervals = [id for id, in q]
|
||||||
|
|
||||||
|
while intervals:
|
||||||
|
start = intervals.pop(0)
|
||||||
|
if intervals:
|
||||||
|
end = intervals[0]
|
||||||
|
else:
|
||||||
|
end = None
|
||||||
|
yield int_for_range(start, end)
|
||||||
|
|
||||||
|
def _windowed_query(self, q, column, windowsize):
|
||||||
|
""""Break a Query into windows on a given column."""
|
||||||
|
|
||||||
|
for whereclause in self._column_windows(
|
||||||
|
q.session,
|
||||||
|
column, windowsize):
|
||||||
|
for row in q.filter(whereclause).order_by(column):
|
||||||
|
yield row
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from itertools import islice
|
||||||
import logging
|
import logging
|
||||||
import pprint
|
import pprint
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -113,3 +114,14 @@ def convert_labels_to_squad(labels_file: str):
|
|||||||
|
|
||||||
with open("labels_in_squad_format.json", "w+") as outfile:
|
with open("labels_in_squad_format.json", "w+") as outfile:
|
||||||
json.dump(labels_in_squad_format, outfile)
|
json.dump(labels_in_squad_format, outfile)
|
||||||
|
|
||||||
|
|
||||||
|
def get_batches_from_generator(iterable, n):
|
||||||
|
"""
|
||||||
|
Batch elements of an iterable into fixed-length chunks or blocks.
|
||||||
|
"""
|
||||||
|
it = iter(iterable)
|
||||||
|
x = tuple(islice(it, n))
|
||||||
|
while x:
|
||||||
|
yield x
|
||||||
|
x = tuple(islice(it, n))
|
||||||
@ -258,7 +258,9 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
|||||||
)
|
)
|
||||||
elif document_store_type == "faiss":
|
elif document_store_type == "faiss":
|
||||||
document_store = FAISSDocumentStore(
|
document_store = FAISSDocumentStore(
|
||||||
sql_url="sqlite://", return_embedding=True, embedding_field=embedding_field
|
sql_url="sqlite://",
|
||||||
|
return_embedding=True,
|
||||||
|
embedding_field=embedding_field,
|
||||||
)
|
)
|
||||||
return document_store
|
return document_store
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -85,6 +85,20 @@ def test_get_document_count(document_store):
|
|||||||
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
def test_get_all_documents_generator(document_store):
|
||||||
|
documents = [
|
||||||
|
{"text": "text1", "id": "1", "meta_field_for_count": "a"},
|
||||||
|
{"text": "text2", "id": "2", "meta_field_for_count": "b"},
|
||||||
|
{"text": "text3", "id": "3", "meta_field_for_count": "b"},
|
||||||
|
{"text": "text4", "id": "4", "meta_field_for_count": "b"},
|
||||||
|
{"text": "text5", "id": "5", "meta_field_for_count": "b"},
|
||||||
|
]
|
||||||
|
|
||||||
|
document_store.write_documents(documents)
|
||||||
|
assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
|
||||||
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||||
@ -168,6 +182,41 @@ def test_document_with_embeddings(document_store):
|
|||||||
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
|
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
||||||
|
def test_update_embeddings(document_store, retriever):
|
||||||
|
documents = []
|
||||||
|
for i in range(23):
|
||||||
|
documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
||||||
|
documents.append({"text": "text_0", "id": "23", "meta_field": "value_0"})
|
||||||
|
|
||||||
|
document_store.write_documents(documents, index="haystack_test_1")
|
||||||
|
document_store.update_embeddings(retriever, index="haystack_test_1")
|
||||||
|
documents = document_store.get_all_documents(index="haystack_test_1", return_embedding=True)
|
||||||
|
assert len(documents) == 24
|
||||||
|
for doc in documents:
|
||||||
|
assert type(doc.embedding) is np.ndarray
|
||||||
|
|
||||||
|
documents = document_store.get_all_documents(
|
||||||
|
index="haystack_test_1",
|
||||||
|
filters={"meta_field": ["value_0", "value_23"]},
|
||||||
|
return_embedding=True,
|
||||||
|
)
|
||||||
|
np.testing.assert_array_equal(documents[0].embedding, documents[1].embedding)
|
||||||
|
|
||||||
|
documents = document_store.get_all_documents(
|
||||||
|
index="haystack_test_1",
|
||||||
|
filters={"meta_field": ["value_0", "value_10"]},
|
||||||
|
return_embedding=True,
|
||||||
|
)
|
||||||
|
np.testing.assert_raises(
|
||||||
|
AssertionError,
|
||||||
|
np.testing.assert_array_equal,
|
||||||
|
documents[0].embedding,
|
||||||
|
documents[1].embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
def test_delete_documents(document_store_with_docs):
|
def test_delete_documents(document_store_with_docs):
|
||||||
|
|||||||
@ -18,40 +18,47 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
|||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
|
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
|
||||||
meta={"name": "0"}
|
meta={"name": "0"},
|
||||||
|
id="1",
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
|
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
|
||||||
|
id="2",
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="""Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are""",
|
text="""Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are""",
|
||||||
meta={"name": "1"}
|
meta={"name": "1"},
|
||||||
|
id="3",
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="""The Dothraki vocabulary was created by David J. Peterson well in advance of the adaptation. HBO hired the Language Creatio""",
|
text="""The Dothraki vocabulary was created by David J. Peterson well in advance of the adaptation. HBO hired the Language Creatio""",
|
||||||
meta={"name": "2"}
|
meta={"name": "2"},
|
||||||
|
id="4",
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="""The title of the episode refers to the Great Sept of Baelor, the main religious building in King's Landing, where the episode's pivotal scene takes place. In the world created by George R. R. Martin""",
|
text="""The title of the episode refers to the Great Sept of Baelor, the main religious building in King's Landing, where the episode's pivotal scene takes place. In the world created by George R. R. Martin""",
|
||||||
meta={}
|
meta={},
|
||||||
|
id="5",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
document_store.return_embedding = return_embedding
|
document_store.return_embedding = return_embedding
|
||||||
document_store.write_documents(documents)
|
document_store.write_documents(documents)
|
||||||
document_store.update_embeddings(retriever=retriever)
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
docs_with_emb = document_store.get_all_documents()
|
|
||||||
|
|
||||||
if return_embedding is True:
|
if return_embedding is True:
|
||||||
assert (len(docs_with_emb[0].embedding) == 768)
|
doc_1 = document_store.get_document_by_id("1")
|
||||||
assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001)
|
assert (len(doc_1.embedding) == 768)
|
||||||
assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001)
|
assert (abs(doc_1.embedding[0] - (-0.3063)) < 0.001)
|
||||||
assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001)
|
doc_2 = document_store.get_document_by_id("2")
|
||||||
assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001)
|
assert (abs(doc_2.embedding[0] - (-0.3914)) < 0.001)
|
||||||
assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001)
|
doc_3 = document_store.get_document_by_id("3")
|
||||||
|
assert (abs(doc_3.embedding[0] - (-0.2470)) < 0.001)
|
||||||
|
doc_4 = document_store.get_document_by_id("4")
|
||||||
|
assert (abs(doc_4.embedding[0] - (-0.0802)) < 0.001)
|
||||||
|
doc_5 = document_store.get_document_by_id("5")
|
||||||
|
assert (abs(doc_5.embedding[0] - (-0.0551)) < 0.001)
|
||||||
|
|
||||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")
|
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")
|
||||||
|
|
||||||
|
|||||||
@ -1,107 +1,72 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from haystack.document_store.base import BaseDocumentStore
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack.document_store.memory import InMemoryDocumentStore
|
|
||||||
from haystack.preprocessor.preprocessor import PreProcessor
|
from haystack.preprocessor.preprocessor import PreProcessor
|
||||||
from haystack.finder import Finder
|
from haystack.finder import Finder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [None, 20])
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
def test_add_eval_data(document_store):
|
def test_add_eval_data(document_store, batch_size):
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.add_eval_data(
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
filename="samples/squad/small.json",
|
||||||
document_store.add_eval_data(filename="samples/squad/small.json", doc_index="test_eval_document", label_index="test_feedback")
|
doc_index="haystack_test_eval_document",
|
||||||
|
label_index="haystack_test_feedback",
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 87
|
assert document_store.get_document_count(index="haystack_test_eval_document") == 87
|
||||||
assert document_store.get_label_count(index="test_feedback") == 1214
|
assert document_store.get_label_count(index="haystack_test_feedback") == 1214
|
||||||
|
|
||||||
# test documents
|
# test documents
|
||||||
docs = document_store.get_all_documents(index="test_eval_document")
|
docs = document_store.get_all_documents(index="haystack_test_eval_document", filters={"name": ["Normans"]})
|
||||||
assert docs[0].text[:10] == "The Norman"
|
|
||||||
assert docs[0].meta["name"] == "Normans"
|
assert docs[0].meta["name"] == "Normans"
|
||||||
assert len(docs[0].meta.keys()) == 1
|
assert len(docs[0].meta.keys()) == 1
|
||||||
|
|
||||||
# test labels
|
# test labels
|
||||||
labels = document_store.get_all_labels(index="test_feedback")
|
labels = document_store.get_all_labels(index="haystack_test_feedback")
|
||||||
assert labels[0].answer == "France"
|
label = None
|
||||||
assert labels[0].no_answer == False
|
for l in labels:
|
||||||
assert labels[0].is_correct_answer == True
|
if l.question == "In what country is Normandy located?":
|
||||||
assert labels[0].is_correct_document == True
|
label = l
|
||||||
assert labels[0].question == 'In what country is Normandy located?'
|
break
|
||||||
assert labels[0].origin == "gold_label"
|
assert label.answer == "France"
|
||||||
assert labels[0].offset_start_in_doc == 159
|
assert label.no_answer == False
|
||||||
|
assert label.is_correct_answer == True
|
||||||
|
assert label.is_correct_document == True
|
||||||
|
assert label.question == "In what country is Normandy located?"
|
||||||
|
assert label.origin == "gold_label"
|
||||||
|
assert label.offset_start_in_doc == 159
|
||||||
|
|
||||||
# check combination
|
# check combination
|
||||||
assert labels[0].document_id == docs[0].id
|
doc = document_store.get_document_by_id(label.document_id, index="haystack_test_eval_document")
|
||||||
start = labels[0].offset_start_in_doc
|
start = label.offset_start_in_doc
|
||||||
end = start+len(labels[0].answer)
|
end = start + len(label.answer)
|
||||||
assert docs[0].text[start:end] == "France"
|
assert doc.text[start:end] == "France"
|
||||||
|
|
||||||
# clean up
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
|
||||||
def test_add_eval_data_batchwise(document_store):
|
|
||||||
# add eval data (SQUAD format in jsonl)
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
document_store.add_eval_data(filename="samples/squad/small.json",
|
|
||||||
doc_index="test_eval_document",
|
|
||||||
label_index="test_feedback",
|
|
||||||
batch_size=20)
|
|
||||||
|
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 87
|
|
||||||
assert document_store.get_label_count(index="test_feedback") == 1214
|
|
||||||
|
|
||||||
# test documents
|
|
||||||
docs = document_store.get_all_documents(index="test_eval_document")
|
|
||||||
assert docs[0].text[:10] == "The Norman"
|
|
||||||
assert docs[0].meta["name"] == "Normans"
|
|
||||||
assert len(docs[0].meta.keys()) == 1
|
|
||||||
|
|
||||||
# test labels
|
|
||||||
labels = document_store.get_all_labels(index="test_feedback")
|
|
||||||
assert labels[0].answer == "France"
|
|
||||||
assert labels[0].no_answer == False
|
|
||||||
assert labels[0].is_correct_answer == True
|
|
||||||
assert labels[0].is_correct_document == True
|
|
||||||
assert labels[0].question == 'In what country is Normandy located?'
|
|
||||||
assert labels[0].origin == "gold_label"
|
|
||||||
assert labels[0].offset_start_in_doc == 159
|
|
||||||
|
|
||||||
# check combination
|
|
||||||
assert labels[0].document_id == docs[0].id
|
|
||||||
start = labels[0].offset_start_in_doc
|
|
||||||
end = start+len(labels[0].answer)
|
|
||||||
assert docs[0].text[start:end] == "France"
|
|
||||||
|
|
||||||
# clean up
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||||
def test_eval_reader(reader, document_store: BaseDocumentStore):
|
def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.add_eval_data(
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
filename="samples/squad/tiny.json",
|
||||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
doc_index="haystack_test_eval_document",
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
label_index="haystack_test_feedback",
|
||||||
|
)
|
||||||
|
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||||
# eval reader
|
# eval reader
|
||||||
reader_eval_results = reader.eval(document_store=document_store, label_index="test_feedback",
|
reader_eval_results = reader.eval(
|
||||||
doc_index="test_eval_document", device="cpu")
|
document_store=document_store,
|
||||||
|
label_index="haystack_test_feedback",
|
||||||
|
doc_index="haystack_test_eval_document",
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
assert reader_eval_results["f1"] > 0.65
|
assert reader_eval_results["f1"] > 0.65
|
||||||
assert reader_eval_results["f1"] < 0.67
|
assert reader_eval_results["f1"] < 0.67
|
||||||
assert reader_eval_results["EM"] == 0.5
|
assert reader_eval_results["EM"] == 0.5
|
||||||
assert reader_eval_results["top_n_accuracy"] == 1.0
|
assert reader_eval_results["top_n_accuracy"] == 1.0
|
||||||
|
|
||||||
# clean up
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||||
@ -109,22 +74,22 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
|
|||||||
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
|
||||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
|
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.add_eval_data(
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
filename="samples/squad/tiny.json",
|
||||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
doc_index="haystack_test_eval_document",
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
label_index="haystack_test_feedback",
|
||||||
|
)
|
||||||
|
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||||
|
|
||||||
# eval retriever
|
# eval retriever
|
||||||
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
|
results = retriever.eval(
|
||||||
|
top_k=1, label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", open_domain=open_domain
|
||||||
|
)
|
||||||
assert results["recall"] == 1.0
|
assert results["recall"] == 1.0
|
||||||
assert results["mrr"] == 1.0
|
assert results["mrr"] == 1.0
|
||||||
if not open_domain:
|
if not open_domain:
|
||||||
assert results["map"] == 1.0
|
assert results["map"] == 1.0
|
||||||
|
|
||||||
# clean up
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||||
@ -134,13 +99,17 @@ def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
|
|||||||
finder = Finder(reader=reader, retriever=retriever)
|
finder = Finder(reader=reader, retriever=retriever)
|
||||||
|
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.add_eval_data(
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
filename="samples/squad/tiny.json",
|
||||||
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
|
doc_index="haystack_test_eval_document",
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 2
|
label_index="haystack_test_feedback",
|
||||||
|
)
|
||||||
|
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||||
|
|
||||||
# eval finder
|
# eval finder
|
||||||
results = finder.eval(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1, top_k_reader=5)
|
results = finder.eval(
|
||||||
|
label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", top_k_retriever=1, top_k_reader=5
|
||||||
|
)
|
||||||
assert results["retriever_recall"] == 1.0
|
assert results["retriever_recall"] == 1.0
|
||||||
assert results["retriever_map"] == 1.0
|
assert results["retriever_map"] == 1.0
|
||||||
assert abs(results["reader_topk_f1"] - 0.66666) < 0.001
|
assert abs(results["reader_topk_f1"] - 0.66666) < 0.001
|
||||||
@ -151,24 +120,19 @@ def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
|
|||||||
assert results["reader_top1_accuracy"] <= results["reader_topk_accuracy"]
|
assert results["reader_top1_accuracy"] <= results["reader_topk_accuracy"]
|
||||||
|
|
||||||
# batch eval finder
|
# batch eval finder
|
||||||
results_batch = finder.eval_batch(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1,
|
results_batch = finder.eval_batch(
|
||||||
top_k_reader=5)
|
label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", top_k_retriever=1, top_k_reader=5
|
||||||
|
)
|
||||||
assert results_batch["retriever_recall"] == 1.0
|
assert results_batch["retriever_recall"] == 1.0
|
||||||
assert results_batch["retriever_map"] == 1.0
|
assert results_batch["retriever_map"] == 1.0
|
||||||
assert results_batch["reader_top1_f1"] == results["reader_top1_f1"]
|
assert results_batch["reader_top1_f1"] == results["reader_top1_f1"]
|
||||||
assert results_batch["reader_top1_em"] == results["reader_top1_em"]
|
assert results_batch["reader_top1_em"] == results["reader_top1_em"]
|
||||||
assert results_batch["reader_topk_accuracy"] == results["reader_topk_accuracy"]
|
assert results_batch["reader_topk_accuracy"] == results["reader_topk_accuracy"]
|
||||||
|
|
||||||
# clean up
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
def test_eval_data_splitting(document_store):
|
def test_eval_data_split_word(document_store):
|
||||||
# splitting by word
|
# splitting by word
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
preprocessor = PreProcessor(
|
preprocessor = PreProcessor(
|
||||||
clean_empty_lines=False,
|
clean_empty_lines=False,
|
||||||
clean_whitespace=False,
|
clean_whitespace=False,
|
||||||
@ -176,22 +140,24 @@ def test_eval_data_splitting(document_store):
|
|||||||
split_by="word",
|
split_by="word",
|
||||||
split_length=4,
|
split_length=4,
|
||||||
split_overlap=0,
|
split_overlap=0,
|
||||||
split_respect_sentence_boundary=False
|
split_respect_sentence_boundary=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
document_store.add_eval_data(filename="samples/squad/tiny.json",
|
document_store.add_eval_data(
|
||||||
doc_index="test_eval_document",
|
filename="samples/squad/tiny.json",
|
||||||
label_index="test_feedback",
|
doc_index="haystack_test_eval_document",
|
||||||
preprocessor=preprocessor)
|
label_index="haystack_test_feedback",
|
||||||
labels = document_store.get_all_labels_aggregated(index="test_feedback")
|
preprocessor=preprocessor,
|
||||||
docs = document_store.get_all_documents(index="test_eval_document")
|
)
|
||||||
|
labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback")
|
||||||
|
docs = document_store.get_all_documents(index="haystack_test_eval_document")
|
||||||
assert len(docs) == 5
|
assert len(docs) == 5
|
||||||
assert len(set(labels[0].multiple_document_ids)) == 2
|
assert len(set(labels[0].multiple_document_ids)) == 2
|
||||||
|
|
||||||
# splitting by passage
|
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
|
||||||
|
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
def test_eval_data_split_passage(document_store):
|
||||||
|
# splitting by passage
|
||||||
preprocessor = PreProcessor(
|
preprocessor = PreProcessor(
|
||||||
clean_empty_lines=False,
|
clean_empty_lines=False,
|
||||||
clean_whitespace=False,
|
clean_whitespace=False,
|
||||||
@ -202,10 +168,12 @@ def test_eval_data_splitting(document_store):
|
|||||||
split_respect_sentence_boundary=False
|
split_respect_sentence_boundary=False
|
||||||
)
|
)
|
||||||
|
|
||||||
document_store.add_eval_data(filename="samples/squad/tiny_passages.json",
|
document_store.add_eval_data(
|
||||||
doc_index="test_eval_document",
|
filename="samples/squad/tiny_passages.json",
|
||||||
label_index="test_feedback",
|
doc_index="haystack_test_eval_document",
|
||||||
preprocessor=preprocessor)
|
label_index="haystack_test_feedback",
|
||||||
docs = document_store.get_all_documents(index="test_eval_document")
|
preprocessor=preprocessor,
|
||||||
|
)
|
||||||
|
docs = document_store.get_all_documents(index="haystack_test_eval_document")
|
||||||
assert len(docs) == 2
|
assert len(docs) == 2
|
||||||
assert len(docs[1].text) == 56
|
assert len(docs[1].text) == 56
|
||||||
@ -19,21 +19,6 @@ DOCUMENTS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def check_data_correctness(documents_indexed, documents_inserted):
|
|
||||||
# test if correct vector_ids are assigned
|
|
||||||
for i, doc in enumerate(documents_indexed):
|
|
||||||
assert doc.meta["vector_id"] == str(i)
|
|
||||||
|
|
||||||
# test if number of documents is correct
|
|
||||||
assert len(documents_indexed) == len(documents_inserted)
|
|
||||||
|
|
||||||
# test if two docs have same vector_is assigned
|
|
||||||
vector_ids = set()
|
|
||||||
for i, doc in enumerate(documents_indexed):
|
|
||||||
vector_ids.add(doc.meta["vector_id"])
|
|
||||||
assert len(vector_ids) == len(documents_inserted)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
def test_faiss_index_save_and_load(document_store):
|
def test_faiss_index_save_and_load(document_store):
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
@ -65,6 +50,7 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|||||||
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
||||||
|
|
||||||
documents_indexed = document_store.get_all_documents()
|
documents_indexed = document_store.get_all_documents()
|
||||||
|
assert len(documents_indexed) == len(DOCUMENTS)
|
||||||
|
|
||||||
# test if correct vectors are associated with docs
|
# test if correct vectors are associated with docs
|
||||||
for i, doc in enumerate(documents_indexed):
|
for i, doc in enumerate(documents_indexed):
|
||||||
@ -74,34 +60,26 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|||||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||||
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
||||||
|
|
||||||
# test document correctness
|
|
||||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
@pytest.mark.parametrize("batch_size", [4, 6])
|
||||||
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
|
def test_faiss_update_docs(document_store, retriever, batch_size):
|
||||||
# adjust buffer size
|
|
||||||
document_store.index_buffer_size = index_buffer_size
|
|
||||||
|
|
||||||
# initial write
|
# initial write
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
|
|
||||||
document_store.update_embeddings(retriever=retriever)
|
document_store.update_embeddings(retriever=retriever, batch_size=batch_size)
|
||||||
documents_indexed = document_store.get_all_documents()
|
documents_indexed = document_store.get_all_documents()
|
||||||
|
assert len(documents_indexed) == len(DOCUMENTS)
|
||||||
|
|
||||||
# test if correct vectors are associated with docs
|
# test if correct vectors are associated with docs
|
||||||
for i, doc in enumerate(documents_indexed):
|
for doc in documents_indexed:
|
||||||
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
|
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
|
||||||
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
|
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
|
||||||
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0]
|
||||||
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
||||||
assert np.allclose(updated_embedding, stored_emb, rtol=0.01)
|
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
|
||||||
|
|
||||||
# test document correctness
|
|
||||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@ -115,8 +93,7 @@ def test_faiss_update_with_empty_store(document_store, retriever):
|
|||||||
|
|
||||||
documents_indexed = document_store.get_all_documents()
|
documents_indexed = document_store.get_all_documents()
|
||||||
|
|
||||||
# test document correctness
|
assert len(documents_indexed) == len(DOCUMENTS)
|
||||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
|
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
|
||||||
@ -190,5 +167,8 @@ def test_faiss_passing_index_from_outside():
|
|||||||
document_store.write_documents(documents=DOCUMENTS, index="document")
|
document_store.write_documents(documents=DOCUMENTS, index="document")
|
||||||
documents_indexed = document_store.get_all_documents(index="document")
|
documents_indexed = document_store.get_all_documents(index="document")
|
||||||
|
|
||||||
# test document correctness
|
# test if vectors ids are associated with docs
|
||||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
for doc in documents_indexed:
|
||||||
|
assert 0 <= int(doc.meta["vector_id"]) <= 7
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user