Add more tests (#213)

This commit is contained in:
Malte Pietsch 2020-07-10 10:54:56 +02:00 committed by GitHub
parent 549f3a1285
commit d2b26a99ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 37 deletions

View File

@ -136,8 +136,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
del doc["meta"]
bulk(self.client, documents, request_timeout=300)
def get_document_count(self) -> int:
result = self.client.count()
def get_document_count(self, index: Optional[str] = None,) -> int:
if index is None:
index = self.index
result = self.client.count(index=index)
count = result["count"]
return count

View File

@ -2,10 +2,17 @@ import tarfile
import time
import urllib.request
from subprocess import Popen, PIPE, STDOUT, run
import os
import pytest
from elasticsearch import Elasticsearch
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.database.sql import SQLDocumentStore
from haystack.database.memory import InMemoryDocumentStore
from haystack.database.elasticsearch import ElasticsearchDocumentStore
@pytest.fixture(scope='session')
def elasticsearch_dir(tmpdir_factory):
@ -19,6 +26,7 @@ def elasticsearch_fixture(elasticsearch_dir):
client = Elasticsearch(hosts=[{"host": "localhost"}])
client.info()
except:
print("Downloading and starting an Elasticsearch instance for the tests ...")
thetarfile = "https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.1-linux-x86_64.tar.gz"
ftpstream = urllib.request.urlopen(thetarfile)
thetarfile = tarfile.open(fileobj=ftpstream, mode="r|gz")
@ -41,3 +49,47 @@ def xpdf_fixture():
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
)
@pytest.fixture()
def test_docs_xs():
return [
{"name": "filename1", "text": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1"}},
{"name": "filename2", "text": "My name is Paul and I live in New York", "meta": {"meta_field": "test2"}},
{"name": "filename3", "text": "My name is Christelle and I live in Paris", "meta": {"meta_field": "test3"}}
]
@pytest.fixture(params=["farm", "transformers"])
def reader(request):
if request.param == "farm":
return FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
use_gpu=False, top_k_per_sample=5, num_processes=0)
if request.param == "transformers":
return TransformersReader(model="distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased",
use_gpu=-1)
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
if request.param == "sql":
if os.path.exists("qa_test.db"):
os.remove("qa_test.db")
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
document_store.write_documents(test_docs_xs)
if request.param == "memory":
document_store = InMemoryDocumentStore()
document_store.write_documents(test_docs_xs)
if request.param == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
client.indices.delete(index='haystack_test', ignore=[404])
document_store = ElasticsearchDocumentStore(index="haystack_test")
assert document_store.get_document_count() == 0
document_store.write_documents(test_docs_xs)
time.sleep(2)
return document_store

View File

@ -1,7 +0,0 @@
from haystack.reader.farm import FARMReader
def test_farm_reader():
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
assert reader is not None
assert isinstance(reader, FARMReader)

View File

@ -1,40 +1,48 @@
from haystack import Finder
from haystack.database.sql import SQLDocumentStore
from haystack.reader.transformers import TransformersReader
from haystack.retriever.sparse import TfidfRetriever
import pytest
def test_finder_get_answers():
test_docs = [
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", "meta": {"test": "test"}},
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", "meta": {"test": "test"}},
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", "meta": {"test": "test"}}
]
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
document_store.write_documents(test_docs)
retriever = TfidfRetriever(document_store=document_store)
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased", use_gpu=-1)
#@pytest.mark.parametrize("reader", [("farm")], indirect=True)
#@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_finder_get_answers(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
top_k_reader=5)
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
top_k_reader=3)
assert prediction is not None
assert prediction["question"] == "Who lives in Berlin?"
assert prediction["answers"][0]["answer"] == "Carla"
assert prediction["answers"][0]["probability"] <= 1
assert prediction["answers"][0]["probability"] >= 0
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
assert prediction["answers"][0]["document_id"] == "0"
assert len(prediction["answers"]) == 3
def test_finder_get_answers_single_result():
test_docs = [
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
]
def test_finder_offsets(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
top_k_reader=5)
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
document_store.write_documents(test_docs)
retriever = TfidfRetriever(document_store=document_store)
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
tokenizer="distilbert-base-uncased", use_gpu=-1)
assert prediction["answers"][0]["offset_start"] == 11
#TODO enable again when FARM is upgraded incl. the new offset calc
# assert prediction["answers"][0]["offset_end"] == 16
start = prediction["answers"][0]["offset_start"]
end = prediction["answers"][0]["offset_end"]
#assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
def test_finder_get_answers_single_result(reader, document_store_with_docs):
retriever = TfidfRetriever(document_store=document_store_with_docs)
finder = Finder(reader, retriever)
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
top_k_reader=1)
assert prediction is not None
assert len(prediction["answers"]) == 1

28
test/test_reader.py Normal file
View File

@ -0,0 +1,28 @@
import pytest
from haystack.reader.base import BaseReader
from haystack.database.base import Document
def test_reader_basic(reader):
assert reader is not None
assert isinstance(reader, BaseReader)
def test_output(reader, test_docs_xs):
docs = []
for d in test_docs_xs:
doc = Document(id=d["name"], text=d["text"], meta=d["meta"])
docs.append(doc)
results = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
assert results is not None
assert results["question"] == "Who lives in Berlin?"
assert results["answers"][0]["answer"] == "Carla"
assert results["answers"][0]["offset_start"] == 11
#TODO enable again when FARM is upgraded incl. the new offset calc
# assert results["answers"][0]["offset_end"] == 16
assert results["answers"][0]["probability"] <= 1
assert results["answers"][0]["probability"] >= 0
assert results["answers"][0]["context"] == "My name is Carla and I live in Berlin"
assert results["answers"][0]["document_id"] == "filename1"
assert len(results["answers"]) == 5