mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 06:13:43 +00:00
Fix update_embeddings function in FAISSDocumentStore and add retriever fixture in tests (#481)
* 1. Prevent update_embeddings function in FAISSDocumentStore to set faiss_index as None when document store does not have any docs. 2. cleaning up tests by adding fixture for retriever. * TfidfRetriever need document store with documents during initialization as it call fit() function in constructor so fixing it by checking self.paragraphs of None * Fix naming of retriever's fixture (embedded to embedding and tfid to tfidf)
This commit is contained in:
parent
ecaf7b8f0b
commit
2e9f3c1512
@ -128,17 +128,19 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
:param index: (SQL) index name for storing the docs and metadata
|
||||
:return: None
|
||||
"""
|
||||
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
||||
self.faiss_index.reset()
|
||||
if not self.faiss_index:
|
||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||
|
||||
index = index or self.index
|
||||
documents = self.get_all_documents(index=index)
|
||||
|
||||
if len(documents) == 0:
|
||||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||
self.faiss_index = None
|
||||
return
|
||||
|
||||
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
||||
self.faiss_index.reset()
|
||||
|
||||
logger.info(f"Updating embeddings for {len(documents)} docs...")
|
||||
embeddings = retriever.embed_passages(documents) # type: ignore
|
||||
assert len(documents) == len(embeddings)
|
||||
|
||||
@ -167,6 +167,12 @@ class TfidfRetriever(BaseRetriever):
|
||||
return documents
|
||||
|
||||
def fit(self):
|
||||
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||
self.paragraphs = self._get_all_paragraphs()
|
||||
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||
logger.warning("Fit method called with empty document store")
|
||||
return
|
||||
|
||||
self.df = pd.DataFrame.from_dict(self.paragraphs)
|
||||
self.df["text"] = self.df["text"].apply(lambda x: " ".join(x))
|
||||
self.tfidf_matrix = self.vectorizer.fit_transform(self.df["text"])
|
||||
@ -7,6 +7,9 @@ from sys import platform
|
||||
import pytest
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||
|
||||
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
|
||||
|
||||
from haystack import Document
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
@ -157,6 +160,16 @@ def document_store(request, test_docs_xs, elasticsearch_fixture):
|
||||
return get_document_store(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
|
||||
def retriever(request, document_store):
|
||||
return get_retriever(request.param, document_store)
|
||||
|
||||
|
||||
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
|
||||
def retriever_with_docs(request, document_store_with_docs):
|
||||
return get_retriever(request.param, document_store_with_docs)
|
||||
|
||||
|
||||
def get_document_store(document_store_type):
|
||||
if document_store_type == "sql":
|
||||
if os.path.exists("haystack_test.db"):
|
||||
@ -177,3 +190,27 @@ def get_document_store(document_store_type):
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
|
||||
return document_store
|
||||
|
||||
|
||||
def get_retriever(retriever_type, document_store):
|
||||
|
||||
if retriever_type == "dpr":
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False, embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
elif retriever_type == "tfidf":
|
||||
return TfidfRetriever(document_store=document_store)
|
||||
elif retriever_type == "embedding":
|
||||
retriever = EmbeddingRetriever(document_store=document_store,
|
||||
embedding_model="deepset/sentence_bert",
|
||||
use_gpu=False)
|
||||
elif retriever_type == "elsticsearch":
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
elif retriever_type == "es_filter_only":
|
||||
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
|
||||
else:
|
||||
raise Exception(f"No retriever fixture for '{retriever_type}'")
|
||||
|
||||
return retriever
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import pytest
|
||||
import time
|
||||
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack import Document
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
||||
def test_dpr_inmemory_retrieval(document_store):
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
def test_dpr_inmemory_retrieval(document_store, retriever):
|
||||
|
||||
documents = [
|
||||
Document(
|
||||
@ -33,11 +33,6 @@ def test_dpr_inmemory_retrieval(document_store):
|
||||
|
||||
document_store.delete_all_documents(index="test_dpr")
|
||||
document_store.write_documents(documents, index="test_dpr")
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True, embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
document_store.update_embeddings(retriever=retriever, index="test_dpr")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@ -3,21 +3,20 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
def test_dummy_retriever(document_store_with_docs):
|
||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever
|
||||
retriever = ElasticsearchFilterOnlyRetriever(document_store_with_docs)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["es_filter_only"], indirect=True)
|
||||
def test_dummy_retriever(retriever_with_docs, document_store_with_docs):
|
||||
|
||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||
assert result[0].meta["name"] == "filename1"
|
||||
|
||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||
assert result[0].meta["name"] == "filename1"
|
||||
|
||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Christelle and I live in Paris"
|
||||
assert result[0].meta["name"] == "filename3"
|
||||
|
||||
@ -1,35 +1,33 @@
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
def test_elasticsearch_retrieval(document_store_with_docs):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
res = retriever.retrieve(query="Who lives in Berlin?")
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
|
||||
def test_elasticsearch_retrieval(retriever_with_docs, document_store_with_docs):
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 3
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
def test_elasticsearch_retrieval_filters(document_store_with_docs):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
|
||||
def test_elasticsearch_retrieval_filters(retriever_with_docs, document_store_with_docs):
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 1
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
|
||||
assert len(res) == 0
|
||||
|
||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
|
||||
assert len(res) == 0
|
||||
|
||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 1
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
|
||||
assert len(res) == 0
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
from haystack import Finder
|
||||
from haystack.retriever.dense import EmbeddingRetriever
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
||||
def test_embedding_retriever(document_store):
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
def test_embedding_retriever(retriever, document_store):
|
||||
|
||||
documents = [
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
|
||||
@ -20,8 +20,6 @@ def test_embedding_retriever(document_store):
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
]
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
|
||||
|
||||
embedded = []
|
||||
for doc in documents:
|
||||
doc['embedding'] = retriever.embed([doc['meta']['question']])[0]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.finder import Finder
|
||||
|
||||
|
||||
@ -62,9 +61,8 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("open_domain", [True, False])
|
||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
|
||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
|
||||
# add eval data (SQUAD format)
|
||||
document_store.delete_all_documents(index="test_eval_document")
|
||||
document_store.delete_all_documents(index="test_feedback")
|
||||
@ -83,8 +81,8 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_eval_finder(document_store: BaseDocumentStore, reader):
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
|
||||
def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
|
||||
finder = Finder(reader=reader, retriever=retriever)
|
||||
|
||||
# add eval data (SQUAD format)
|
||||
|
||||
@ -4,8 +4,6 @@ from haystack import Document
|
||||
import faiss
|
||||
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.retriever.dense import EmbeddingRetriever
|
||||
from haystack import Finder
|
||||
|
||||
DOCUMENTS = [
|
||||
@ -47,7 +45,8 @@ def test_faiss_index_save_and_load(document_store):
|
||||
assert document_store.faiss_index.ntotal == 0
|
||||
|
||||
# test loading the index
|
||||
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
|
||||
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db",
|
||||
faiss_file_path="haystack_test_faiss")
|
||||
|
||||
# check faiss index is restored
|
||||
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
|
||||
@ -78,21 +77,15 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||
def test_faiss_update_docs(document_store, index_buffer_size):
|
||||
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
|
||||
# adjust buffer size
|
||||
document_store.index_buffer_size = index_buffer_size
|
||||
|
||||
# initial write
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
# do the update
|
||||
retriever = DensePassageRetriever(document_store=document_store,
|
||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False, embed_title=True,
|
||||
remove_sep_tok_from_untitled_passages=True)
|
||||
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
|
||||
@ -109,28 +102,40 @@ def test_faiss_update_docs(document_store, index_buffer_size):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
def test_faiss_retrieving(document_store):
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
def test_faiss_update_with_empty_store(document_store, retriever):
|
||||
# Call update with empty doc store
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
# initial write
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
||||
use_gpu=False)
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
|
||||
# test document correctness
|
||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
def test_faiss_retrieving(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
result = retriever.retrieve(query="How to test this?")
|
||||
assert len(result) == len(DOCUMENTS)
|
||||
assert type(result[0]) == Document
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||
def test_faiss_finding(document_store):
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
def test_faiss_finding(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
||||
use_gpu=False)
|
||||
finder = Finder(reader=None, retriever=retriever)
|
||||
|
||||
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||
|
||||
assert len(prediction.get('answers', [])) == 1
|
||||
|
||||
|
||||
def test_faiss_passing_index_from_outside():
|
||||
d = 768
|
||||
nlist = 2
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from haystack import Finder
|
||||
from haystack.retriever.sparse import TfidfRetriever
|
||||
import pytest
|
||||
|
||||
|
||||
def test_finder_get_answers(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_finder_get_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||
finder = Finder(reader, retriever_with_docs)
|
||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||
top_k_reader=3)
|
||||
assert prediction is not None
|
||||
@ -19,9 +18,9 @@ def test_finder_get_answers(reader, document_store_with_docs):
|
||||
assert len(prediction["answers"]) == 3
|
||||
|
||||
|
||||
def test_finder_offsets(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_finder_offsets(reader, retriever_with_docs, document_store_with_docs):
|
||||
finder = Finder(reader, retriever_with_docs)
|
||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||
top_k_reader=5)
|
||||
|
||||
@ -32,9 +31,9 @@ def test_finder_offsets(reader, document_store_with_docs):
|
||||
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||
|
||||
|
||||
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_finder_get_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
|
||||
finder = Finder(reader, retriever_with_docs)
|
||||
query = "testing finder"
|
||||
prediction = finder.get_answers(question=query, top_k_retriever=1,
|
||||
top_k_reader=1)
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
def test_tfidf_retriever():
|
||||
from haystack.retriever.sparse import TfidfRetriever
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["tfidf"], indirect=True)
|
||||
def test_tfidf_retriever(document_store, retriever):
|
||||
|
||||
test_docs = [
|
||||
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
|
||||
@ -7,11 +11,8 @@ def test_tfidf_retriever():
|
||||
{"name": "testing the finder 3", "text": "alien says arghh"}
|
||||
]
|
||||
|
||||
from haystack.document_store.memory import InMemoryDocumentStore
|
||||
document_store = InMemoryDocumentStore()
|
||||
document_store.write_documents(test_docs)
|
||||
|
||||
retriever = TfidfRetriever(document_store)
|
||||
retriever.fit()
|
||||
doc = retriever.retrieve("godzilla", top_k=1)[0]
|
||||
assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user