mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	Fix update_embeddings function in FAISSDocumentStore and add retriever fixture in tests (#481)
* 1. Prevent update_embeddings function in FAISSDocumentStore to set faiss_index as None when document store does not have any docs. 2. cleaning up tests by adding fixture for retriever. * TfidfRetriever need document store with documents during initialization as it call fit() function in constructor so fixing it by checking self.paragraphs of None * Fix naming of retriever's fixture (embedded to embedding and tfid to tfidf)
This commit is contained in:
		
							parent
							
								
									ecaf7b8f0b
								
							
						
					
					
						commit
						2e9f3c1512
					
				@ -128,17 +128,19 @@ class FAISSDocumentStore(SQLDocumentStore):
 | 
			
		||||
        :param index: (SQL) index name for storing the docs and metadata
 | 
			
		||||
        :return: None
 | 
			
		||||
        """
 | 
			
		||||
        # To clear out the FAISS index contents and frees all memory immediately that is in use by the index
 | 
			
		||||
        self.faiss_index.reset()
 | 
			
		||||
        if not self.faiss_index:
 | 
			
		||||
            raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
 | 
			
		||||
 | 
			
		||||
        index = index or self.index
 | 
			
		||||
        documents = self.get_all_documents(index=index)
 | 
			
		||||
 | 
			
		||||
        if len(documents) == 0:
 | 
			
		||||
            logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
 | 
			
		||||
            self.faiss_index = None
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # To clear out the FAISS index contents and frees all memory immediately that is in use by the index
 | 
			
		||||
        self.faiss_index.reset()
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Updating embeddings for {len(documents)} docs...")
 | 
			
		||||
        embeddings = retriever.embed_passages(documents)  # type: ignore
 | 
			
		||||
        assert len(documents) == len(embeddings)
 | 
			
		||||
 | 
			
		||||
@ -167,6 +167,12 @@ class TfidfRetriever(BaseRetriever):
 | 
			
		||||
        return documents
 | 
			
		||||
 | 
			
		||||
    def fit(self):
 | 
			
		||||
        if not self.paragraphs or len(self.paragraphs) == 0:
 | 
			
		||||
            self.paragraphs = self._get_all_paragraphs()
 | 
			
		||||
            if not self.paragraphs or len(self.paragraphs) == 0:
 | 
			
		||||
                logger.warning("Fit method called with empty document store")
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        self.df = pd.DataFrame.from_dict(self.paragraphs)
 | 
			
		||||
        self.df["text"] = self.df["text"].apply(lambda x: " ".join(x))
 | 
			
		||||
        self.tfidf_matrix = self.vectorizer.fit_transform(self.df["text"])
 | 
			
		||||
@ -7,6 +7,9 @@ from sys import platform
 | 
			
		||||
import pytest
 | 
			
		||||
import requests
 | 
			
		||||
from elasticsearch import Elasticsearch
 | 
			
		||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
 | 
			
		||||
 | 
			
		||||
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
 | 
			
		||||
 | 
			
		||||
from haystack import Document
 | 
			
		||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
 | 
			
		||||
@ -157,6 +160,16 @@ def document_store(request, test_docs_xs, elasticsearch_fixture):
 | 
			
		||||
    return get_document_store(request.param)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
 | 
			
		||||
def retriever(request, document_store):
 | 
			
		||||
    return get_retriever(request.param, document_store)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(params=["es_filter_only", "elsticsearch", "dpr", "embedding", "tfidf"])
 | 
			
		||||
def retriever_with_docs(request, document_store_with_docs):
 | 
			
		||||
    return get_retriever(request.param, document_store_with_docs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_document_store(document_store_type):
 | 
			
		||||
    if document_store_type == "sql":
 | 
			
		||||
        if os.path.exists("haystack_test.db"):
 | 
			
		||||
@ -177,3 +190,27 @@ def get_document_store(document_store_type):
 | 
			
		||||
        raise Exception(f"No document store fixture for '{document_store_type}'")
 | 
			
		||||
 | 
			
		||||
    return document_store
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_retriever(retriever_type, document_store):
 | 
			
		||||
 | 
			
		||||
    if retriever_type == "dpr":
 | 
			
		||||
        retriever = DensePassageRetriever(document_store=document_store,
 | 
			
		||||
                                          query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
 | 
			
		||||
                                          passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
 | 
			
		||||
                                          use_gpu=False, embed_title=True,
 | 
			
		||||
                                          remove_sep_tok_from_untitled_passages=True)
 | 
			
		||||
    elif retriever_type == "tfidf":
 | 
			
		||||
        return TfidfRetriever(document_store=document_store)
 | 
			
		||||
    elif retriever_type == "embedding":
 | 
			
		||||
        retriever = EmbeddingRetriever(document_store=document_store,
 | 
			
		||||
                                       embedding_model="deepset/sentence_bert",
 | 
			
		||||
                                       use_gpu=False)
 | 
			
		||||
    elif retriever_type == "elsticsearch":
 | 
			
		||||
        retriever = ElasticsearchRetriever(document_store=document_store)
 | 
			
		||||
    elif retriever_type == "es_filter_only":
 | 
			
		||||
        retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
 | 
			
		||||
    else:
 | 
			
		||||
        raise Exception(f"No retriever fixture for '{retriever_type}'")
 | 
			
		||||
 | 
			
		||||
    return retriever
 | 
			
		||||
 | 
			
		||||
@ -1,13 +1,13 @@
 | 
			
		||||
import pytest
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from haystack.retriever.dense import DensePassageRetriever
 | 
			
		||||
from haystack import Document
 | 
			
		||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
 | 
			
		||||
def test_dpr_inmemory_retrieval(document_store):
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
 | 
			
		||||
def test_dpr_inmemory_retrieval(document_store, retriever):
 | 
			
		||||
 | 
			
		||||
    documents = [
 | 
			
		||||
        Document(
 | 
			
		||||
@ -33,11 +33,6 @@ def test_dpr_inmemory_retrieval(document_store):
 | 
			
		||||
 | 
			
		||||
    document_store.delete_all_documents(index="test_dpr")
 | 
			
		||||
    document_store.write_documents(documents, index="test_dpr")
 | 
			
		||||
    retriever = DensePassageRetriever(document_store=document_store,
 | 
			
		||||
                                      query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
 | 
			
		||||
                                      passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
 | 
			
		||||
                                      use_gpu=True, embed_title=True,
 | 
			
		||||
                                      remove_sep_tok_from_untitled_passages=True)
 | 
			
		||||
    document_store.update_embeddings(retriever=retriever, index="test_dpr")
 | 
			
		||||
    time.sleep(2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,21 +3,20 @@ import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
 | 
			
		||||
def test_dummy_retriever(document_store_with_docs):
 | 
			
		||||
    from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever
 | 
			
		||||
    retriever = ElasticsearchFilterOnlyRetriever(document_store_with_docs)
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["es_filter_only"], indirect=True)
 | 
			
		||||
def test_dummy_retriever(retriever_with_docs, document_store_with_docs):
 | 
			
		||||
 | 
			
		||||
    result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
 | 
			
		||||
    result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
 | 
			
		||||
    assert type(result[0]) == Document
 | 
			
		||||
    assert result[0].text == "My name is Carla and I live in Berlin"
 | 
			
		||||
    assert result[0].meta["name"] == "filename1"
 | 
			
		||||
 | 
			
		||||
    result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
 | 
			
		||||
    result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
 | 
			
		||||
    assert type(result[0]) == Document
 | 
			
		||||
    assert result[0].text == "My name is Carla and I live in Berlin"
 | 
			
		||||
    assert result[0].meta["name"] == "filename1"
 | 
			
		||||
 | 
			
		||||
    result = retriever.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
 | 
			
		||||
    result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
 | 
			
		||||
    assert type(result[0]) == Document
 | 
			
		||||
    assert result[0].text == "My name is Christelle and I live in Paris"
 | 
			
		||||
    assert result[0].meta["name"] == "filename3"
 | 
			
		||||
 | 
			
		||||
@ -1,35 +1,33 @@
 | 
			
		||||
from haystack.retriever.sparse import ElasticsearchRetriever
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
 | 
			
		||||
def test_elasticsearch_retrieval(document_store_with_docs):
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?")
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
 | 
			
		||||
def test_elasticsearch_retrieval(retriever_with_docs, document_store_with_docs):
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
 | 
			
		||||
    assert res[0].text == "My name is Carla and I live in Berlin"
 | 
			
		||||
    assert len(res) == 3
 | 
			
		||||
    assert res[0].meta["name"] == "filename1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
 | 
			
		||||
def test_elasticsearch_retrieval_filters(document_store_with_docs):
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["elsticsearch"], indirect=True)
 | 
			
		||||
def test_elasticsearch_retrieval_filters(retriever_with_docs, document_store_with_docs):
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
 | 
			
		||||
    assert res[0].text == "My name is Carla and I live in Berlin"
 | 
			
		||||
    assert len(res) == 1
 | 
			
		||||
    assert res[0].meta["name"] == "filename1"
 | 
			
		||||
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
 | 
			
		||||
    assert len(res) == 0
 | 
			
		||||
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
 | 
			
		||||
    assert len(res) == 0
 | 
			
		||||
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
 | 
			
		||||
    assert res[0].text == "My name is Carla and I live in Berlin"
 | 
			
		||||
    assert len(res) == 1
 | 
			
		||||
    assert res[0].meta["name"] == "filename1"
 | 
			
		||||
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
 | 
			
		||||
    res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
 | 
			
		||||
    assert len(res) == 0
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,10 @@
 | 
			
		||||
import pytest
 | 
			
		||||
from haystack import Finder
 | 
			
		||||
from haystack.retriever.dense import EmbeddingRetriever
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
 | 
			
		||||
def test_embedding_retriever(document_store):
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
 | 
			
		||||
def test_embedding_retriever(retriever, document_store):
 | 
			
		||||
 | 
			
		||||
    documents = [
 | 
			
		||||
        {'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
 | 
			
		||||
@ -20,8 +20,6 @@ def test_embedding_retriever(document_store):
 | 
			
		||||
        {'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False)
 | 
			
		||||
 | 
			
		||||
    embedded = []
 | 
			
		||||
    for doc in documents:
 | 
			
		||||
        doc['embedding'] = retriever.embed([doc['meta']['question']])[0]
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
import pytest
 | 
			
		||||
from haystack.document_store.base import BaseDocumentStore
 | 
			
		||||
from haystack.retriever.sparse import ElasticsearchRetriever
 | 
			
		||||
from haystack.finder import Finder
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,9 +61,8 @@ def test_eval_reader(reader, document_store: BaseDocumentStore):
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("open_domain", [True, False])
 | 
			
		||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store)
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
 | 
			
		||||
def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain, retriever):
 | 
			
		||||
    # add eval data (SQUAD format)
 | 
			
		||||
    document_store.delete_all_documents(index="test_eval_document")
 | 
			
		||||
    document_store.delete_all_documents(index="test_feedback")
 | 
			
		||||
@ -83,8 +81,8 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain):
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
 | 
			
		||||
def test_eval_finder(document_store: BaseDocumentStore, reader):
 | 
			
		||||
    retriever = ElasticsearchRetriever(document_store=document_store)
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["elsticsearch"], indirect=True)
 | 
			
		||||
def test_eval_finder(document_store: BaseDocumentStore, reader, retriever):
 | 
			
		||||
    finder = Finder(reader=reader, retriever=retriever)
 | 
			
		||||
 | 
			
		||||
    # add eval data (SQUAD format)
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,6 @@ from haystack import Document
 | 
			
		||||
import faiss
 | 
			
		||||
 | 
			
		||||
from haystack.document_store.faiss import FAISSDocumentStore
 | 
			
		||||
from haystack.retriever.dense import DensePassageRetriever
 | 
			
		||||
from haystack.retriever.dense import EmbeddingRetriever
 | 
			
		||||
from haystack import Finder
 | 
			
		||||
 | 
			
		||||
DOCUMENTS = [
 | 
			
		||||
@ -47,7 +45,8 @@ def test_faiss_index_save_and_load(document_store):
 | 
			
		||||
    assert document_store.faiss_index.ntotal == 0
 | 
			
		||||
 | 
			
		||||
    # test loading the index
 | 
			
		||||
    new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
 | 
			
		||||
    new_document_store = document_store.load(sql_url="sqlite:///haystack_test.db",
 | 
			
		||||
                                             faiss_file_path="haystack_test_faiss")
 | 
			
		||||
 | 
			
		||||
    # check faiss index is restored
 | 
			
		||||
    assert new_document_store.faiss_index.ntotal == len(DOCUMENTS)
 | 
			
		||||
@ -78,21 +77,15 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
 | 
			
		||||
def test_faiss_update_docs(document_store, index_buffer_size):
 | 
			
		||||
def test_faiss_update_docs(document_store, index_buffer_size, retriever):
 | 
			
		||||
    # adjust buffer size
 | 
			
		||||
    document_store.index_buffer_size = index_buffer_size
 | 
			
		||||
 | 
			
		||||
    # initial write
 | 
			
		||||
    document_store.write_documents(DOCUMENTS)
 | 
			
		||||
 | 
			
		||||
    # do the update
 | 
			
		||||
    retriever = DensePassageRetriever(document_store=document_store,
 | 
			
		||||
                                      query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
 | 
			
		||||
                                      passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
 | 
			
		||||
                                      use_gpu=False, embed_title=True,
 | 
			
		||||
                                      remove_sep_tok_from_untitled_passages=True)
 | 
			
		||||
 | 
			
		||||
    document_store.update_embeddings(retriever=retriever)
 | 
			
		||||
    documents_indexed = document_store.get_all_documents()
 | 
			
		||||
 | 
			
		||||
@ -109,28 +102,40 @@ def test_faiss_update_docs(document_store, index_buffer_size):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
 | 
			
		||||
def test_faiss_retrieving(document_store):
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
 | 
			
		||||
def test_faiss_update_with_empty_store(document_store, retriever):
 | 
			
		||||
    # Call update with empty doc store
 | 
			
		||||
    document_store.update_embeddings(retriever=retriever)
 | 
			
		||||
 | 
			
		||||
    # initial write
 | 
			
		||||
    document_store.write_documents(DOCUMENTS)
 | 
			
		||||
 | 
			
		||||
    retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
 | 
			
		||||
                                   use_gpu=False)
 | 
			
		||||
    documents_indexed = document_store.get_all_documents()
 | 
			
		||||
 | 
			
		||||
    # test document correctness
 | 
			
		||||
    check_data_correctness(documents_indexed, DOCUMENTS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
 | 
			
		||||
def test_faiss_retrieving(document_store, retriever):
 | 
			
		||||
    document_store.write_documents(DOCUMENTS)
 | 
			
		||||
    result = retriever.retrieve(query="How to test this?")
 | 
			
		||||
    assert len(result) == len(DOCUMENTS)
 | 
			
		||||
    assert type(result[0]) == Document
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
 | 
			
		||||
def test_faiss_finding(document_store):
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
 | 
			
		||||
def test_faiss_finding(document_store, retriever):
 | 
			
		||||
    document_store.write_documents(DOCUMENTS)
 | 
			
		||||
 | 
			
		||||
    retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert",
 | 
			
		||||
                                   use_gpu=False)
 | 
			
		||||
    finder = Finder(reader=None, retriever=retriever)
 | 
			
		||||
 | 
			
		||||
    prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
 | 
			
		||||
 | 
			
		||||
    assert len(prediction.get('answers', [])) == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_faiss_passing_index_from_outside():
 | 
			
		||||
    d = 768
 | 
			
		||||
    nlist = 2
 | 
			
		||||
@ -147,4 +152,4 @@ def test_faiss_passing_index_from_outside():
 | 
			
		||||
    documents_indexed = document_store.get_all_documents(index="document")
 | 
			
		||||
 | 
			
		||||
    # test document correctness
 | 
			
		||||
    check_data_correctness(documents_indexed, DOCUMENTS)
 | 
			
		||||
    check_data_correctness(documents_indexed, DOCUMENTS)
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,10 @@
 | 
			
		||||
from haystack import Finder
 | 
			
		||||
from haystack.retriever.sparse import TfidfRetriever
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_finder_get_answers(reader, document_store_with_docs):
 | 
			
		||||
    retriever = TfidfRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    finder = Finder(reader, retriever)
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
 | 
			
		||||
def test_finder_get_answers(reader, retriever_with_docs, document_store_with_docs):
 | 
			
		||||
    finder = Finder(reader, retriever_with_docs)
 | 
			
		||||
    prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
 | 
			
		||||
                                    top_k_reader=3)
 | 
			
		||||
    assert prediction is not None
 | 
			
		||||
@ -19,9 +18,9 @@ def test_finder_get_answers(reader, document_store_with_docs):
 | 
			
		||||
    assert len(prediction["answers"]) == 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_finder_offsets(reader, document_store_with_docs):
 | 
			
		||||
    retriever = TfidfRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    finder = Finder(reader, retriever)
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
 | 
			
		||||
def test_finder_offsets(reader, retriever_with_docs, document_store_with_docs):
 | 
			
		||||
    finder = Finder(reader, retriever_with_docs)
 | 
			
		||||
    prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
 | 
			
		||||
                                    top_k_reader=5)
 | 
			
		||||
 | 
			
		||||
@ -32,9 +31,9 @@ def test_finder_offsets(reader, document_store_with_docs):
 | 
			
		||||
    assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_finder_get_answers_single_result(reader, document_store_with_docs):
 | 
			
		||||
    retriever = TfidfRetriever(document_store=document_store_with_docs)
 | 
			
		||||
    finder = Finder(reader, retriever)
 | 
			
		||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
 | 
			
		||||
def test_finder_get_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
 | 
			
		||||
    finder = Finder(reader, retriever_with_docs)
 | 
			
		||||
    query = "testing finder"
 | 
			
		||||
    prediction = finder.get_answers(question=query, top_k_retriever=1,
 | 
			
		||||
                                    top_k_reader=1)
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,9 @@
 | 
			
		||||
def test_tfidf_retriever():
 | 
			
		||||
    from haystack.retriever.sparse import TfidfRetriever
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
 | 
			
		||||
@pytest.mark.parametrize("retriever", ["tfidf"], indirect=True)
 | 
			
		||||
def test_tfidf_retriever(document_store, retriever):
 | 
			
		||||
 | 
			
		||||
    test_docs = [
 | 
			
		||||
        {"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
 | 
			
		||||
@ -7,11 +11,8 @@ def test_tfidf_retriever():
 | 
			
		||||
        {"name": "testing the finder 3", "text": "alien says arghh"}
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    from haystack.document_store.memory import InMemoryDocumentStore
 | 
			
		||||
    document_store = InMemoryDocumentStore()
 | 
			
		||||
    document_store.write_documents(test_docs)
 | 
			
		||||
 | 
			
		||||
    retriever = TfidfRetriever(document_store)
 | 
			
		||||
    retriever.fit()
 | 
			
		||||
    doc = retriever.retrieve("godzilla", top_k=1)[0]
 | 
			
		||||
    assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user