mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-04 19:06:44 +00:00
Add more tests (#213)
This commit is contained in:
parent
549f3a1285
commit
d2b26a99ff
@ -136,8 +136,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
del doc["meta"]
|
del doc["meta"]
|
||||||
bulk(self.client, documents, request_timeout=300)
|
bulk(self.client, documents, request_timeout=300)
|
||||||
|
|
||||||
def get_document_count(self) -> int:
|
def get_document_count(self, index: Optional[str] = None,) -> int:
|
||||||
result = self.client.count()
|
if index is None:
|
||||||
|
index = self.index
|
||||||
|
result = self.client.count(index=index)
|
||||||
count = result["count"]
|
count = result["count"]
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,17 @@ import tarfile
|
|||||||
import time
|
import time
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from subprocess import Popen, PIPE, STDOUT, run
|
from subprocess import Popen, PIPE, STDOUT, run
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
from haystack.reader.farm import FARMReader
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
|
||||||
|
from haystack.database.sql import SQLDocumentStore
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def elasticsearch_dir(tmpdir_factory):
|
def elasticsearch_dir(tmpdir_factory):
|
||||||
@ -19,6 +26,7 @@ def elasticsearch_fixture(elasticsearch_dir):
|
|||||||
client = Elasticsearch(hosts=[{"host": "localhost"}])
|
client = Elasticsearch(hosts=[{"host": "localhost"}])
|
||||||
client.info()
|
client.info()
|
||||||
except:
|
except:
|
||||||
|
print("Downloading and starting an Elasticsearch instance for the tests ...")
|
||||||
thetarfile = "https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.1-linux-x86_64.tar.gz"
|
thetarfile = "https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.1-linux-x86_64.tar.gz"
|
||||||
ftpstream = urllib.request.urlopen(thetarfile)
|
ftpstream = urllib.request.urlopen(thetarfile)
|
||||||
thetarfile = tarfile.open(fileobj=ftpstream, mode="r|gz")
|
thetarfile = tarfile.open(fileobj=ftpstream, mode="r|gz")
|
||||||
@ -41,3 +49,47 @@ def xpdf_fixture():
|
|||||||
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
|
"""pdftotext is not installed. It is part of xpdf or poppler-utils software suite.
|
||||||
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
|
You can download for your OS from here: https://www.xpdfreader.com/download.html."""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def test_docs_xs():
|
||||||
|
return [
|
||||||
|
{"name": "filename1", "text": "My name is Carla and I live in Berlin", "meta": {"meta_field": "test1"}},
|
||||||
|
{"name": "filename2", "text": "My name is Paul and I live in New York", "meta": {"meta_field": "test2"}},
|
||||||
|
{"name": "filename3", "text": "My name is Christelle and I live in Paris", "meta": {"meta_field": "test3"}}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["farm", "transformers"])
|
||||||
|
def reader(request):
|
||||||
|
if request.param == "farm":
|
||||||
|
return FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||||
|
use_gpu=False, top_k_per_sample=5, num_processes=0)
|
||||||
|
if request.param == "transformers":
|
||||||
|
return TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||||
|
tokenizer="distilbert-base-uncased",
|
||||||
|
use_gpu=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
|
||||||
|
def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
|
||||||
|
if request.param == "sql":
|
||||||
|
if os.path.exists("qa_test.db"):
|
||||||
|
os.remove("qa_test.db")
|
||||||
|
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||||
|
document_store.write_documents(test_docs_xs)
|
||||||
|
|
||||||
|
if request.param == "memory":
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs_xs)
|
||||||
|
|
||||||
|
if request.param == "elasticsearch":
|
||||||
|
# make sure we start from a fresh index
|
||||||
|
client = Elasticsearch()
|
||||||
|
client.indices.delete(index='haystack_test', ignore=[404])
|
||||||
|
document_store = ElasticsearchDocumentStore(index="haystack_test")
|
||||||
|
assert document_store.get_document_count() == 0
|
||||||
|
document_store.write_documents(test_docs_xs)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
return document_store
|
||||||
|
|||||||
@ -1,7 +0,0 @@
|
|||||||
from haystack.reader.farm import FARMReader
|
|
||||||
|
|
||||||
|
|
||||||
def test_farm_reader():
|
|
||||||
reader = FARMReader(model_name_or_path="deepset/bert-base-cased-squad2", use_gpu=False)
|
|
||||||
assert reader is not None
|
|
||||||
assert isinstance(reader, FARMReader)
|
|
||||||
@ -1,40 +1,48 @@
|
|||||||
from haystack import Finder
|
from haystack import Finder
|
||||||
from haystack.database.sql import SQLDocumentStore
|
|
||||||
from haystack.reader.transformers import TransformersReader
|
|
||||||
from haystack.retriever.sparse import TfidfRetriever
|
from haystack.retriever.sparse import TfidfRetriever
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
#@pytest.mark.parametrize("reader", [("farm")], indirect=True)
|
||||||
def test_finder_get_answers():
|
#@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||||
test_docs = [
|
def test_finder_get_answers(reader, document_store_with_docs):
|
||||||
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", "meta": {"test": "test"}},
|
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||||
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", "meta": {"test": "test"}},
|
|
||||||
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", "meta": {"test": "test"}}
|
|
||||||
]
|
|
||||||
|
|
||||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
|
||||||
document_store.write_documents(test_docs)
|
|
||||||
retriever = TfidfRetriever(document_store=document_store)
|
|
||||||
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
|
||||||
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever)
|
||||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||||
top_k_reader=5)
|
top_k_reader=3)
|
||||||
assert prediction is not None
|
assert prediction is not None
|
||||||
|
assert prediction["question"] == "Who lives in Berlin?"
|
||||||
|
assert prediction["answers"][0]["answer"] == "Carla"
|
||||||
|
assert prediction["answers"][0]["probability"] <= 1
|
||||||
|
assert prediction["answers"][0]["probability"] >= 0
|
||||||
|
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||||
|
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||||
|
assert prediction["answers"][0]["document_id"] == "0"
|
||||||
|
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_finder_get_answers_single_result():
|
def test_finder_offsets(reader, document_store_with_docs):
|
||||||
test_docs = [
|
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||||
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
finder = Finder(reader, retriever)
|
||||||
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
prediction = finder.get_answers(question="Who lives in Berlin?", top_k_retriever=10,
|
||||||
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
top_k_reader=5)
|
||||||
]
|
|
||||||
|
|
||||||
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
assert prediction["answers"][0]["offset_start"] == 11
|
||||||
document_store.write_documents(test_docs)
|
#TODO enable again when FARM is upgraded incl. the new offset calc
|
||||||
retriever = TfidfRetriever(document_store=document_store)
|
# assert prediction["answers"][0]["offset_end"] == 16
|
||||||
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
start = prediction["answers"][0]["offset_start"]
|
||||||
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
end = prediction["answers"][0]["offset_end"]
|
||||||
|
#assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
||||||
|
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever)
|
||||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
|
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
|
||||||
top_k_reader=1)
|
top_k_reader=1)
|
||||||
assert prediction is not None
|
assert prediction is not None
|
||||||
|
assert len(prediction["answers"]) == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
28
test/test_reader.py
Normal file
28
test/test_reader.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.reader.base import BaseReader
|
||||||
|
from haystack.database.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
def test_reader_basic(reader):
|
||||||
|
assert reader is not None
|
||||||
|
assert isinstance(reader, BaseReader)
|
||||||
|
|
||||||
|
|
||||||
|
def test_output(reader, test_docs_xs):
|
||||||
|
docs = []
|
||||||
|
for d in test_docs_xs:
|
||||||
|
doc = Document(id=d["name"], text=d["text"], meta=d["meta"])
|
||||||
|
docs.append(doc)
|
||||||
|
results = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||||
|
assert results is not None
|
||||||
|
assert results["question"] == "Who lives in Berlin?"
|
||||||
|
assert results["answers"][0]["answer"] == "Carla"
|
||||||
|
assert results["answers"][0]["offset_start"] == 11
|
||||||
|
#TODO enable again when FARM is upgraded incl. the new offset calc
|
||||||
|
# assert results["answers"][0]["offset_end"] == 16
|
||||||
|
assert results["answers"][0]["probability"] <= 1
|
||||||
|
assert results["answers"][0]["probability"] >= 0
|
||||||
|
assert results["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||||
|
assert results["answers"][0]["document_id"] == "filename1"
|
||||||
|
assert len(results["answers"]) == 5
|
||||||
Loading…
x
Reference in New Issue
Block a user