Fix Windows CI OOM (#1878)

* set fixture scope to "function"

* run FARMReader without multiprocessing

* dispose off ray after tests

* run most expensive tasks first in test files

* run expensive tests first

* run garbage collector between tests

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2021-12-22 17:20:23 +01:00 committed by GitHub
parent 7bdb782871
commit fc8df2163d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 142 additions and 132 deletions

View File

@ -2,6 +2,7 @@ import subprocess
import time import time
from subprocess import run from subprocess import run
from sys import platform from sys import platform
import gc
import numpy as np import numpy as np
import psutil import psutil
@ -117,6 +118,15 @@ def pytest_collection_modifyitems(config,items):
item.add_marker(skip_docstore) item.add_marker(skip_docstore)
@pytest.fixture(scope="function", autouse=True)
def gc_cleanup(request):
"""
Run garbage collector between tests in order to reduce memory footprint for CI.
"""
yield
gc.collect()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def elasticsearch_fixture(): def elasticsearch_fixture():
# test if a ES cluster is already running. If not, download and start an ES instance locally. # test if a ES cluster is already running. If not, download and start an ES instance locally.
@ -247,7 +257,7 @@ def xpdf_fixture():
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def rag_generator(): def rag_generator():
return RAGenerator( return RAGenerator(
model_name_or_path="facebook/rag-token-nq", model_name_or_path="facebook/rag-token-nq",
@ -256,17 +266,17 @@ def rag_generator():
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def question_generator(): def question_generator():
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg") return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def eli5_generator(): def eli5_generator():
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20) return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def summarizer(): def summarizer():
return TransformersSummarizer( return TransformersSummarizer(
model_name_or_path="google/pegasus-xsum", model_name_or_path="google/pegasus-xsum",
@ -274,21 +284,21 @@ def summarizer():
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def en_to_de_translator(): def en_to_de_translator():
return TransformersTranslator( return TransformersTranslator(
model_name_or_path="Helsinki-NLP/opus-mt-en-de", model_name_or_path="Helsinki-NLP/opus-mt-en-de",
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def de_to_en_translator(): def de_to_en_translator():
return TransformersTranslator( return TransformersTranslator(
model_name_or_path="Helsinki-NLP/opus-mt-de-en", model_name_or_path="Helsinki-NLP/opus-mt-de-en",
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def test_docs_xs(): def test_docs_xs():
return [ return [
# current "dict" format for a document # current "dict" format for a document
@ -300,7 +310,7 @@ def test_docs_xs():
] ]
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def reader_without_normalized_scores(): def reader_without_normalized_scores():
return FARMReader( return FARMReader(
model_name_or_path="distilbert-base-uncased-distilled-squad", model_name_or_path="distilbert-base-uncased-distilled-squad",
@ -311,7 +321,7 @@ def reader_without_normalized_scores():
) )
@pytest.fixture(params=["farm", "transformers"], scope="module") @pytest.fixture(params=["farm", "transformers"], scope="function")
def reader(request): def reader(request):
if request.param == "farm": if request.param == "farm":
return FARMReader( return FARMReader(
@ -328,32 +338,32 @@ def reader(request):
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def table_reader(): def table_reader():
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq") return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def ranker_two_logits(): def ranker_two_logits():
return SentenceTransformersRanker( return SentenceTransformersRanker(
model_name_or_path="deepset/gbert-base-germandpr-reranking", model_name_or_path="deepset/gbert-base-germandpr-reranking",
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def ranker(): def ranker():
return SentenceTransformersRanker( return SentenceTransformersRanker(
model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2", model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2",
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def document_classifier(): def document_classifier():
return TransformersDocumentClassifier( return TransformersDocumentClassifier(
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
use_gpu=False use_gpu=False
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def zero_shot_document_classifier(): def zero_shot_document_classifier():
return TransformersDocumentClassifier( return TransformersDocumentClassifier(
model_name_or_path="cross-encoder/nli-distilroberta-base", model_name_or_path="cross-encoder/nli-distilroberta-base",
@ -362,7 +372,7 @@ def zero_shot_document_classifier():
labels=["negative", "positive"] labels=["negative", "positive"]
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def batched_document_classifier(): def batched_document_classifier():
return TransformersDocumentClassifier( return TransformersDocumentClassifier(
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
@ -370,7 +380,7 @@ def batched_document_classifier():
batch_size=16 batch_size=16
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def indexing_document_classifier(): def indexing_document_classifier():
return TransformersDocumentClassifier( return TransformersDocumentClassifier(
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion", model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
@ -381,7 +391,7 @@ def indexing_document_classifier():
# TODO Fix bug in test_no_answer_output when using # TODO Fix bug in test_no_answer_output when using
# @pytest.fixture(params=["farm", "transformers"]) # @pytest.fixture(params=["farm", "transformers"])
@pytest.fixture(params=["farm"], scope="module") @pytest.fixture(params=["farm"], scope="function")
def no_answer_reader(request): def no_answer_reader(request):
if request.param == "farm": if request.param == "farm":
return FARMReader( return FARMReader(
@ -401,14 +411,14 @@ def no_answer_reader(request):
) )
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def prediction(reader, test_docs_xs): def prediction(reader, test_docs_xs):
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs] docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5) prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
return prediction return prediction
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def no_answer_prediction(no_answer_reader, test_docs_xs): def no_answer_prediction(no_answer_reader, test_docs_xs):
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs] docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5) prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
@ -543,7 +553,7 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
return document_store return document_store
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def adaptive_model_qa(num_processes): def adaptive_model_qa(num_processes):
""" """
PyTest Fixture for a Question Answering Inferencer based on PyTorch. PyTest Fixture for a Question Answering Inferencer based on PyTorch.
@ -571,7 +581,7 @@ def adaptive_model_qa(num_processes):
assert len(children) == 0 assert len(children) == 0
@pytest.fixture(scope="module") @pytest.fixture(scope="function")
def bert_base_squad2(request): def bert_base_squad2(request):
model = QAInferencer.load( model = QAInferencer.load(
"deepset/minilm-uncased-squad2", "deepset/minilm-uncased-squad2",

View File

@ -6,6 +6,7 @@ components:
params: params:
no_ans_boost: -10 no_ans_boost: -10
model_name_or_path: deepset/roberta-base-squad2 model_name_or_path: deepset/roberta-base-squad2
num_processes: 0
- name: ESRetriever - name: ESRetriever
type: ElasticsearchRetriever type: ElasticsearchRetriever
params: params:

View File

@ -6,6 +6,7 @@ components:
params: params:
no_ans_boost: -10 no_ans_boost: -10
model_name_or_path: deepset/minilm-uncased-squad2 model_name_or_path: deepset/minilm-uncased-squad2
num_processes: 0
- name: Retriever - name: Retriever
type: TfidfRetriever type: TfidfRetriever
params: params:

View File

@ -2,8 +2,8 @@ from haystack.nodes import FARMReader
import torch import torch
def test_distillation(): def test_distillation():
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny") student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small") teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
# create a checkpoint of weights before distillation # create a checkpoint of weights before distillation
student_weights = [] student_weights = []

View File

@ -1,16 +1,74 @@
import pytest import pytest
from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.nodes.answer_generator.transformers import RAGenerator, RAGeneratorType
from haystack.nodes.retriever.dense import EmbeddingRetriever
from haystack.nodes.preprocessor import PreProcessor from haystack.nodes.preprocessor import PreProcessor
from haystack.nodes.evaluator import EvalAnswers, EvalDocuments from haystack.nodes.evaluator import EvalAnswers, EvalDocuments
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
from haystack.nodes.retriever.dense import DensePassageRetriever from haystack.nodes.retriever.dense import DensePassageRetriever
from haystack.nodes.retriever.sparse import ElasticsearchRetriever from haystack.nodes.retriever.sparse import ElasticsearchRetriever
from haystack.pipelines.base import Pipeline from haystack.pipelines.base import Pipeline
from haystack.pipelines import ExtractiveQAPipeline from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline, SearchSummarizationPipeline
from haystack.pipelines.standard_pipelines import DocumentSearchPipeline, FAQPipeline, RetrieverQuestionGenerationPipeline, TranslationWrapperPipeline from haystack.pipelines.standard_pipelines import DocumentSearchPipeline, FAQPipeline, RetrieverQuestionGenerationPipeline, TranslationWrapperPipeline
from haystack.nodes.summarizer.transformers import TransformersSummarizer
from haystack.schema import Answer, Document, EvaluationResult, Label, MultiLabel, Span from haystack.schema import Answer, Document, EvaluationResult, Label, MultiLabel, Span
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
def test_generativeqa_calculate_metrics(document_store_with_docs: InMemoryDocumentStore, rag_generator, retriever_with_docs):
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
pipeline = GenerativeQAPipeline(generator=rag_generator, retriever=retriever_with_docs)
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}}
)
metrics = eval_result.calculate_metrics()
assert "Retriever" in eval_result
assert "Generator" in eval_result
assert len(eval_result) == 2
assert metrics["Retriever"]["mrr"] == 0.5
assert metrics["Retriever"]["map"] == 0.5
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
assert metrics["Retriever"]["recall_single_hit"] == 0.5
assert metrics["Retriever"]["precision"] == 1.0/6
assert metrics["Generator"]["exact_match"] == 0.0
assert metrics["Generator"]["f1"] == 1.0/3
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
def test_summarizer_calculate_metrics(document_store_with_docs: ElasticsearchDocumentStore, summarizer, retriever_with_docs):
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
pipeline = SearchSummarizationPipeline(retriever=retriever_with_docs, summarizer=summarizer, return_in_answer_format=True)
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}}
)
metrics = eval_result.calculate_metrics()
assert "Retriever" in eval_result
assert "Summarizer" in eval_result
assert len(eval_result) == 2
assert metrics["Retriever"]["mrr"] == 0.5
assert metrics["Retriever"]["map"] == 0.5
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
assert metrics["Retriever"]["recall_single_hit"] == 0.5
assert metrics["Retriever"]["precision"] == 1.0/6
assert metrics["Summarizer"]["mrr"] == 0.5
assert metrics["Summarizer"]["map"] == 0.5
assert metrics["Summarizer"]["recall_multi_hit"] == 0.5
assert metrics["Summarizer"]["recall_single_hit"] == 0.5
assert metrics["Summarizer"]["precision"] == 1.0/6
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("batch_size", [None, 20]) @pytest.mark.parametrize("batch_size", [None, 20])
def test_add_eval_data(document_store, batch_size): def test_add_eval_data(document_store, batch_size):

View File

@ -1,72 +0,0 @@
import pytest
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.nodes.retriever.dense import EmbeddingRetriever
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.nodes.summarizer.transformers import TransformersSummarizer
from haystack.pipelines import GenerativeQAPipeline, SearchSummarizationPipeline
from haystack.schema import EvaluationResult
from test_eval import EVAL_LABELS
# had to be separated from other eval tests to work around OOM in Windows CI
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
def test_generativeqa_calculate_metrics(document_store_with_docs: InMemoryDocumentStore, rag_generator, retriever_with_docs):
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
pipeline = GenerativeQAPipeline(generator=rag_generator, retriever=retriever_with_docs)
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}}
)
metrics = eval_result.calculate_metrics()
assert "Retriever" in eval_result
assert "Generator" in eval_result
assert len(eval_result) == 2
assert metrics["Retriever"]["mrr"] == 0.5
assert metrics["Retriever"]["map"] == 0.5
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
assert metrics["Retriever"]["recall_single_hit"] == 0.5
assert metrics["Retriever"]["precision"] == 1.0/6
assert metrics["Generator"]["exact_match"] == 0.0
assert metrics["Generator"]["f1"] == 1.0/3
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_summarizer_calculate_metrics(document_store_with_docs: ElasticsearchDocumentStore):
summarizer = TransformersSummarizer(
model_name_or_path="sshleifer/distill-pegasus-xsum-16-4",
use_gpu=False
)
document_store_with_docs.embedding_dim = 384
retriever = EmbeddingRetriever(
document_store=document_store_with_docs,
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
use_gpu=False
)
document_store_with_docs.update_embeddings(retriever=retriever)
pipeline = SearchSummarizationPipeline(retriever=retriever, summarizer=summarizer, return_in_answer_format=True)
eval_result: EvaluationResult = pipeline.eval(
labels=EVAL_LABELS,
params={"Retriever": {"top_k": 5}}
)
metrics = eval_result.calculate_metrics()
assert "Retriever" in eval_result
assert "Summarizer" in eval_result
assert len(eval_result) == 2
assert metrics["Retriever"]["mrr"] == 0.5
assert metrics["Retriever"]["map"] == 0.5
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
assert metrics["Retriever"]["recall_single_hit"] == 0.5
assert metrics["Retriever"]["precision"] == 1.0/6
assert metrics["Summarizer"]["mrr"] == 0.0
assert metrics["Summarizer"]["map"] == 0.0
assert metrics["Summarizer"]["recall_multi_hit"] == 0.0
assert metrics["Summarizer"]["recall_single_hit"] == 0.0
assert metrics["Summarizer"]["precision"] == 0.0

View File

@ -12,7 +12,7 @@ def test_extractor(document_store_with_docs):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
ner = EntityExtractor() ner = EntityExtractor()
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2") reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
@ -36,7 +36,7 @@ def test_extractor_output_simplifier(document_store_with_docs):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
ner = EntityExtractor() ner = EntityExtractor()
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2") reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])

View File

@ -12,6 +12,35 @@ from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline
from conftest import DOCS_WITH_EMBEDDINGS from conftest import DOCS_WITH_EMBEDDINGS
# Keeping few (retriever,document_store) combination to reduce test time
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory")],
indirect=True,
)
def test_generator_pipeline_with_translator(
document_store,
retriever,
rag_generator,
en_to_de_translator,
de_to_en_translator
):
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
pipeline = TranslationWrapperPipeline(
input_translator=de_to_en_translator,
output_translator=en_to_de_translator,
pipeline=base_pipeline
)
output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}})
answers = output["answers"]
assert len(answers) == 2
assert "berlin" in answers[0].answer
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.generator @pytest.mark.generator
def test_rag_token_generator(rag_generator): def test_rag_token_generator(rag_generator):
@ -105,32 +134,3 @@ def test_lfqa_pipeline_invalid_converter(document_store, retriever):
with pytest.raises(Exception) as exception_info: with pytest.raises(Exception) as exception_info:
output = pipeline.run(query=query, params={"top_k": 1}) output = pipeline.run(query=query, params={"top_k": 1})
assert ("does not have a valid __call__ method signature" in str(exception_info.value)) assert ("does not have a valid __call__ method signature" in str(exception_info.value))
# Keeping few (retriever,document_store) combination to reduce test time
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory")],
indirect=True,
)
def test_generator_pipeline_with_translator(
document_store,
retriever,
rag_generator,
en_to_de_translator,
de_to_en_translator
):
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
pipeline = TranslationWrapperPipeline(
input_translator=de_to_en_translator,
output_translator=en_to_de_translator,
pipeline=base_pipeline
)
output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}})
answers = output["answers"]
assert len(answers) == 2
assert "berlin" in answers[0].answer

View File

@ -65,6 +65,7 @@ def test_load_and_save_yaml(document_store, tmp_path):
params: params:
model_name_or_path: deepset/roberta-base-squad2 model_name_or_path: deepset/roberta-base-squad2
no_ans_boost: -10 no_ans_boost: -10
num_processes: 0
type: FARMReader type: FARMReader
pipelines: pipelines:
- name: query - name: query
@ -127,6 +128,7 @@ def test_load_and_save_yaml_prebuilt_pipelines(document_store, tmp_path):
params: params:
model_name_or_path: deepset/roberta-base-squad2 model_name_or_path: deepset/roberta-base-squad2
no_ans_boost: -10 no_ans_boost: -10
num_processes: 0
type: FARMReader type: FARMReader
pipelines: pipelines:
- name: query - name: query

View File

@ -23,7 +23,7 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
name="Retriever", name="Retriever",
inputs=["Query"]) inputs=["Query"])
pipeline.add_node( pipeline.add_node(
component=FARMReader(model_name_or_path="deepset/minilm-uncased-squad2"), component=FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0),
name="Reader", name="Reader",
inputs=["Retriever"]) inputs=["Retriever"])
@ -50,7 +50,7 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
def test_debug_attributes_global(document_store_with_docs, tmp_path): def test_debug_attributes_global(document_store_with_docs, tmp_path):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2") reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
@ -81,7 +81,7 @@ def test_debug_attributes_global(document_store_with_docs, tmp_path):
def test_debug_attributes_per_node(document_store_with_docs, tmp_path): def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2") reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
@ -111,7 +111,7 @@ def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path): def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2") reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])

View File

@ -17,3 +17,13 @@ def test_load_pipeline(document_store_with_docs):
assert ray.serve.get_deployment(name="Reader").num_replicas == 1 assert ray.serve.get_deployment(name="Reader").num_replicas == 1
assert prediction["query"] == "Who lives in Berlin?" assert prediction["query"] == "Who lives in Berlin?"
assert prediction["answers"][0].answer == "Carla" assert prediction["answers"][0].answer == "Carla"
@pytest.fixture(scope="function", autouse=True)
def shutdown_ray():
yield
try:
import ray
ray.shutdown()
except:
pass

View File

@ -56,7 +56,7 @@ def test_prediction_attributes(prediction):
def test_model_download_options(): def test_model_download_options():
# download disabled and model is not cached locally # download disabled and model is not cached locally
with pytest.raises(OSError): with pytest.raises(OSError):
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True) impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True, num_processes=0)
def test_answer_attributes(prediction): def test_answer_attributes(prediction):
# TODO Transformers answer also has meta key # TODO Transformers answer also has meta key