2021-09-29 09:27:06 +02:00
|
|
|
import time
|
2020-11-26 10:32:30 +01:00
|
|
|
import faiss
|
2021-09-20 07:54:26 +02:00
|
|
|
import math
|
2020-08-07 14:25:08 +02:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.schema import Document
|
2021-08-09 13:41:40 +02:00
|
|
|
from haystack.pipeline import DocumentSearchPipeline
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
2020-12-03 10:27:06 +01:00
|
|
|
from haystack.pipeline import Pipeline
|
2020-11-26 10:32:30 +01:00
|
|
|
from haystack.retriever.dense import EmbeddingRetriever
|
2020-09-18 17:08:13 +02:00
|
|
|
|
|
|
|
DOCUMENTS = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
|
|
|
{"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_5", "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_6", "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
2020-09-18 17:08:13 +02:00
|
|
|
]
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2021-04-27 09:55:31 +02:00
|
|
|
def test_faiss_index_save_and_load(tmp_path):
|
|
|
|
document_store = FAISSDocumentStore(
|
|
|
|
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
|
|
|
index="haystack_test",
|
2021-09-20 08:32:14 +02:00
|
|
|
progress_bar=False # Just to check if the init parameters are kept
|
2021-04-27 09:55:31 +02:00
|
|
|
)
|
2021-01-11 13:24:38 +01:00
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
|
|
|
|
# test saving the index
|
2021-04-27 09:55:31 +02:00
|
|
|
document_store.save(tmp_path / "haystack_test_faiss")
|
2021-01-11 13:24:38 +01:00
|
|
|
|
|
|
|
# clear existing faiss_index
|
2021-02-09 21:25:01 +01:00
|
|
|
document_store.faiss_indexes[document_store.index].reset()
|
2021-01-11 13:24:38 +01:00
|
|
|
|
|
|
|
# test faiss index is cleared
|
2021-02-09 21:25:01 +01:00
|
|
|
assert document_store.faiss_indexes[document_store.index].ntotal == 0
|
2021-01-11 13:24:38 +01:00
|
|
|
|
|
|
|
# test loading the index
|
2021-09-20 08:32:14 +02:00
|
|
|
new_document_store = FAISSDocumentStore.load(tmp_path / "haystack_test_faiss")
|
2021-01-11 13:24:38 +01:00
|
|
|
|
|
|
|
# check faiss index is restored
|
2021-02-09 21:25:01 +01:00
|
|
|
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
|
2021-04-27 09:55:31 +02:00
|
|
|
# check if documents are restored
|
|
|
|
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
|
2021-09-20 08:32:14 +02:00
|
|
|
# Check if the init parameters are kept
|
|
|
|
assert not new_document_store.progress_bar
|
2020-10-05 12:01:20 +02:00
|
|
|
|
|
|
|
|
2021-09-27 11:25:05 +02:00
|
|
|
def test_faiss_index_save_and_load_custom_path(tmp_path):
|
|
|
|
document_store = FAISSDocumentStore(
|
|
|
|
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
|
|
|
index="haystack_test",
|
|
|
|
progress_bar=False # Just to check if the init parameters are kept
|
|
|
|
)
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
|
|
|
|
# test saving the index
|
|
|
|
document_store.save(index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
|
|
|
|
|
|
|
|
# clear existing faiss_index
|
|
|
|
document_store.faiss_indexes[document_store.index].reset()
|
|
|
|
|
|
|
|
# test faiss index is cleared
|
|
|
|
assert document_store.faiss_indexes[document_store.index].ntotal == 0
|
|
|
|
|
|
|
|
# test loading the index
|
|
|
|
new_document_store = FAISSDocumentStore.load(index_path=tmp_path / "haystack_test_faiss", config_path=tmp_path / "custom_path.json")
|
|
|
|
|
|
|
|
# check faiss index is restored
|
|
|
|
assert new_document_store.faiss_indexes[document_store.index].ntotal == len(DOCUMENTS)
|
|
|
|
# check if documents are restored
|
|
|
|
assert len(new_document_store.get_all_documents()) == len(DOCUMENTS)
|
|
|
|
# Check if the init parameters are kept
|
|
|
|
assert not new_document_store.progress_bar
|
|
|
|
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
2020-10-05 12:01:20 +02:00
|
|
|
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
|
|
|
@pytest.mark.parametrize("batch_size", [2])
|
2020-12-14 18:15:44 +01:00
|
|
|
def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|
|
|
document_store.index_buffer_size = index_buffer_size
|
2020-10-05 12:01:20 +02:00
|
|
|
|
|
|
|
# Write in small batches
|
|
|
|
for i in range(0, len(DOCUMENTS), batch_size):
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.write_documents(DOCUMENTS[i: i + batch_size])
|
2020-10-05 12:01:20 +02:00
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
documents_indexed = document_store.get_all_documents()
|
2021-01-21 16:00:08 +01:00
|
|
|
assert len(documents_indexed) == len(DOCUMENTS)
|
2020-10-05 12:01:20 +02:00
|
|
|
|
2020-09-18 12:52:22 +02:00
|
|
|
# test if correct vectors are associated with docs
|
|
|
|
for i, doc in enumerate(documents_indexed):
|
|
|
|
# we currently don't get the embeddings back when we call document_store.get_all_documents()
|
2021-10-13 14:23:23 +02:00
|
|
|
original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0]
|
2021-02-09 21:25:01 +01:00
|
|
|
stored_emb = document_store.faiss_indexes[document_store.index].reconstruct(int(doc.meta["vector_id"]))
|
2020-09-18 12:52:22 +02:00
|
|
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
2020-10-06 16:09:56 +02:00
|
|
|
assert np.allclose(original_doc["embedding"], stored_emb, rtol=0.01)
|
2021-09-29 09:27:06 +02:00
|
|
|
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2020-10-26 19:19:10 +01:00
|
|
|
@pytest.mark.slow
|
2020-12-14 18:15:44 +01:00
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
2021-01-21 16:00:08 +01:00
|
|
|
@pytest.mark.parametrize("batch_size", [4, 6])
|
2021-01-29 13:29:12 +01:00
|
|
|
def test_update_docs(document_store, retriever, batch_size):
|
2020-09-18 12:52:22 +02:00
|
|
|
# initial write
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.write_documents(DOCUMENTS)
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2021-01-21 16:00:08 +01:00
|
|
|
document_store.update_embeddings(retriever=retriever, batch_size=batch_size)
|
2020-12-14 18:15:44 +01:00
|
|
|
documents_indexed = document_store.get_all_documents()
|
2021-01-21 16:00:08 +01:00
|
|
|
assert len(documents_indexed) == len(DOCUMENTS)
|
2020-09-18 12:52:22 +02:00
|
|
|
|
|
|
|
# test if correct vectors are associated with docs
|
2021-01-21 16:00:08 +01:00
|
|
|
for doc in documents_indexed:
|
2021-10-13 14:23:23 +02:00
|
|
|
original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0]
|
2021-10-28 12:17:56 +02:00
|
|
|
updated_embedding = retriever.embed_documents([Document.from_dict(original_doc)])
|
2021-01-21 16:00:08 +01:00
|
|
|
stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0]
|
2020-09-18 12:52:22 +02:00
|
|
|
# compare original input vec with stored one (ignore extra dim added by hnsw)
|
2021-01-21 16:00:08 +01:00
|
|
|
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
|
2020-10-05 12:01:20 +02:00
|
|
|
|
|
|
|
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.mark.slow
|
2020-12-14 18:15:44 +01:00
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["milvus", "faiss"], indirect=True)
|
2021-05-25 16:30:06 +05:00
|
|
|
def test_update_existing_docs(document_store, retriever):
|
|
|
|
document_store.duplicate_documents = "overwrite"
|
2021-10-13 14:23:23 +02:00
|
|
|
old_document = Document(content="text_1")
|
2021-01-29 13:29:12 +01:00
|
|
|
# initial write
|
|
|
|
document_store.write_documents([old_document])
|
|
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
old_documents_indexed = document_store.get_all_documents()
|
|
|
|
assert len(old_documents_indexed) == 1
|
|
|
|
|
|
|
|
# Update document data
|
2021-10-13 14:23:23 +02:00
|
|
|
new_document = Document(content="text_2")
|
2021-01-29 13:29:12 +01:00
|
|
|
new_document.id = old_document.id
|
|
|
|
document_store.write_documents([new_document])
|
|
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
new_documents_indexed = document_store.get_all_documents()
|
|
|
|
assert len(new_documents_indexed) == 1
|
|
|
|
|
|
|
|
assert old_documents_indexed[0].id == new_documents_indexed[0].id
|
2021-10-13 14:23:23 +02:00
|
|
|
assert old_documents_indexed[0].content == "text_1"
|
|
|
|
assert new_documents_indexed[0].content == "text_2"
|
2021-01-29 13:29:12 +01:00
|
|
|
assert not np.allclose(old_documents_indexed[0].embedding, new_documents_indexed[0].embedding, rtol=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_update_with_empty_store(document_store, retriever):
|
2020-10-14 16:15:04 +02:00
|
|
|
# Call update with empty doc store
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.update_embeddings(retriever=retriever)
|
2020-10-14 16:15:04 +02:00
|
|
|
|
|
|
|
# initial write
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.write_documents(DOCUMENTS)
|
2020-09-18 17:08:13 +02:00
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
documents_indexed = document_store.get_all_documents()
|
2020-10-14 16:15:04 +02:00
|
|
|
|
2021-01-21 16:00:08 +01:00
|
|
|
assert len(documents_indexed) == len(DOCUMENTS)
|
2020-10-14 16:15:04 +02:00
|
|
|
|
|
|
|
|
2020-11-26 10:32:30 +01:00
|
|
|
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
|
2021-04-27 09:55:31 +02:00
|
|
|
def test_faiss_retrieving(index_factory, tmp_path):
|
2021-01-06 15:56:19 +01:00
|
|
|
document_store = FAISSDocumentStore(
|
2021-04-27 09:55:31 +02:00
|
|
|
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory
|
2021-01-06 15:56:19 +01:00
|
|
|
)
|
|
|
|
|
2020-11-26 10:32:30 +01:00
|
|
|
document_store.delete_all_documents(index="document")
|
|
|
|
if "ivf" in index_factory.lower():
|
|
|
|
document_store.train_index(DOCUMENTS)
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
2021-01-06 15:56:19 +01:00
|
|
|
|
|
|
|
retriever = EmbeddingRetriever(
|
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="deepset/sentence_bert",
|
|
|
|
use_gpu=False
|
|
|
|
)
|
2020-11-26 10:32:30 +01:00
|
|
|
result = retriever.retrieve(query="How to test this?")
|
2021-01-06 15:56:19 +01:00
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
assert len(result) == len(DOCUMENTS)
|
2020-09-18 17:08:13 +02:00
|
|
|
assert type(result[0]) == Document
|
|
|
|
|
2021-01-06 15:56:19 +01:00
|
|
|
# Cleanup
|
2021-02-09 21:25:01 +01:00
|
|
|
document_store.faiss_indexes[document_store.index].reset()
|
2021-01-06 15:56:19 +01:00
|
|
|
|
2020-10-05 12:01:20 +02:00
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_finding(document_store, retriever):
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.write_documents(DOCUMENTS)
|
2021-08-09 13:41:40 +02:00
|
|
|
pipe = DocumentSearchPipeline(retriever=retriever)
|
2020-09-18 17:08:13 +02:00
|
|
|
|
2021-10-19 15:22:44 +02:00
|
|
|
prediction = pipe.run(query="How to test this?", params={"Retriever": {"top_k": 1}})
|
2020-09-18 12:52:22 +02:00
|
|
|
|
2021-08-09 13:41:40 +02:00
|
|
|
assert len(prediction.get('documents', [])) == 1
|
2020-10-06 16:09:56 +02:00
|
|
|
|
2020-10-14 16:15:04 +02:00
|
|
|
|
2021-09-29 09:27:06 +02:00
|
|
|
@pytest.mark.slow
|
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_delete_docs_with_filters(document_store, retriever):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
|
|
|
assert document_store.get_embedding_count() == 6
|
|
|
|
|
|
|
|
document_store.delete_documents(filters={"name": ["name_1", "name_2", "name_3", "name_4"]})
|
|
|
|
|
|
|
|
documents = document_store.get_all_documents()
|
|
|
|
assert len(documents) == 2
|
|
|
|
assert document_store.get_embedding_count() == 2
|
|
|
|
assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"}
|
|
|
|
|
|
|
|
|
2021-10-19 12:30:15 +02:00
|
|
|
@pytest.mark.slow
|
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_delete_docs_by_id(document_store, retriever):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
|
|
|
assert document_store.get_embedding_count() == 6
|
|
|
|
doc_ids = [doc.id for doc in document_store.get_all_documents()]
|
|
|
|
ids_to_delete = doc_ids[0:3]
|
|
|
|
|
|
|
|
document_store.delete_documents(ids=ids_to_delete)
|
|
|
|
|
|
|
|
documents = document_store.get_all_documents()
|
|
|
|
assert len(documents) == len(doc_ids) - len(ids_to_delete)
|
|
|
|
assert document_store.get_embedding_count() == len(doc_ids) - len(ids_to_delete)
|
|
|
|
|
|
|
|
remaining_ids = [doc.id for doc in documents]
|
|
|
|
assert all(doc_id not in remaining_ids for doc_id in ids_to_delete)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_delete_docs_by_id_with_filters(document_store, retriever):
|
|
|
|
document_store.write_documents(DOCUMENTS)
|
|
|
|
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
|
|
|
assert document_store.get_embedding_count() == 6
|
|
|
|
|
|
|
|
ids_to_delete = [doc.id for doc in document_store.get_all_documents(filters={"name": ["name_1", "name_2"]})]
|
|
|
|
ids_not_to_delete = [doc.id for doc in document_store.get_all_documents(filters={"name": ["name_3", "name_4", "name_5", "name_6"]})]
|
|
|
|
|
|
|
|
document_store.delete_documents(ids=ids_to_delete, filters={"name": ["name_1", "name_2", "name_3", "name_4"]})
|
|
|
|
|
|
|
|
documents = document_store.get_all_documents()
|
|
|
|
assert len(documents) == len(DOCUMENTS) - len(ids_to_delete)
|
|
|
|
assert document_store.get_embedding_count() == len(DOCUMENTS) - len(ids_to_delete)
|
|
|
|
|
|
|
|
assert all(doc.meta["name"] != "name_1" for doc in documents)
|
|
|
|
assert all(doc.meta["name"] != "name_2" for doc in documents)
|
|
|
|
|
|
|
|
all_ids_left = [doc.id for doc in documents]
|
|
|
|
assert all(doc_id in all_ids_left for doc_id in ids_not_to_delete)
|
|
|
|
|
|
|
|
|
|
|
|
|
2020-12-14 18:15:44 +01:00
|
|
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
2021-01-29 13:29:12 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
|
|
|
def test_pipeline(document_store, retriever):
|
2020-12-03 10:27:06 +01:00
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
|
|
|
{"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
2020-12-03 10:27:06 +01:00
|
|
|
]
|
2020-12-14 18:15:44 +01:00
|
|
|
document_store.write_documents(documents)
|
2020-12-03 10:27:06 +01:00
|
|
|
pipeline = Pipeline()
|
2020-12-14 18:15:44 +01:00
|
|
|
pipeline.add_node(component=retriever, name="FAISS", inputs=["Query"])
|
2021-10-19 15:22:44 +02:00
|
|
|
output = pipeline.run(query="How to test this?", params={"FAISS": {"top_k": 3}})
|
2020-12-03 10:27:06 +01:00
|
|
|
assert len(output["documents"]) == 3
|
|
|
|
|
|
|
|
|
2021-04-27 09:55:31 +02:00
|
|
|
def test_faiss_passing_index_from_outside(tmp_path):
|
2020-10-06 16:09:56 +02:00
|
|
|
d = 768
|
|
|
|
nlist = 2
|
|
|
|
quantizer = faiss.IndexFlatIP(d)
|
2021-02-09 21:25:01 +01:00
|
|
|
index = "haystack_test_1"
|
2020-10-06 16:09:56 +02:00
|
|
|
faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
|
2020-11-26 10:32:30 +01:00
|
|
|
faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
2020-10-06 16:09:56 +02:00
|
|
|
faiss_index.nprobe = 2
|
2021-04-27 09:55:31 +02:00
|
|
|
document_store = FAISSDocumentStore(
|
|
|
|
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index
|
|
|
|
)
|
2020-10-06 16:09:56 +02:00
|
|
|
|
2021-08-30 18:48:28 +05:30
|
|
|
document_store.delete_documents()
|
2020-10-06 16:09:56 +02:00
|
|
|
# as it is a IVF index we need to train it before adding docs
|
|
|
|
document_store.train_index(DOCUMENTS)
|
|
|
|
|
2021-02-09 21:25:01 +01:00
|
|
|
document_store.write_documents(documents=DOCUMENTS)
|
|
|
|
documents_indexed = document_store.get_all_documents()
|
2020-10-06 16:09:56 +02:00
|
|
|
|
2021-01-21 16:00:08 +01:00
|
|
|
# test if vectors ids are associated with docs
|
|
|
|
for doc in documents_indexed:
|
|
|
|
assert 0 <= int(doc.meta["vector_id"]) <= 7
|
|
|
|
|
|
|
|
|
2021-09-20 07:54:26 +02:00
|
|
|
def test_faiss_cosine_similarity(tmp_path):
|
|
|
|
document_store = FAISSDocumentStore(
|
|
|
|
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine'
|
|
|
|
)
|
|
|
|
|
|
|
|
# below we will write documents to the store and then query it to see if vectors were normalized
|
|
|
|
|
|
|
|
document_store.write_documents(documents=DOCUMENTS)
|
|
|
|
|
|
|
|
# note that the same query will be used later when querying after updating the embeddings
|
|
|
|
query = np.random.rand(768).astype(np.float32)
|
|
|
|
|
|
|
|
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
|
|
|
|
|
|
|
# check if search with cosine similarity returns the correct number of results
|
|
|
|
assert len(query_results) == len(DOCUMENTS)
|
|
|
|
indexed_docs = {}
|
|
|
|
for doc in DOCUMENTS:
|
2021-10-13 14:23:23 +02:00
|
|
|
indexed_docs[doc["content"]] = doc["embedding"]
|
2021-09-20 07:54:26 +02:00
|
|
|
|
|
|
|
for doc in query_results:
|
|
|
|
result_emb = doc.embedding
|
2021-10-13 14:23:23 +02:00
|
|
|
original_emb = np.array([indexed_docs[doc.content]], dtype="float32")
|
2021-09-20 07:54:26 +02:00
|
|
|
faiss.normalize_L2(original_emb)
|
|
|
|
|
|
|
|
# check if the stored embedding was normalized
|
|
|
|
assert np.allclose(original_emb[0], result_emb, rtol=0.01)
|
|
|
|
|
|
|
|
# check if the score is plausible for cosine similarity
|
|
|
|
assert 0 <= doc.score <= 1.0
|
|
|
|
|
|
|
|
# now check if vectors are normalized when updating embeddings
|
|
|
|
class MockRetriever():
|
2021-10-28 12:17:56 +02:00
|
|
|
def embed_documents(self, docs):
|
2021-09-20 07:54:26 +02:00
|
|
|
return [np.random.rand(768).astype(np.float32) for doc in docs]
|
|
|
|
|
|
|
|
retriever = MockRetriever()
|
|
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
query_results = document_store.query_by_embedding(query_emb=query, top_k=len(DOCUMENTS), return_embedding=True)
|
|
|
|
|
|
|
|
for doc in query_results:
|
2021-10-13 14:23:23 +02:00
|
|
|
original_emb = np.array([indexed_docs[doc.content]], dtype="float32")
|
2021-09-20 07:54:26 +02:00
|
|
|
faiss.normalize_L2(original_emb)
|
|
|
|
# check if the original embedding has changed after updating the embeddings
|
|
|
|
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_faiss_cosine_sanity_check(tmp_path):
|
|
|
|
document_store = FAISSDocumentStore(
|
|
|
|
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", similarity='cosine',
|
|
|
|
vector_dim=3
|
|
|
|
)
|
|
|
|
|
|
|
|
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
|
|
|
VEC_2 = np.array([.4, .5, .6], dtype="float32")
|
|
|
|
|
|
|
|
# This is the cosine similarity of VEC_1 and VEC_2 calculated using sklearn.metrics.pairwise.cosine_similarity
|
|
|
|
# The score is normalized to yield a value between 0 and 1.
|
|
|
|
KNOWN_COSINE = (0.9746317 + 1) / 2
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
docs = [{"name": "vec_1", "content": "vec_1", "embedding": VEC_1}]
|
2021-09-20 07:54:26 +02:00
|
|
|
document_store.write_documents(documents=docs)
|
|
|
|
|
|
|
|
query_results = document_store.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
|
|
|
|
|
|
|
|
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
|
|
|
|
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.000001)
|