mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-09 22:33:47 +00:00
Fix update_embeddings function in FAISSDocumentStore and add retriever fixture in tests (#481)
* 1. Prevent update_embeddings function in FAISSDocumentStore to set faiss_index as None when document store does not have any docs. 2. cleaning up tests by adding fixture for retriever. * TfidfRetriever need document store with documents during initialization as it call fit() function in constructor so fixing it by checking self.paragraphs of None * Fix naming of retriever's fixture (embedded to embedding and tfid to tfidf)
This commit is contained in:
parent
ecaf7b8f0b
commit
2e9f3c1512
@ -128,17 +128,19 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
:param index: (SQL) index name for storing the docs and metadata
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
if not self.faiss_index:
|
||||||
self.faiss_index.reset()
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||||
|
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
documents = self.get_all_documents(index=index)
|
documents = self.get_all_documents(index=index)
|
||||||
|
|
||||||
if len(documents) == 0:
|
if len(documents) == 0:
|
||||||
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||||
self.faiss_index = None
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# To clear out the FAISS index contents and frees all memory immediately that is in use by the index
|
||||||
|
self.faiss_index.reset()
|
||||||
|
|
||||||
logger.info(f"Updating embeddings for {len(documents)} docs...")
|
logger.info(f"Updating embeddings for {len(documents)} docs...")
|
||||||
embeddings = retriever.embed_passages(documents) # type: ignore
|
embeddings = retriever.embed_passages(documents) # type: ignore
|
||||||
assert len(documents) == len(embeddings)
|
assert len(documents) == len(embeddings)
|
||||||
|
|||||||
@ -167,6 +167,12 @@ class TfidfRetriever(BaseRetriever):
|
|||||||
return documents
|
return documents
|
||||||
|
|
||||||
def fit(self):
|
def fit(self):
|
||||||
|
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||||
|
self.paragraphs = self._get_all_paragraphs()
|
||||||
|
if not self.paragraphs or len(self.paragraphs) == 0:
|
||||||
|
logger.warning("Fit method called with empty document store")
|
||||||
|
return
|
||||||
|
|
||||||
self.df = pd.DataFrame.from_dict(self.paragraphs)
|
self.df = pd.DataFrame.from_dict(self.paragraphs)
|
||||||
self.df["text"] = self.df["text"].apply(lambda x: " ".join(x))
|
self.df["text"] = self.df["text"].apply(lambda x: " ".join(x))
|
||||||
self.tfidf_matrix = self.vectorizer.fit_transform(self.df["text"])
|
self.tfidf_matrix = self.vectorizer.fit_transform(self.df["text"])
|
||||||
@ -7,6 +7,9 @@ from sys import platform
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||||
|
|
||||||
|
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
@ -157,6 +160,16 @@ def document_store(request, test_docs_xs, elasticsearch_fixture):
|
|||||||
return get_document_store(request.param)
|
return get_document_store(request.param)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
|
||||||
|
def retriever(request, document_store):
|
||||||
|
return get_retriever(request.param, document_store)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
|
||||||
|
def retriever_with_docs(request, document_store_with_docs):
|
||||||
|
return get_retriever(request.param, document_store_with_docs)
|
||||||
|
|
||||||
|
|
||||||
def get_document_store(document_store_type):
|
def get_document_store(document_store_type):
|
||||||
if document_store_type == "sql":
|
if document_store_type == "sql":
|
||||||
if os.path.exists("haystack_test.db"):
|
if os.path.exists("haystack_test.db"):
|
||||||
@ -177,3 +190,27 @@ def get_document_store(document_store_type):
|
|||||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||||
|
|
||||||
return document_store
|
return document_store
|
||||||
|
|
||||||
|
|
||||||
|
def get_retriever(retriever_type, document_store):
|
||||||
|
|
||||||
|
if retriever_type == "dpr":
|
||||||
|
retriever = DensePassageRetriever(document_store=document_store,
|
||||||
|
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||||
|
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||||
|
use_gpu=False, embed_title=True,
|
||||||
|
remove_sep_tok_from_untitled_passages=True)
|
||||||
|
elif retriever_type == "tfidf":
|
||||||
|
return TfidfRetriever(document_store=document_store)
|
||||||
|
elif retriever_type == "embedding":
|
||||||
|
retriever = EmbeddingRetriever(document_store=document_store,
|
||||||
|
embedding_model="deepset/sentence_bert",
|
||||||
|
use_gpu=False)
|
||||||
|
elif retriever_type == "elsticsearch":
|
||||||
|
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||||
|
elif retriever_type == "es_filter_only":
|
||||||
|
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
|
||||||
|
else:
|
||||||
|
raise Exception(f"No retriever fixture for '{retriever_type}'")
|
||||||
|
|
||||||
|
return retriever
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from haystack.retriever.dense import DensePassageRetriever
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
||||||
def test_dpr_inmemory_retrieval(document_store):
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
|
def test_dpr_inmemory_retrieval(document_store, retriever):
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
@ -33,11 +33,6 @@ def test_dpr_inmemory_retrieval(document_store):
|
|||||||
|
|
||||||
document_store.delete_all_documents(index="test_dpr")
|
document_store.delete_all_documents(index="test_dpr")
|
||||||
document_store.write_documents(documents, index="test_dpr")
|
document_store.write_documents(documents, index="test_dpr")
|
||||||
retriever = DensePassageRetriever(document_store=document_store,
|
|
||||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
||||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
||||||
use_gpu=True, embed_title=True,
|
|
||||||
remove_sep_tok_from_untitled_passages=True)
|
|
||||||
document_store.update_embeddings(retriever=retriever, index="test_dpr")
|
document_store.update_embeddings(retriever=retriever, index="test_dpr")
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
|
|||||||
@ -3,21 +3,20 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||||
def test_dummy_retriever(document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["es_filter_only"], indirect=True)
|
||||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever
|
def test_dummy_retriever(retriever_with_docs, document_store_with_docs):
|
||||||
retriever = ElasticsearchFilterOnlyRetriever(document_store_with_docs)
|
|
||||||
|
|
||||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
|
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
|
||||||
assert type(result[0]) == Document
|
assert type(result[0]) == Document
|
||||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||||
assert result[0].meta["name"] == "filename1"
|
assert result[0].meta["name"] == "filename1"
|
||||||
|
|
||||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
|
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
|
||||||
assert type(result[0]) == Document
|
assert type(result[0]) == Document
|
||||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||||
assert result[0].meta["name"] == "filename1"
|
assert result[0].meta["name"] == "filename1"
|
||||||
|
|
||||||
result = retriever.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
||||||
assert type(result[0]) == Document
|
assert type(result[0]) == Document
|
||||||
assert result[0].text == "My name is Christelle and I live in Paris"
|
assert result[0].text == "My name is Christelle and I live in Paris"
|
||||||
assert result[0].meta["name"] == "filename3"
|
assert result[0].meta["name"] == "filename3"
|
||||||
|
|||||||
@ -1,35 +1,33 @@
|
|||||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||||
def test_elasticsearch_retrieval(document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
def test_elasticsearch_retrieval(retriever_with_docs, document_store_with_docs):
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?")
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
||||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||||
assert len(res) == 3
|
assert len(res) == 3
|
||||||
assert res[0].meta["name"] == "filename1"
|
assert res[0].meta["name"] == "filename1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||||
def test_elasticsearch_retrieval_filters(document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
def test_elasticsearch_retrieval_filters(retriever_with_docs, document_store_with_docs):
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
|
||||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||||
assert len(res) == 1
|
assert len(res) == 1
|
||||||
assert res[0].meta["name"] == "filename1"
|
assert res[0].meta["name"] == "filename1"
|
||||||
|
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
|
||||||
assert len(res) == 0
|
assert len(res) == 0
|
||||||
|
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
|
||||||
assert len(res) == 0
|
assert len(res) == 0
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
|
|
||||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||||
assert len(res) == 1
|
assert len(res) == 1
|
||||||
assert res[0].meta["name"] == "filename1"
|
assert res[0].meta["name"] == "filename1"
|
||||||
|
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
|
||||||
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
|
|
||||||
assert len(res) == 0
|
assert len(res) == 0
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from haystack import Finder
|
from haystack import Finder
|
||||||
from haystack.retriever.dense import EmbeddingRetriever
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
||||||
def test_embedding_retriever(document_store):
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
|
def test_embedding_retriever(retriever, document_store):
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
|
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
|
||||||
@ -20,8 +20,6 @@ def test_embedding_retriever(document_store):
|
|||||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||||
]
|
]
|
||||||
|
|
||||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
|
|
||||||
|
|
||||||
embedded = []
|
embedded = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
doc['embedding'] = retriever.embed([doc['meta']['question']])[0]
|
doc['embedding'] = retriever.embed([doc['meta']['question']])[0]
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from haystack.document_store.base import BaseDocumentStore
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
|
||||||
from haystack.finder import Finder
|
from haystack.finder import Finder
|
||||||
|
|
||||||
|
|
||||||
@ -62,9 +61,8 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||||
@pytest.mark.parametrize("open_domain", [True, False])
|
@pytest.mark.parametrize("open_domain", [True, False])
|
||||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
|
||||||
|
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
document_store.delete_all_documents(index="test_eval_document")
|
document_store.delete_all_documents(index="test_eval_document")
|
||||||
document_store.delete_all_documents(index="test_feedback")
|
document_store.delete_all_documents(index="test_feedback")
|
||||||
@ -83,8 +81,8 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||||
def test_eval_finder(document_store: BaseDocumentStore, reader):
|
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
|
||||||
finder = Finder(reader=reader, retriever=retriever)
|
finder = Finder(reader=reader, retriever=retriever)
|
||||||
|
|
||||||
# add eval data (SQUAD format)
|
# add eval data (SQUAD format)
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from haystack import Document
|
|||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
from haystack.document_store.faiss import FAISSDocumentStore
|
from haystack.document_store.faiss import FAISSDocumentStore
|
||||||
from haystack.retriever.dense import DensePassageRetriever
|
|
||||||
from haystack.retriever.dense import EmbeddingRetriever
|
|
||||||
from haystack import Finder
|
from haystack import Finder
|
||||||
|
|
||||||
DOCUMENTS = [
|
DOCUMENTS = [
|
||||||
@ -47,7 +45,8 @@ def test_faiss_index_save_and_load(document_store):
|
|||||||
assert document_store.faiss_index.ntotal == 0
|
assert document_store.faiss_index.ntotal == 0
|
||||||
|
|
||||||
# test loading the index
|
# test loading the index
|
||||||
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
|
new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db",
|
||||||
|
faiss_file_path="haystack_test_faiss")
|
||||||
|
|
||||||
# check faiss index is restored
|
# check faiss index is restored
|
||||||
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
|
assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
|
||||||
@ -78,21 +77,15 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
|
||||||
def test_faiss_update_docs(document_store, index_buffer_size):
|
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
|
||||||
# adjust buffer size
|
# adjust buffer size
|
||||||
document_store.index_buffer_size = index_buffer_size
|
document_store.index_buffer_size = index_buffer_size
|
||||||
|
|
||||||
# initial write
|
# initial write
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
|
|
||||||
# do the update
|
|
||||||
retriever = DensePassageRetriever(document_store=document_store,
|
|
||||||
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
||||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
|
||||||
use_gpu=False, embed_title=True,
|
|
||||||
remove_sep_tok_from_untitled_passages=True)
|
|
||||||
|
|
||||||
document_store.update_embeddings(retriever=retriever)
|
document_store.update_embeddings(retriever=retriever)
|
||||||
documents_indexed = document_store.get_all_documents()
|
documents_indexed = document_store.get_all_documents()
|
||||||
|
|
||||||
@ -109,28 +102,40 @@ def test_faiss_update_docs(document_store, index_buffer_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
def test_faiss_retrieving(document_store):
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
|
def test_faiss_update_with_empty_store(document_store, retriever):
|
||||||
|
# Call update with empty doc store
|
||||||
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
|
# initial write
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
|
|
||||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
documents_indexed = document_store.get_all_documents()
|
||||||
use_gpu=False)
|
|
||||||
|
# test document correctness
|
||||||
|
check_data_correctness(documents_indexed, DOCUMENTS)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
|
def test_faiss_retrieving(document_store, retriever):
|
||||||
|
document_store.write_documents(DOCUMENTS)
|
||||||
result = retriever.retrieve(query="How to test this?")
|
result = retriever.retrieve(query="How to test this?")
|
||||||
assert len(result) == len(DOCUMENTS)
|
assert len(result) == len(DOCUMENTS)
|
||||||
assert type(result[0]) == Document
|
assert type(result[0]) == Document
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
||||||
def test_faiss_finding(document_store):
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
|
def test_faiss_finding(document_store, retriever):
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
|
|
||||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
|
|
||||||
use_gpu=False)
|
|
||||||
finder = Finder(reader=None, retriever=retriever)
|
finder = Finder(reader=None, retriever=retriever)
|
||||||
|
|
||||||
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||||
|
|
||||||
assert len(prediction.get('answers', [])) == 1
|
assert len(prediction.get('answers', [])) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_faiss_passing_index_from_outside():
|
def test_faiss_passing_index_from_outside():
|
||||||
d = 768
|
d = 768
|
||||||
nlist = 2
|
nlist = 2
|
||||||
@ -147,4 +152,4 @@ def test_faiss_passing_index_from_outside():
|
|||||||
documents_indexed = document_store.get_all_documents(index="document")
|
documents_indexed = document_store.get_all_documents(index="document")
|
||||||
|
|
||||||
# test document correctness
|
# test document correctness
|
||||||
check_data_correctness(documents_indexed, DOCUMENTS)
|
check_data_correctness(documents_indexed, DOCUMENTS)
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
from haystack import Finder
|
from haystack import Finder
|
||||||
from haystack.retriever.sparse import TfidfRetriever
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_finder_get_answers(reader, document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
def test_finder_get_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever_with_docs)
|
||||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||||
top_k_reader=3)
|
top_k_reader=3)
|
||||||
assert prediction is not None
|
assert prediction is not None
|
||||||
@ -19,9 +18,9 @@ def test_finder_get_answers(reader, document_store_with_docs):
|
|||||||
assert len(prediction["answers"]) == 3
|
assert len(prediction["answers"]) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_finder_offsets(reader, document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
def test_finder_offsets(reader, retriever_with_docs, document_store_with_docs):
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever_with_docs)
|
||||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||||
top_k_reader=5)
|
top_k_reader=5)
|
||||||
|
|
||||||
@ -32,9 +31,9 @@ def test_finder_offsets(reader, document_store_with_docs):
|
|||||||
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||||
|
|
||||||
|
|
||||||
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
def test_finder_get_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever_with_docs)
|
||||||
query = "testing finder"
|
query = "testing finder"
|
||||||
prediction = finder.get_answers(question=query, top_k_retriever=1,
|
prediction = finder.get_answers(question=query, top_k_retriever=1,
|
||||||
top_k_reader=1)
|
top_k_reader=1)
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
def test_tfidf_retriever():
|
import pytest
|
||||||
from haystack.retriever.sparse import TfidfRetriever
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("retriever", ["tfidf"], indirect=True)
|
||||||
|
def test_tfidf_retriever(document_store, retriever):
|
||||||
|
|
||||||
test_docs = [
|
test_docs = [
|
||||||
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
|
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
|
||||||
@ -7,11 +11,8 @@ def test_tfidf_retriever():
|
|||||||
{"name": "testing the finder 3", "text": "alien says arghh"}
|
{"name": "testing the finder 3", "text": "alien says arghh"}
|
||||||
]
|
]
|
||||||
|
|
||||||
from haystack.document_store.memory import InMemoryDocumentStore
|
|
||||||
document_store = InMemoryDocumentStore()
|
|
||||||
document_store.write_documents(test_docs)
|
document_store.write_documents(test_docs)
|
||||||
|
|
||||||
retriever = TfidfRetriever(document_store)
|
|
||||||
retriever.fit()
|
retriever.fit()
|
||||||
doc = retriever.retrieve("godzilla", top_k=1)[0]
|
doc = retriever.retrieve("godzilla", top_k=1)[0]
|
||||||
assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"
|
assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user