Fix update_embeddings function in FAISSDocumentStore and add retriever fixture in tests (#481)

* 1. Prevent update_embeddings function in FAISSDocumentStore to set faiss_index as None when document store does not have any docs.

2. cleaning up tests by adding fixture for retriever.

* TfidfRetriever need document store with documents during initialization as it call fit() function in constructor so fixing it by checking self.paragraphs of None

* Fix naming of retriever's fixture (embedded to embedding and tfid to tfidf)
This commit is contained in:
Lalit Pagaria 2020-10-14 16:15:04 +02:00 committed by GitHub
parent ecaf7b8f0b
commit 2e9f3c1512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 111 additions and 73 deletions

View File

@ -128,17 +128,19 @@ class FAISSDocumentStore(SQLDocumentStore):
:param index: (SQL) index name for storing the docs and metadata
:return: None
"""
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
index = index or self.index
documents = self.get_all_documents(index=index)
if len(documents) == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
self.faiss_index = None
return
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
self.faiss_index.reset()
logger.info(f"Updating embeddings for {len(documents)} docs...")
embeddings = retriever.embed_passages(documents) # type: ignore
assert len(documents) == len(embeddings)

View File

@ -167,6 +167,12 @@ class TfidfRetriever(BaseRetriever):
return documents
def fit(self):
if not self.paragraphs or len(self.paragraphs) == 0:
self.paragraphs = self._get_all_paragraphs()
if not self.paragraphs or len(self.paragraphs) == 0:
logger.warning("Fit method called with empty document store")
return
self.df = pd.DataFrame.from_dict(self.paragraphs)
self.df["text"] = self.df["text"].apply(lambda x: " ".join(x))
self.tfidf_matrix = self.vectorizer.fit_transform(self.df["text"])

View File

@ -7,6 +7,9 @@ from sys import platform
import pytest
import requests
from elasticsearch import Elasticsearch
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
@ -157,6 +160,16 @@ def document_store(request, test_docs_xs, elasticsearch_fixture):
return get_document_store(request.param)
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
def retriever(request, document_store):
return get_retriever(request.param, document_store)
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
def retriever_with_docs(request, document_store_with_docs):
return get_retriever(request.param, document_store_with_docs)
def get_document_store(document_store_type):
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
@ -177,3 +190,27 @@ def get_document_store(document_store_type):
raise Exception(f"No document store fixture for '{document_store_type}'")
return document_store
def get_retriever(retriever_type, document_store):
if retriever_type == "dpr":
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True,
remove_sep_tok_from_untitled_passages=True)
elif retriever_type == "tfidf":
return TfidfRetriever(document_store=document_store)
elif retriever_type == "embedding":
retriever = EmbeddingRetriever(document_store=document_store,
embedding_model="deepset/sentence_bert",
use_gpu=False)
elif retriever_type == "elsticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
else:
raise Exception(f"No retriever fixture for '{retriever_type}'")
return retriever

View File

@ -1,13 +1,13 @@
import pytest
import time
from haystack.retriever.dense import DensePassageRetriever
from haystack import Document
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
def test_dpr_inmemory_retrieval(document_store):
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
def test_dpr_inmemory_retrieval(document_store, retriever):
documents = [
Document(
@ -33,11 +33,6 @@ def test_dpr_inmemory_retrieval(document_store):
document_store.delete_all_documents(index="test_dpr")
document_store.write_documents(documents, index="test_dpr")
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=True, embed_title=True,
remove_sep_tok_from_untitled_passages=True)
document_store.update_embeddings(retriever=retriever, index="test_dpr")
time.sleep(2)

View File

@ -3,21 +3,20 @@ import pytest
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_dummy_retriever(document_store_with_docs):
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever
retriever = ElasticsearchFilterOnlyRetriever(document_store_with_docs)
@pytest.mark.parametrize("retriever_with_docs", ["es_filter_only"], indirect=True)
def test_dummy_retriever(retriever_with_docs, document_store_with_docs):
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
assert type(result[0]) == Document
assert result[0].text == "My name is Carla and I live in Berlin"
assert result[0].meta["name"] == "filename1"
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
assert type(result[0]) == Document
assert result[0].text == "My name is Carla and I live in Berlin"
assert result[0].meta["name"] == "filename1"
result = retriever.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
assert type(result[0]) == Document
assert result[0].text == "My name is Christelle and I live in Paris"
assert result[0].meta["name"] == "filename3"

View File

@ -1,35 +1,33 @@
from haystack.retriever.sparse import ElasticsearchRetriever
import pytest
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_elasticsearch_retrieval(document_store_with_docs):
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?")
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
def test_elasticsearch_retrieval(retriever_with_docs, document_store_with_docs):
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 3
assert res[0].meta["name"] == "filename1"
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_elasticsearch_retrieval_filters(document_store_with_docs):
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
def test_elasticsearch_retrieval_filters(retriever_with_docs, document_store_with_docs):
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 1
assert res[0].meta["name"] == "filename1"
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
assert len(res) == 0
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
assert len(res) == 0
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 1
assert res[0].meta["name"] == "filename1"
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
assert len(res) == 0

View File

@ -1,10 +1,10 @@
import pytest
from haystack import Finder
from haystack.retriever.dense import EmbeddingRetriever
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
def test_embedding_retriever(document_store):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_embedding_retriever(retriever, document_store):
documents = [
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
@ -20,8 +20,6 @@ def test_embedding_retriever(document_store):
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
]
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
embedded = []
for doc in documents:
doc['embedding'] = retriever.embed([doc['meta']['question']])[0]

View File

@ -1,6 +1,5 @@
import pytest
from haystack.document_store.base import BaseDocumentStore
from haystack.retriever.sparse import ElasticsearchRetriever
from haystack.finder import Finder
@ -62,9 +61,8 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("open_domain", [True, False])
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
retriever = ElasticsearchRetriever(document_store=document_store)
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
# add eval data (SQUAD format)
document_store.delete_all_documents(index="test_eval_document")
document_store.delete_all_documents(index="test_feedback")
@ -83,8 +81,8 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_eval_finder(document_store: BaseDocumentStore, reader):
retriever = ElasticsearchRetriever(document_store=document_store)
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
finder = Finder(reader=reader, retriever=retriever)
# add eval data (SQUAD format)

View File

@ -4,8 +4,6 @@ from haystack import Document
import faiss
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.retriever.dense import DensePassageRetriever
from haystack.retriever.dense import EmbeddingRetriever
from haystack import Finder
DOCUMENTS = [
@ -47,7 +45,8 @@ def test_faiss_index_save_and_load(document_store):
assert document_store.faiss_index.ntotal == 0
# test loading the index
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db",
faiss_file_path="haystack_test_faiss")
# check faiss index is restored
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
@ -78,21 +77,15 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size):
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
# adjust buffer size
document_store.index_buffer_size = index_buffer_size
# initial write
document_store.write_documents(DOCUMENTS)
# do the update
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True,
remove_sep_tok_from_untitled_passages=True)
document_store.update_embeddings(retriever=retriever)
documents_indexed = document_store.get_all_documents()
@ -109,28 +102,40 @@ def test_faiss_update_docs(document_store, index_buffer_size):
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_retrieving(document_store):
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
def test_faiss_update_with_empty_store(document_store, retriever):
# Call update with empty doc store
document_store.update_embeddings(retriever=retriever)
# initial write
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
documents_indexed = document_store.get_all_documents()
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_faiss_retrieving(document_store, retriever):
document_store.write_documents(DOCUMENTS)
result = retriever.retrieve(query="How to test this?")
assert len(result) == len(DOCUMENTS)
assert type(result[0]) == Document
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_finding(document_store):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_faiss_finding(document_store, retriever):
document_store.write_documents(DOCUMENTS)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
use_gpu=False)
finder = Finder(reader=None, retriever=retriever)
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
assert len(prediction.get('answers', [])) == 1
def test_faiss_passing_index_from_outside():
d = 768
nlist = 2
@ -147,4 +152,4 @@ def test_faiss_passing_index_from_outside():
documents_indexed = document_store.get_all_documents(index="document")
# test document correctness
check_data_correctness(documents_indexed, DOCUMENTS)
check_data_correctness(documents_indexed, DOCUMENTS)

View File

@ -1,11 +1,10 @@
from haystack import Finder
from haystack.retriever.sparse import TfidfRetriever
import pytest
def test_finder_get_answers(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_finder_get_answers(reader, retriever_with_docs, document_store_with_docs):
finder = Finder(reader, retriever_with_docs)
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
top_k_reader=3)
assert prediction is not None
@ -19,9 +18,9 @@ def test_finder_get_answers(reader, document_store_with_docs):
assert len(prediction["answers"]) == 3
def test_finder_offsets(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_finder_offsets(reader, retriever_with_docs, document_store_with_docs):
finder = Finder(reader, retriever_with_docs)
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
top_k_reader=5)
@ -32,9 +31,9 @@ def test_finder_offsets(reader, document_store_with_docs):
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
def test_finder_get_answers_single_result(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_finder_get_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
finder = Finder(reader, retriever_with_docs)
query = "testing finder"
prediction = finder.get_answers(question=query, top_k_retriever=1,
top_k_reader=1)

View File

@ -1,5 +1,9 @@
def test_tfidf_retriever():
from haystack.retriever.sparse import TfidfRetriever
import pytest
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["tfidf"], indirect=True)
def test_tfidf_retriever(document_store, retriever):
test_docs = [
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
@ -7,11 +11,8 @@ def test_tfidf_retriever():
{"name": "testing the finder 3", "text": "alien says arghh"}
]
from haystack.document_store.memory import InMemoryDocumentStore
document_store = InMemoryDocumentStore()
document_store.write_documents(test_docs)
retriever = TfidfRetriever(document_store)
retriever.fit()
doc = retriever.retrieve("godzilla", top_k=1)[0]
assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"