mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-03 18:36:04 +00:00
Add more tests (#213)
This commit is contained in:
parent
549f3a1285
commit
d2b26a99ff
@ -136,8 +136,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
del doc["meta"]
|
||||
bulk(self.client, documents, request_timeout=300)
|
||||
|
||||
def get_document_count(self) -> int:
|
||||
result = self.client.count()
|
||||
def get_document_count(self, index: Optional[str] = None,) -> int:
|
||||
if index is None:
|
||||
index = self.index
|
||||
result = self.client.count(index=index)
|
||||
count = result["count"]
|
||||
return count
|
||||
|
||||
|
||||
@ -2,10 +2,17 @@ import tarfile
|
||||
import time
|
||||
import urllib.request
|
||||
from subprocess import Popen, PIPE, STDOUT, run
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
|
||||
from haystack.database.sql import SQLDocumentStore
|
||||
from haystack.database.memory import InMemoryDocumentStore
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def elasticsearch_dir(tmpdir_factory):
|
||||
@ -19,6 +26,7 @@ def elasticsearch_fixture(elasticsearch_dir):
|
||||
client = Elasticsearch(hosts=[{"host": "localhost"}])
|
||||
client.info()
|
||||
except:
|
||||
print("Downloading and starting an Elasticsearch instance for the tests ...")
|
||||
thetarfile = "https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.1-linux-x86_64.tar.gz"
|
||||
ftpstream = urllib.request.urlopen(thetarfile)
|
||||
thetarfile = tarfile.open(fileobj=ftpstream, mode="r|gz")
|
||||
@ -41,3 +49,47 @@ def xpdf_fixture():
|
||||
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
|
||||
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_docs_xs():
|
||||
return [
|
||||
{"name": "filename1", "text": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1"}},
|
||||
{"name": "filename2", "text": "My name is Paul and I live in New York", "meta": {"meta_field": "test2"}},
|
||||
{"name": "filename3", "text": "My name is Christelle and I live in Paris", "meta": {"meta_field": "test3"}}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=["farm", "transformers"])
|
||||
def reader(request):
|
||||
if request.param == "farm":
|
||||
return FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
use_gpu=False, top_k_per_sample=5, num_processes=0)
|
||||
if request.param == "transformers":
|
||||
return TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased",
|
||||
use_gpu=-1)
|
||||
|
||||
|
||||
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
|
||||
def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
|
||||
if request.param == "sql":
|
||||
if os.path.exists("qa_test.db"):
|
||||
os.remove("qa_test.db")
|
||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||
document_store.write_documents(test_docs_xs)
|
||||
|
||||
if request.param == "memory":
|
||||
document_store = InMemoryDocumentStore()
|
||||
document_store.write_documents(test_docs_xs)
|
||||
|
||||
if request.param == "elasticsearch":
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test")
|
||||
assert document_store.get_document_count() == 0
|
||||
document_store.write_documents(test_docs_xs)
|
||||
time.sleep(2)
|
||||
|
||||
return document_store
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
from haystack.reader.farm import FARMReader
|
||||
|
||||
|
||||
def test_farm_reader():
|
||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
|
||||
assert reader is not None
|
||||
assert isinstance(reader, FARMReader)
|
||||
@ -1,40 +1,48 @@
|
||||
from haystack import Finder
|
||||
from haystack.database.sql import SQLDocumentStore
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from haystack.retriever.sparse import TfidfRetriever
|
||||
import pytest
|
||||
|
||||
|
||||
def test_finder_get_answers():
|
||||
test_docs = [
|
||||
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", "meta": {"test": "test"}},
|
||||
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", "meta": {"test": "test"}},
|
||||
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", "meta": {"test": "test"}}
|
||||
]
|
||||
|
||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||
document_store.write_documents(test_docs)
|
||||
retriever = TfidfRetriever(document_store=document_store)
|
||||
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||
#@pytest.mark.parametrize("reader", [("farm")], indirect=True)
|
||||
#@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
def test_finder_get_answers(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
||||
top_k_reader=5)
|
||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||
top_k_reader=3)
|
||||
assert prediction is not None
|
||||
assert prediction["question"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
assert prediction["answers"][0]["document_id"] == "0"
|
||||
|
||||
assert len(prediction["answers"]) == 3
|
||||
|
||||
|
||||
def test_finder_get_answers_single_result():
|
||||
test_docs = [
|
||||
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
||||
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
||||
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
||||
]
|
||||
def test_finder_offsets(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||
top_k_reader=5)
|
||||
|
||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||
document_store.write_documents(test_docs)
|
||||
retriever = TfidfRetriever(document_store=document_store)
|
||||
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||
assert prediction["answers"][0]["offset_start"] == 11
|
||||
#TODO enable again when FARM is upgraded incl. the new offset calc
|
||||
# assert prediction["answers"][0]["offset_end"] == 16
|
||||
start = prediction["answers"][0]["offset_start"]
|
||||
end = prediction["answers"][0]["offset_end"]
|
||||
#assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||
|
||||
|
||||
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||
finder = Finder(reader, retriever)
|
||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
|
||||
top_k_reader=1)
|
||||
assert prediction is not None
|
||||
assert len(prediction["answers"]) == 1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
28
test/test_reader.py
Normal file
28
test/test_reader.py
Normal file
@ -0,0 +1,28 @@
|
||||
import pytest
|
||||
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.database.base import Document
|
||||
|
||||
|
||||
def test_reader_basic(reader):
|
||||
assert reader is not None
|
||||
assert isinstance(reader, BaseReader)
|
||||
|
||||
|
||||
def test_output(reader, test_docs_xs):
|
||||
docs = []
|
||||
for d in test_docs_xs:
|
||||
doc = Document(id=d["name"], text=d["text"], meta=d["meta"])
|
||||
docs.append(doc)
|
||||
results = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
assert results is not None
|
||||
assert results["question"] == "Who lives in Berlin?"
|
||||
assert results["answers"][0]["answer"] == "Carla"
|
||||
assert results["answers"][0]["offset_start"] == 11
|
||||
#TODO enable again when FARM is upgraded incl. the new offset calc
|
||||
# assert results["answers"][0]["offset_end"] == 16
|
||||
assert results["answers"][0]["probability"] <= 1
|
||||
assert results["answers"][0]["probability"] >= 0
|
||||
assert results["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
assert results["answers"][0]["document_id"] == "filename1"
|
||||
assert len(results["answers"]) == 5
|
||||
Loading…
x
Reference in New Issue
Block a user