2020-08-07 14:25:08 +02:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
2020-09-18 12:52:22 +02:00
|
|
|
from haystack import Document
|
2020-10-06 16:09:56 +02:00
|
|
|
import faiss
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2020-10-06 16:09:56 +02:00
|
|
|
from haystack.document_store.faiss import FAISSDocumentStore
|
2020-09-18 12:52:22 +02:00
|
|
|
from haystack.retriever.dense import DensePassageRetriever
|
2020-09-18 17:08:13 +02:00
|
|
|
from haystack.retriever.dense import EmbeddingRetriever
|
|
|
|
from haystack import Finder
|
|
|
|
|
|
|
|
DOCUMENTS = [
|
|
|
|
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
2020-10-05 12:01:20 +02:00
|
|
|
{"name": "name_4", "text": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_5", "text": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_6", "text": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
2020-09-18 17:08:13 +02:00
|
|
|
]
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
def check_data_correctness(documents_indexed, documents_inserted):
|
2020-08-07 14:25:08 +02:00
|
|
|
# test if correct vector_ids are assigned
|
|
|
|
for i, doc in enumerate(documents_indexed):
|
|
|
|
assert doc.meta["vector_id"] == str(i)
|
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
# test if number of documents is correct
|
|
|
|
assert len(documents_indexed) == len(documents_inserted)
|
|
|
|
|
|
|
|
# test if two docs have same vector_is assigned
|
|
|
|
vector_ids = set()
|
|
|
|
for i, doc in enumerate(documents_indexed):
|
|
|
|
vector_ids.add(doc.meta["vector_id"])
|
|
|
|
assert len(vector_ids) == len(documents_inserted)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
|
|
|
def test_faiss_index_save_and_load(document_store):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
|
|
|
|
# test saving the index
|
|
|
|
document_store.save("haystack_test_faiss")
|
|
|
|
|
|
|
|
# clear existing faiss_index
|
|
|
|
document_store.faiss_index.reset()
|
|
|
|
|
|
|
|
# test faiss index is cleared
|
|
|
|
assert document_store.faiss_index.ntotal == 0
|
|
|
|
|
|
|
|
# test loading the index
|
|
|
|
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
|
|
|
|
|
|
|
|
# check faiss index is restored
|
|
|
|
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
|
|
|
@pytest.mark.parametrize("batch_size", [2])
|
|
|
|
def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|
|
|
document_store.index_buffer_size = index_buffer_size
|
|
|
|
|
|
|
|
# Write in small batches
|
|
|
|
for i in range(0, len(DOCUMENTS), batch_size):
|
|
|
|
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
|
|
|
|
|
|
|
documents_indexed = document_store.get_all_documents()
|
|
|
|
|
2020-09-18 12:52:22 +02:00
|
|
|
# test if correct vectors are associated with docs
|
|
|
|
for i, doc in enumerate(documents_indexed):
|
|
|
|
# we currently don't get the embeddings back when we call document_store.get_all_documents()
|
2020-09-18 17:08:13 +02:00
|
|
|
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
|
2020-09-18 12:52:22 +02:00
|
|
|
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
|
|
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
2020-10-06 16:09:56 +02:00
|
|
|
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
# test document correctness
|
|
|
|
check_data_correctness(documents_indexed, DOCUMENTS)
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2020-09-18 12:52:22 +02:00
|
|
|
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
|
|
|
def test_faiss_update_docs(document_store, index_buffer_size):
|
|
|
|
# adjust buffer size
|
|
|
|
document_store.index_buffer_size = index_buffer_size
|
|
|
|
|
|
|
|
# initial write
|
2020-09-18 17:08:13 +02:00
|
|
|
document_store.write_documents(DOCUMENTS)
|
2020-09-18 12:52:22 +02:00
|
|
|
|
|
|
|
# do the update
|
|
|
|
retriever = DensePassageRetriever(document_store=document_store,
|
|
|
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
|
|
use_gpu=False, embed_title=True,
|
|
|
|
remove_sep_tok_from_untitled_passages=True)
|
|
|
|
|
|
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
documents_indexed = document_store.get_all_documents()
|
|
|
|
|
|
|
|
# test if correct vectors are associated with docs
|
|
|
|
for i, doc in enumerate(documents_indexed):
|
2020-09-18 17:08:13 +02:00
|
|
|
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
|
2020-09-18 12:52:22 +02:00
|
|
|
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
|
|
|
|
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
|
|
|
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
2020-10-06 16:09:56 +02:00
|
|
|
assert np.allclose(updated_embedding, stored_emb, rtol=0.01)
|
2020-09-18 17:08:13 +02:00
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
# test document correctness
|
|
|
|
check_data_correctness(documents_indexed, DOCUMENTS)
|
|
|
|
|
|
|
|
|
2020-09-18 17:08:13 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
|
|
|
def test_faiss_retrieving(document_store):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
|
|
|
use_gpu=False)
|
2020-09-18 17:08:13 +02:00
|
|
|
result = retriever.retrieve(query="How to test this?")
|
2020-10-05 12:01:20 +02:00
|
|
|
assert len(result) == len(DOCUMENTS)
|
2020-09-18 17:08:13 +02:00
|
|
|
assert type(result[0]) == Document
|
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
|
2020-09-18 17:08:13 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
|
|
|
def test_faiss_finding(document_store):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
|
|
|
use_gpu=False)
|
2020-09-18 17:08:13 +02:00
|
|
|
finder = Finder(reader=None, retriever=retriever)
|
|
|
|
|
|
|
|
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2020-09-18 17:08:13 +02:00
|
|
|
assert len(prediction.get('answers', [])) == 1
|
2020-10-06 16:09:56 +02:00
|
|
|
|
|
|
|
def test_faiss_passing_index_from_outside():
|
|
|
|
d = 768
|
|
|
|
nlist = 2
|
|
|
|
quantizer = faiss.IndexFlatIP(d)
|
|
|
|
faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
faiss_index.nprobe = 2
|
|
|
|
document_store = FAISSDocumentStore(sql_url="sqlite:///haystack_test_faiss.db", faiss_index=faiss_index)
|
|
|
|
|
|
|
|
document_store.delete_all_documents(index="document")
|
|
|
|
# as it is a IVF index we need to train it before adding docs
|
|
|
|
document_store.train_index(DOCUMENTS)
|
|
|
|
|
|
|
|
document_store.write_documents(documents=DOCUMENTS, index="document")
|
|
|
|
documents_indexed = document_store.get_all_documents(index="document")
|
|
|
|
|
|
|
|
# test document correctness
|
|
|
|
check_data_correctness(documents_indexed, DOCUMENTS)
|