Fix type casting for vectors in FAISS (#399)

* Fix type casting for vectors in FAISS

Co-authored-by: philipp-bode <philipp.bode@student.hpi.de>

* add type casts for elastic. refactor embedding retriever tests

* fix case: empty embedding field

* fix faiss tolerance

* add assert in test_faiss_retrieving

Co-authored-by: philipp-bode <philipp.bode@student.hpi.de>
This commit is contained in:
Malte Pietsch 2020-09-18 17:08:13 +02:00 committed by GitHub
parent 4ea4cfd282
commit db6864d159
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 29 deletions

View File

@ -208,6 +208,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
**doc.to_dict(field_map=self._create_document_field_map())
} # type: Dict[str, Any]
# cast embedding type as ES cannot deal with np.array
if _doc[self.embedding_field] is not None:
_doc[self.embedding_field] = _doc[self.embedding_field].tolist()
# rename id for elastic
_doc["_id"] = str(_doc.pop("id"))

View File

@ -67,6 +67,7 @@ class FAISSDocumentStore(SQLDocumentStore):
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors)
docs_to_write_in_sql = []
@ -130,6 +131,7 @@ class FAISSDocumentStore(SQLDocumentStore):
vector_id = faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
hnsw_vectors = hnsw_vectors.astype(np.float32)
faiss_index.add(hnsw_vectors)
for doc in documents[i: i + self.index_buffer_size]:
@ -150,7 +152,7 @@ class FAISSDocumentStore(SQLDocumentStore):
raise Exception("Query filters are not implemented for the FAISSDocumentStore.")
if not self.faiss_index:
raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
query_emb = query_emb.reshape(1, -1)
query_emb = query_emb.reshape(1, -1).astype(np.float32)
aux_dim = np.zeros(len(query_emb), dtype="float32")
hnsw_vectors = np.hstack((query_emb, aux_dim.reshape(-1, 1)))

View File

@ -338,8 +338,7 @@ class EmbeddingRetriever(BaseRetriever):
# text is single string, sentence-transformers needs a list of strings
# get back list of numpy embedding vectors
emb = self.embedding_model.encode(texts) # type: ignore
# cast to float64 as float32 can cause trouble when serializing for ES
emb = [(r.astype('float64')) for r in emb]
emb = [r for r in emb]
return emb
def embed_queries(self, texts: List[str]) -> List[np.array]:

View File

@ -1,12 +1,10 @@
import pytest
from haystack import Finder
from haystack.retriever.dense import EmbeddingRetriever
def test_faq_retriever_in_memory_store():
from haystack.document_store.memory import InMemoryDocumentStore
from haystack.retriever.dense import EmbeddingRetriever
document_store = InMemoryDocumentStore(embedding_field="embedding")
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
def test_embedding_retriever(document_store):
documents = [
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},

View File

@ -3,19 +3,20 @@ import pytest
from haystack import Document
from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.dense import EmbeddingRetriever
from haystack import Finder
DOCUMENTS = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
]
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_write_docs(document_store, index_buffer_size):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float32)},
]
document_store.index_buffer_size = index_buffer_size
document_store.write_documents(documents)
document_store.write_documents(DOCUMENTS)
documents_indexed = document_store.get_all_documents()
# test if correct vector_ids are assigned
@ -25,14 +26,14 @@ def test_faiss_write_docs(document_store, index_buffer_size):
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
# we currently don't get the embeddings back when we call document_store.get_all_documents()
original_doc = [d for d in documents if d["text"] == doc.text][0]
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.0001)
assert np.allclose(original_doc["embedding"], stored_emb[:-1], rtol=0.01)
# test insertion of documents in an existing index fails
with pytest.raises(Exception):
document_store.write_documents(documents)
document_store.write_documents(DOCUMENTS)
# test saving the index
document_store.save("haystack_test_faiss")
@ -43,17 +44,12 @@ def test_faiss_write_docs(document_store, index_buffer_size):
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float32)},
]
# adjust buffer size
document_store.index_buffer_size = index_buffer_size
# initial write
document_store.write_documents(documents)
document_store.write_documents(DOCUMENTS)
# do the update
retriever = DensePassageRetriever(document_store=document_store,
@ -66,20 +62,41 @@ def test_faiss_update_docs(document_store, index_buffer_size):
documents_indexed = document_store.get_all_documents()
# test if number of documents is correct
assert len(documents_indexed) == len(documents)
assert len(documents_indexed) == len(DOCUMENTS)
# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(documents)
assert len(vector_ids) == len(DOCUMENTS)
# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
original_doc = [d for d in documents if d["text"] == doc.text][0]
original_doc = [d for d in DOCUMENTS if d["text"] == doc.text][0]
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.0001)
assert np.allclose(updated_embedding, stored_emb[:-1], rtol=0.01)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_retrieving(document_store):
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
result = retriever.retrieve(query="How to test this?")
assert len(result) == 3
assert type(result[0]) == Document
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_finding(document_store):
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
finder = Finder(reader=None, retriever=retriever)
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
assert len(prediction.get('answers', [])) == 1