Add batch_size and generators to document stores. (#733)

* Add batch update of embeddings in document stores

* Resolve merge conflict

* Remove document ordering dependency in tests

* Adjust index buffer size for tests

* Adjust ES Scroll Slice

* Use generator for document store pagination

* Add pagination for InMemoryDocumentStore

* Fix missing index parameter in FAISS update_embeddings()

* Fix FAISS update_embeddings()

* Update FAISS tests

* Update eval tests

* Revert code formatting change

* Fix document count in FAISS update embeddings

* Fix vector_ids reset in SQLDocumentStore

* Update doctrings

* Update docstring
This commit is contained in:
Tanay Soni 2021-01-21 16:00:08 +01:00 committed by GitHub
parent 0b583b8972
commit 337376c81d
10 changed files with 468 additions and 312 deletions

View File

@ -3,7 +3,7 @@ import logging
import time
from copy import deepcopy
from string import Template
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Union, Dict, Any, Generator
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan
from elasticsearch.exceptions import RequestError
@ -13,6 +13,7 @@ from scipy.special import expit
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_batches_from_generator
logger = logging.getLogger(__name__)
@ -232,8 +233,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
documents = [self._convert_es_hit_to_document(hit, return_embedding=self.return_embedding) for hit in result]
return documents
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
batch_size: Optional[int] = None):
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Indexes documents for later queries in Elasticsearch.
@ -253,7 +255,6 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
should be changed to what you have set for self.text_field and self.name_field.
:param index: Elasticsearch index where the documents should be indexed. If not supplied, self.index will be used.
:param batch_size: Number of documents that are passed to Elasticsearch's bulk function at a time.
If `None`, all documents will be passed to bulk at once.
:return: None
"""
@ -298,22 +299,21 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
_doc.pop("meta")
documents_to_index.append(_doc)
if batch_size is not None:
# Pass batch_size number of documents to bulk
if len(documents_to_index) % batch_size == 0:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
documents_to_index = []
# Pass batch_size number of documents to bulk
if len(documents_to_index) % batch_size == 0:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
documents_to_index = []
if documents_to_index:
bulk(self.client, documents_to_index, request_timeout=300, refresh=self.refresh_type)
def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None,
batch_size: Optional[int] = None):
def write_labels(
self, labels: Union[List[Label], List[dict]], index: Optional[str] = None, batch_size: int = 10_000
):
"""Write annotation labels into document store.
:param labels: A list of Python dictionaries or a list of Haystack Label objects.
:param batch_size: Number of labels that are passed to Elasticsearch's bulk function at a time.
If `None`, all labels will be passed to bulk at once.
"""
index = index or self.label_index
if index and not self.client.indices.exists(index=index):
@ -339,11 +339,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
labels_to_index.append(_label)
if batch_size is not None:
# Pass batch_size number of labels to bulk
if len(labels_to_index) % batch_size == 0:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
labels_to_index = []
# Pass batch_size number of labels to bulk
if len(labels_to_index) % batch_size == 0:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
labels_to_index = []
if labels_to_index:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type)
@ -387,10 +386,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return self.get_document_count(index=index)
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
"""
Get documents from the document store.
@ -400,28 +400,62 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
if index is None:
index = self.index
result = self.get_all_documents_in_index(index=index, filters=filters)
if return_embedding is None:
return_embedding = self.return_embedding
documents = [self._convert_es_hit_to_document(hit, return_embedding=return_embedding) for hit in result]
return documents
result = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
for hit in result:
document = self._convert_es_hit_to_document(hit, return_embedding=return_embedding)
yield document
def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
def get_all_labels(
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, batch_size: int = 10_000
) -> List[Label]:
"""
Return all labels in the document store
"""
index = index or self.label_index
result = self.get_all_documents_in_index(index=index, filters=filters)
result = list(self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size))
labels = [Label.from_dict(hit["_source"]) for hit in result]
return labels
def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]:
def _get_all_documents_in_index(
self,
index: str,
filters: Optional[Dict[str, List[str]]] = None,
batch_size: int = 10_000,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
"""
@ -444,9 +478,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
}
)
body["query"]["bool"]["filter"] = filter_clause
result = list(scan(self.client, query=body, index=index))
return result
result = scan(self.client, query=body, index=index, size=batch_size, scroll="1d")
yield from result
def query(
self,
@ -683,13 +717,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
}
return stats
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever
:param retriever: Retriever to use to update the embeddings.
:param index: Index name to update
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
if index is None:
@ -698,26 +733,29 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing ElasticsearchDocumentStore()")
# TODO Index embeddings every X batches to avoid OOM for huge document collections
docs = self.get_all_documents(index)
logger.info(f"Updating embeddings for {len(docs)} docs ...")
embeddings = retriever.embed_passages(docs) # type: ignore
assert len(docs) == len(embeddings)
logger.info(f"Updating embeddings for {self.get_document_count(index=index)} docs ...")
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(docs, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)
result = self.get_all_documents_generator(index, batch_size=batch_size)
for document_batch in get_batches_from_generator(result, batch_size):
if len(document_batch) == 0:
break
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
if embeddings[0].shape[0] != self.embedding_dim:
raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
"Specify the arg `embedding_dim` when initializing ElasticsearchDocumentStore()")
doc_updates = []
for doc, emb in zip(document_batch, embeddings):
update = {"_op_type": "update",
"_index": index,
"_id": doc.id,
"doc": {self.embedding_field: emb.tolist()},
}
doc_updates.append(update)
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
"""

View File

@ -1,14 +1,14 @@
import logging
from sys import platform
from pathlib import Path
from typing import Union, List, Optional, Dict
from typing import Union, List, Optional, Dict, Generator
from tqdm import tqdm
import numpy as np
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_batches_from_generator
from scipy.special import expit
if platform != 'win32' and platform != 'cygwin':
@ -34,7 +34,6 @@ class FAISSDocumentStore(SQLDocumentStore):
def __init__(
self,
sql_url: str = "sqlite:///",
index_buffer_size: int = 10_000,
vector_dim: int = 768,
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
@ -47,8 +46,6 @@ class FAISSDocumentStore(SQLDocumentStore):
"""
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
smaller chunks to reduce memory footprint.
:param vector_dim: the embedding vector size.
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
The type is determined from the given string following the conventions
@ -85,7 +82,6 @@ class FAISSDocumentStore(SQLDocumentStore):
if "ivf" in faiss_index_factory_str.lower(): # enable reconstruction of vectors for inverted index
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
self.index_buffer_size = index_buffer_size
self.return_embedding = return_embedding
if similarity == "dot_product":
self.similarity = similarity
@ -111,13 +107,16 @@ class FAISSDocumentStore(SQLDocumentStore):
index = faiss.index_factory(vector_dim, index_factory, metric_type)
return index
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in FAISS. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return:
"""
# vector index
@ -136,15 +135,15 @@ class FAISSDocumentStore(SQLDocumentStore):
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`")
for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = self.faiss_index.ntotal
vector_id = self.faiss_index.ntotal
for i in range(0, len(document_objects), batch_size):
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
embeddings = [doc.embedding for doc in document_objects[i: i + batch_size]]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings)
docs_to_write_in_sql = []
for doc in document_objects[i : i + self.index_buffer_size]:
for doc in document_objects[i: i + batch_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
@ -158,13 +157,14 @@ class FAISSDocumentStore(SQLDocumentStore):
self.index: "embedding",
}
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None):
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to get embeddings for text
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
if not self.faiss_index:
@ -172,55 +172,85 @@ class FAISSDocumentStore(SQLDocumentStore):
# Faiss does not support update in existing index data so clear all existing data in it
self.faiss_index.reset()
self.reset_vector_ids(index=index)
index = index or self.index
documents = self.get_all_documents(index=index)
if len(documents) == 0:
document_count = self.get_document_count(index=index)
if document_count == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
return
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
logger.info(f"Updating embeddings for {document_count} docs...")
vector_id = self.faiss_index.ntotal
logger.info(f"Updating embeddings for {len(documents)} docs...")
embeddings = retriever.embed_passages(documents) # type: ignore
assert len(documents) == len(embeddings)
for i, doc in enumerate(documents):
doc.embedding = embeddings[i]
result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count) as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
logger.info("Indexing embeddings and updating vectors_ids...")
for i in tqdm(range(0, len(documents), self.index_buffer_size)):
vector_id_map = {}
vector_id = self.faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings)
embeddings_to_index = np.array(embeddings, dtype="float32")
self.faiss_index.add(embeddings_to_index)
for doc in documents[i: i + self.index_buffer_size]:
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
vector_id_map = {}
for doc in document_batch:
vector_id_map[doc.id] = vector_id
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
progress_bar.update(batch_size)
progress_bar.close()
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store.
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
documents = super(FAISSDocumentStore, self).get_all_documents(index=index, filters=filters)
documents = super(FAISSDocumentStore, self).get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size
)
if return_embedding is None:
return_embedding = self.return_embedding
if return_embedding:
for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
yield doc
def get_documents_by_id(
self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
) -> List[Document]:
documents = super(FAISSDocumentStore, self).get_documents_by_id(ids=ids, index=index)
if self.return_embedding:
for doc in documents:
if doc.meta and doc.meta.get("vector_id") is not None:
doc.embedding = self.faiss_index.reconstruct(int(doc.meta["vector_id"]))
@ -309,7 +339,6 @@ class FAISSDocumentStore(SQLDocumentStore):
cls,
faiss_file_path: Union[str, Path],
sql_url: str,
index_buffer_size: int = 10_000,
):
"""
Load a saved FAISS index from a file and connect to the SQL database.
@ -318,8 +347,6 @@ class FAISSDocumentStore(SQLDocumentStore):
:param faiss_file_path: Stored FAISS index file. Can be created via calling `save()`
:param sql_url: Connection string to the SQL database that contains your docs and metadata.
:param index_buffer_size: When working with large datasets, the ingestion process(FAISS + SQL) can be buffered in
smaller chunks to reduce memory footprint.
:return:
"""
"""
@ -328,7 +355,6 @@ class FAISSDocumentStore(SQLDocumentStore):
return cls(
faiss_index=faiss_index,
sql_url=sql_url,
index_buffer_size=index_buffer_size,
vector_dim=faiss_index.d
)

View File

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Generator
from uuid import uuid4
from collections import defaultdict
@ -183,13 +183,26 @@ class InMemoryDocumentStore(BaseDocumentStore):
return len(self.indexes[index].items())
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
result = self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding)
documents = list(result)
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store.
Get all documents from the document store. The methods returns a Python Generator that yields individual
documents.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
@ -222,7 +235,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
else:
filtered_documents = documents
return filtered_documents
yield from filtered_documents
def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]:
"""

View File

@ -1,15 +1,15 @@
import itertools
import logging
from typing import Any, Dict, Union, List, Optional
from typing import Any, Dict, Union, List, Optional, Generator
from uuid import uuid4
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean, Text
from sqlalchemy import and_, func, create_engine, Column, Integer, String, DateTime, ForeignKey, Boolean, Text, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case
from sqlalchemy.sql import case, null
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.document_store.base import BaseDocumentStore
logger = logging.getLogger(__name__)
@ -73,7 +73,6 @@ class SQLDocumentStore(BaseDocumentStore):
index: str = "document",
label_index: str = "label",
update_existing_documents: bool = False,
batch_size: int = 32766,
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -87,11 +86,6 @@ class SQLDocumentStore(BaseDocumentStore):
If set to False, an error is raised if the document ID of the document being
added already exists. Using this parameter could cause performance degradation
for document insertion.
:param batch_size: Maximum number of variable parameters and rows fetched in a single SQL statement,
to help in excessive memory allocations. In most methods of the DocumentStore this means number of documents fetched in one query.
Tune this value based on host machine main memory.
For SQLite versions prior to v3.32.0 keep this value less than 1000.
More info refer: https://www.sqlite.org/limits.html
"""
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
@ -102,7 +96,6 @@ class SQLDocumentStore(BaseDocumentStore):
self.update_existing_documents = update_existing_documents
if getattr(self, "similarity", None) is None:
self.similarity = None
self.batch_size = batch_size
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""Fetch a document by specifying its text id string"""
@ -110,14 +103,14 @@ class SQLDocumentStore(BaseDocumentStore):
document = documents[0] if documents else None
return document
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None) -> List[Document]:
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000) -> List[Document]:
"""Fetch documents by specifying a list of text id strings"""
index = index or self.index
documents = []
for i in range(0, len(ids), self.batch_size):
for i in range(0, len(ids), batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.id.in_(ids[i: i + self.batch_size]),
DocumentORM.id.in_(ids[i: i + batch_size]),
DocumentORM.index == index
)
for row in query.all():
@ -125,14 +118,14 @@ class SQLDocumentStore(BaseDocumentStore):
return documents
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None, batch_size: int = 10_000):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index
documents = []
for i in range(0, len(vector_ids), self.batch_size):
for i in range(0, len(vector_ids), batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids[i: i + self.batch_size]),
DocumentORM.vector_id.in_(vector_ids[i: i + batch_size]),
DocumentORM.index == index
)
for row in query.all():
@ -142,13 +135,25 @@ class SQLDocumentStore(BaseDocumentStore):
return sorted_documents
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
) -> List[Document]:
documents = list(self.get_all_documents_generator(index=index, filters=filters))
return documents
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get documents from the document store.
Get documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
@ -177,26 +182,31 @@ class SQLDocumentStore(BaseDocumentStore):
)
documents_map = {}
for row in documents_query.all():
for i, row in enumerate(self._windowed_query(documents_query, DocumentORM.id, batch_size), start=1):
documents_map[row.id] = Document(
id=row.id,
text=row.text,
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
meta=None if row.vector_id is None else {"vector_id": row.vector_id} # type: ignore
)
if i % batch_size == 0:
documents_map = self._get_documents_meta(documents_map)
yield from documents_map.values()
documents_map = {}
if documents_map:
documents_map = self._get_documents_meta(documents_map)
yield from documents_map.values()
for doc_ids in self.chunked_iterable(documents_map.keys(), size=self.batch_size):
meta_query = self.session.query(
MetaORM.document_id,
MetaORM.name,
MetaORM.value
).filter(MetaORM.document_id.in_(doc_ids))
def _get_documents_meta(self, documents_map):
doc_ids = documents_map.keys()
meta_query = self.session.query(
MetaORM.document_id,
MetaORM.name,
MetaORM.value
).filter(MetaORM.document_id.in_(doc_ids))
for row in meta_query.all():
if documents_map[row.document_id].meta is None:
documents_map[row.document_id].meta = {}
documents_map[row.document_id].meta[row.name] = row.value # type: ignore
return list(documents_map.values())
for row in meta_query.all():
documents_map[row.document_id].meta[row.name] = row.value # type: ignore
return documents_map
def get_all_labels(self, index=None, filters: Optional[dict] = None):
"""
@ -209,17 +219,20 @@ class SQLDocumentStore(BaseDocumentStore):
return labels
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Indexes documents for later queries.
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
:param documents: a list of Python dictionaries or a list of Haystack Document objects.
For documents as dictionaries, the format is {"text": "<the-actual-text>"}.
Optionally: Include meta data via {"text": "<the-actual-text>",
"meta":{"name": "<some-document-name>, "author": "somebody", ...}}
It can be used for filtering and is accessible in the responses of the Finder.
:param index: add an optional index attribute to documents. It can be later used for filtering. For instance,
documents for evaluation can be indexed in a separate index than the documents for search.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
@ -233,10 +246,10 @@ class SQLDocumentStore(BaseDocumentStore):
else:
document_objects = documents
for i in range(0, len(document_objects), self.batch_size):
for doc in document_objects[i: i + self.batch_size]:
for i in range(0, len(document_objects), batch_size):
for doc in document_objects[i: i + batch_size]:
meta_fields = doc.meta or {}
vector_id = meta_fields.get("vector_id")
vector_id = meta_fields.pop("vector_id", None)
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
if self.update_existing_documents:
@ -275,15 +288,16 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.add(label_orm)
self.session.commit()
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None):
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None, batch_size: int = 10_000):
"""
Update vector_ids for given document_ids.
:param vector_id_map: dict containing mapping of document_id -> vector_id.
:param index: filter documents by the optional index attribute for documents in database.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
for chunk_map in self.chunked_dict(vector_id_map, size=self.batch_size):
for chunk_map in self.chunked_dict(vector_id_map, size=batch_size):
self.session.query(DocumentORM).filter(
DocumentORM.id.in_(chunk_map),
DocumentORM.index == index
@ -300,6 +314,14 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.rollback()
raise ex
def reset_vector_ids(self, index: Optional[str] = None):
"""
Set vector IDs for all documents as None
"""
index = index or self.index
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]):
"""
Update the metadata dictionary of a document by specifying its string id
@ -392,16 +414,55 @@ class SQLDocumentStore(BaseDocumentStore):
session.commit()
return instance
# Refer: https://alexwlchan.net/2018/12/iterating-in-fixed-size-chunks/
def chunked_iterable(self, iterable, size):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, size))
if not chunk:
break
yield chunk
def chunked_dict(self, dictionary, size):
it = iter(dictionary)
for i in range(0, len(dictionary), size):
yield {k: dictionary[k] for k in itertools.islice(it, size)}
def _column_windows(self, session, column, windowsize):
"""Return a series of WHERE clauses against
a given column that break it into windows.
Result is an iterable of tuples, consisting of
((start, end), whereclause), where (start, end) are the ids.
The code is taken from: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
"""
def int_for_range(start_id, end_id):
if end_id:
return and_(
column >= start_id,
column < end_id
)
else:
return column >= start_id
q = session.query(
column,
func.row_number(). \
over(order_by=column). \
label('rownum')
). \
from_self(column)
if windowsize > 1:
q = q.filter(text("rownum %% %d=1" % windowsize))
intervals = [id for id, in q]
while intervals:
start = intervals.pop(0)
if intervals:
end = intervals[0]
else:
end = None
yield int_for_range(start, end)
def _windowed_query(self, q, column, windowsize):
""""Break a Query into windows on a given column."""
for whereclause in self._column_windows(
q.session,
column, windowsize):
for row in q.filter(whereclause).order_by(column):
yield row

View File

@ -1,5 +1,6 @@
import json
from collections import defaultdict
from itertools import islice
import logging
import pprint
import pandas as pd
@ -113,3 +114,14 @@ def convert_labels_to_squad(labels_file: str):
with open("labels_in_squad_format.json", "w+") as outfile:
json.dump(labels_in_squad_format, outfile)
def get_batches_from_generator(iterable, n):
"""
Batch elements of an iterable into fixed-length chunks or blocks.
"""
it = iter(iterable)
x = tuple(islice(it, n))
while x:
yield x
x = tuple(islice(it, n))

View File

@ -258,7 +258,9 @@ def get_document_store(document_store_type, embedding_field="embedding"):
)
elif document_store_type == "faiss":
document_store = FAISSDocumentStore(
sql_url="sqlite://", return_embedding=True, embedding_field=embedding_field
sql_url="sqlite://",
return_embedding=True,
embedding_field=embedding_field,
)
return document_store
else:

View File

@ -85,6 +85,20 @@ def test_get_document_count(document_store):
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
@pytest.mark.elasticsearch
def test_get_all_documents_generator(document_store):
documents = [
{"text": "text1", "id": "1", "meta_field_for_count": "a"},
{"text": "text2", "id": "2", "meta_field_for_count": "b"},
{"text": "text3", "id": "3", "meta_field_for_count": "b"},
{"text": "text4", "id": "4", "meta_field_for_count": "b"},
{"text": "text5", "id": "5", "meta_field_for_count": "b"},
]
document_store.write_documents(documents)
assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
@ -168,6 +182,41 @@ def test_document_with_embeddings(document_store):
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
def test_update_embeddings(document_store, retriever):
documents = []
for i in range(23):
documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
documents.append({"text": "text_0", "id": "23", "meta_field": "value_0"})
document_store.write_documents(documents, index="haystack_test_1")
document_store.update_embeddings(retriever, index="haystack_test_1")
documents = document_store.get_all_documents(index="haystack_test_1", return_embedding=True)
assert len(documents) == 24
for doc in documents:
assert type(doc.embedding) is np.ndarray
documents = document_store.get_all_documents(
index="haystack_test_1",
filters={"meta_field": ["value_0", "value_23"]},
return_embedding=True,
)
np.testing.assert_array_equal(documents[0].embedding, documents[1].embedding)
documents = document_store.get_all_documents(
index="haystack_test_1",
filters={"meta_field": ["value_0", "value_10"]},
return_embedding=True,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
documents[0].embedding,
documents[1].embedding
)
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_delete_documents(document_store_with_docs):

View File

@ -18,40 +18,47 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
documents = [
Document(
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
meta={"name": "0"}
meta={"name": "0"},
id="1",
),
Document(
text="""Democratic Republic of the Congo to the south. Angola's capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola's climate is characterized by two seasons: rainfall from October to April and drought, known as ""Cacimbo"", from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with""",
id="2",
),
Document(
text="""Schopenhauer, describing him as an ultimately shallow thinker: ""Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end."" His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous ""History of Western Philosophy"" for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are""",
meta={"name": "1"}
meta={"name": "1"},
id="3",
),
Document(
text="""The Dothraki vocabulary was created by David J. Peterson well in advance of the adaptation. HBO hired the Language Creatio""",
meta={"name": "2"}
meta={"name": "2"},
id="4",
),
Document(
text="""The title of the episode refers to the Great Sept of Baelor, the main religious building in King's Landing, where the episode's pivotal scene takes place. In the world created by George R. R. Martin""",
meta={}
meta={},
id="5",
)
]
document_store.return_embedding = return_embedding
document_store.write_documents(documents)
document_store.update_embeddings(retriever=retriever)
time.sleep(1)
docs_with_emb = document_store.get_all_documents()
if return_embedding is True:
assert (len(docs_with_emb[0].embedding) == 768)
assert (abs(docs_with_emb[0].embedding[0] - (-0.3063)) < 0.001)
assert (abs(docs_with_emb[1].embedding[0] - (-0.3914)) < 0.001)
assert (abs(docs_with_emb[2].embedding[0] - (-0.2470)) < 0.001)
assert (abs(docs_with_emb[3].embedding[0] - (-0.0802)) < 0.001)
assert (abs(docs_with_emb[4].embedding[0] - (-0.0551)) < 0.001)
doc_1 = document_store.get_document_by_id("1")
assert (len(doc_1.embedding) == 768)
assert (abs(doc_1.embedding[0] - (-0.3063)) < 0.001)
doc_2 = document_store.get_document_by_id("2")
assert (abs(doc_2.embedding[0] - (-0.3914)) < 0.001)
doc_3 = document_store.get_document_by_id("3")
assert (abs(doc_3.embedding[0] - (-0.2470)) < 0.001)
doc_4 = document_store.get_document_by_id("4")
assert (abs(doc_4.embedding[0] - (-0.0802)) < 0.001)
doc_5 = document_store.get_document_by_id("5")
assert (abs(doc_5.embedding[0] - (-0.0551)) < 0.001)
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")

View File

@ -1,107 +1,72 @@
import pytest
from haystack.document_store.base import BaseDocumentStore
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.preprocessor.preprocessor import PreProcessor
from haystack.finder import Finder
@pytest.mark.parametrize("batch_size", [None, 20])
@pytest.mark.elasticsearch
def test_add_eval_data(document_store):
def test_add_eval_data(document_store, batch_size):
# add eval data (SQUAD format)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
document_store.add_eval_data(filename="samples/squad/small.json", doc_index="test_eval_document", label_index="test_feedback")
document_store.add_eval_data(
filename="samples/squad/small.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
batch_size=batch_size,
)
assert document_store.get_document_count(index="test_eval_document") == 87
assert document_store.get_label_count(index="test_feedback") == 1214
assert document_store.get_document_count(index="haystack_test_eval_document") == 87
assert document_store.get_label_count(index="haystack_test_feedback") == 1214
# test documents
docs = document_store.get_all_documents(index="test_eval_document")
assert docs[0].text[:10] == "The Norman"
docs = document_store.get_all_documents(index="haystack_test_eval_document", filters={"name": ["Normans"]})
assert docs[0].meta["name"] == "Normans"
assert len(docs[0].meta.keys()) == 1
# test labels
labels = document_store.get_all_labels(index="test_feedback")
assert labels[0].answer == "France"
assert labels[0].no_answer == False
assert labels[0].is_correct_answer == True
assert labels[0].is_correct_document == True
assert labels[0].question == 'In what country is Normandy located?'
assert labels[0].origin == "gold_label"
assert labels[0].offset_start_in_doc == 159
labels = document_store.get_all_labels(index="haystack_test_feedback")
label = None
for l in labels:
if l.question == "In what country is Normandy located?":
label = l
break
assert label.answer == "France"
assert label.no_answer == False
assert label.is_correct_answer == True
assert label.is_correct_document == True
assert label.question == "In what country is Normandy located?"
assert label.origin == "gold_label"
assert label.offset_start_in_doc == 159
# check combination
assert labels[0].document_id == docs[0].id
start = labels[0].offset_start_in_doc
end = start+len(labels[0].answer)
assert docs[0].text[start:end] == "France"
# clean up
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@pytest.mark.elasticsearch
def test_add_eval_data_batchwise(document_store):
# add eval data (SQUAD format in jsonl)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
document_store.add_eval_data(filename="samples/squad/small.json",
doc_index="test_eval_document",
label_index="test_feedback",
batch_size=20)
assert document_store.get_document_count(index="test_eval_document") == 87
assert document_store.get_label_count(index="test_feedback") == 1214
# test documents
docs = document_store.get_all_documents(index="test_eval_document")
assert docs[0].text[:10] == "The Norman"
assert docs[0].meta["name"] == "Normans"
assert len(docs[0].meta.keys()) == 1
# test labels
labels = document_store.get_all_labels(index="test_feedback")
assert labels[0].answer == "France"
assert labels[0].no_answer == False
assert labels[0].is_correct_answer == True
assert labels[0].is_correct_document == True
assert labels[0].question == 'In what country is Normandy located?'
assert labels[0].origin == "gold_label"
assert labels[0].offset_start_in_doc == 159
# check combination
assert labels[0].document_id == docs[0].id
start = labels[0].offset_start_in_doc
end = start+len(labels[0].answer)
assert docs[0].text[start:end] == "France"
# clean up
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
doc = document_store.get_document_by_id(label.document_id, index="haystack_test_eval_document")
start = label.offset_start_in_doc
end = start + len(label.answer)
assert doc.text[start:end] == "France"
@pytest.mark.elasticsearch
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_eval_reader(reader, document_store: BaseDocumentStore):
# add eval data (SQUAD format)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
assert document_store.get_document_count(index="test_eval_document") == 2
document_store.add_eval_data(
filename="samples/squad/tiny.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
)
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
# eval reader
reader_eval_results = reader.eval(document_store=document_store, label_index="test_feedback",
doc_index="test_eval_document", device="cpu")
reader_eval_results = reader.eval(
document_store=document_store,
label_index="haystack_test_feedback",
doc_index="haystack_test_eval_document",
device="cpu",
)
assert reader_eval_results["f1"] > 0.65
assert reader_eval_results["f1"] < 0.67
assert reader_eval_results["EM"] == 0.5
assert reader_eval_results["top_n_accuracy"] == 1.0
# clean up
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@ -109,22 +74,22 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
# add eval data (SQUAD format)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
assert document_store.get_document_count(index="test_eval_document") == 2
document_store.add_eval_data(
filename="samples/squad/tiny.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
)
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
# eval retriever
results = retriever.eval(top_k=1, label_index="test_feedback", doc_index="test_eval_document", open_domain=open_domain)
results = retriever.eval(
top_k=1, label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", open_domain=open_domain
)
assert results["recall"] == 1.0
assert results["mrr"] == 1.0
if not open_domain:
assert results["map"] == 1.0
# clean up
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@ -134,13 +99,17 @@ def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
finder = Finder(reader=reader, retriever=retriever)
# add eval data (SQUAD format)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
document_store.add_eval_data(filename="samples/squad/tiny.json", doc_index="test_eval_document", label_index="test_feedback")
assert document_store.get_document_count(index="test_eval_document") == 2
document_store.add_eval_data(
filename="samples/squad/tiny.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
)
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
# eval finder
results = finder.eval(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1, top_k_reader=5)
results = finder.eval(
label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", top_k_retriever=1, top_k_reader=5
)
assert results["retriever_recall"] == 1.0
assert results["retriever_map"] == 1.0
assert abs(results["reader_topk_f1"] - 0.66666) < 0.001
@ -151,24 +120,19 @@ def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
assert results["reader_top1_accuracy"] <= results["reader_topk_accuracy"]
# batch eval finder
results_batch = finder.eval_batch(label_index="test_feedback", doc_index="test_eval_document", top_k_retriever=1,
top_k_reader=5)
results_batch = finder.eval_batch(
label_index="haystack_test_feedback", doc_index="haystack_test_eval_document", top_k_retriever=1, top_k_reader=5
)
assert results_batch["retriever_recall"] == 1.0
assert results_batch["retriever_map"] == 1.0
assert results_batch["reader_top1_f1"] == results["reader_top1_f1"]
assert results_batch["reader_top1_em"] == results["reader_top1_em"]
assert results_batch["reader_topk_accuracy"] == results["reader_topk_accuracy"]
# clean up
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@pytest.mark.elasticsearch
def test_eval_data_splitting(document_store):
def test_eval_data_split_word(document_store):
# splitting by word
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
preprocessor = PreProcessor(
clean_empty_lines=False,
clean_whitespace=False,
@ -176,22 +140,24 @@ def test_eval_data_splitting(document_store):
split_by="word",
split_length=4,
split_overlap=0,
split_respect_sentence_boundary=False
split_respect_sentence_boundary=False,
)
document_store.add_eval_data(filename="samples/squad/tiny.json",
doc_index="test_eval_document",
label_index="test_feedback",
preprocessor=preprocessor)
labels = document_store.get_all_labels_aggregated(index="test_feedback")
docs = document_store.get_all_documents(index="test_eval_document")
document_store.add_eval_data(
filename="samples/squad/tiny.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
preprocessor=preprocessor,
)
labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback")
docs = document_store.get_all_documents(index="haystack_test_eval_document")
assert len(docs) == 5
assert len(set(labels[0].multiple_document_ids)) == 2
# splitting by passage
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@pytest.mark.elasticsearch
def test_eval_data_split_passage(document_store):
# splitting by passage
preprocessor = PreProcessor(
clean_empty_lines=False,
clean_whitespace=False,
@ -202,10 +168,12 @@ def test_eval_data_splitting(document_store):
split_respect_sentence_boundary=False
)
document_store.add_eval_data(filename="samples/squad/tiny_passages.json",
doc_index="test_eval_document",
label_index="test_feedback",
preprocessor=preprocessor)
docs = document_store.get_all_documents(index="test_eval_document")
document_store.add_eval_data(
filename="samples/squad/tiny_passages.json",
doc_index="haystack_test_eval_document",
label_index="haystack_test_feedback",
preprocessor=preprocessor,
)
docs = document_store.get_all_documents(index="haystack_test_eval_document")
assert len(docs) == 2
assert len(docs[1].text) == 56

View File

@ -19,21 +19,6 @@ DOCUMENTS = [
]
def check_data_correctness(documents_indexed, documents_inserted):
# test if correct vector_ids are assigned
for i, doc in enumerate(documents_indexed):
assert doc.meta["vector_id"] == str(i)
# test if number of documents is correct
assert len(documents_indexed) == len(documents_inserted)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(documents_inserted)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_index_save_and_load(document_store):
document_store.write_documents(DOCUMENTS)
@ -65,6 +50,7 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
document_store.write_documents(DOCUMENTS[i: i + batch_size])
documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(DOCUMENTS)
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
@ -74,34 +60,26 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
# adjust buffer size
document_store.index_buffer_size = index_buffer_size
@pytest.mark.parametrize("batch_size", [4, 6])
def test_faiss_update_docs(document_store, retriever, batch_size):
# initial write
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever)
document_store.update_embeddings(retriever=retriever, batch_size=batch_size)
documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(DOCUMENTS)
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
for doc in documents_indexed:
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0]
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_emb, rtol=0.01)
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@ -115,8 +93,7 @@ def test_faiss_update_with_empty_store(document_store, retriever):
documents_indexed = document_store.get_all_documents()
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
assert len(documents_indexed) == len(DOCUMENTS)
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
@ -190,5 +167,8 @@ def test_faiss_passing_index_from_outside():
document_store.write_documents(documents=DOCUMENTS, index="document")
documents_indexed = document_store.get_all_documents(index="document")
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
# test if vectors ids are associated with docs
for doc in documents_indexed:
assert 0 <= int(doc.meta["vector_id"]) <= 7