mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
Fix Windows CI OOM (#1878)
* set fixture scope to "function" * run FARMReader without multiprocessing * dispose off ray after tests * run most expensive tasks first in test files * run expensive tests first * run garbage collector between tests Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
7bdb782871
commit
fc8df2163d
@ -2,6 +2,7 @@ import subprocess
|
||||
import time
|
||||
from subprocess import run
|
||||
from sys import platform
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@ -117,6 +118,15 @@ def pytest_collection_modifyitems(config,items):
|
||||
item.add_marker(skip_docstore)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def gc_cleanup(request):
|
||||
"""
|
||||
Run garbage collector between tests in order to reduce memory footprint for CI.
|
||||
"""
|
||||
yield
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def elasticsearch_fixture():
|
||||
# test if a ES cluster is already running. If not, download and start an ES instance locally.
|
||||
@ -247,7 +257,7 @@ def xpdf_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def rag_generator():
|
||||
return RAGenerator(
|
||||
model_name_or_path="facebook/rag-token-nq",
|
||||
@ -256,17 +266,17 @@ def rag_generator():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def question_generator():
|
||||
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def eli5_generator():
|
||||
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def summarizer():
|
||||
return TransformersSummarizer(
|
||||
model_name_or_path="google/pegasus-xsum",
|
||||
@ -274,21 +284,21 @@ def summarizer():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def en_to_de_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-en-de",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def de_to_en_translator():
|
||||
return TransformersTranslator(
|
||||
model_name_or_path="Helsinki-NLP/opus-mt-de-en",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def test_docs_xs():
|
||||
return [
|
||||
# current "dict" format for a document
|
||||
@ -300,7 +310,7 @@ def test_docs_xs():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def reader_without_normalized_scores():
|
||||
return FARMReader(
|
||||
model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||
@ -311,7 +321,7 @@ def reader_without_normalized_scores():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(params=["farm", "transformers"], scope="module")
|
||||
@pytest.fixture(params=["farm", "transformers"], scope="function")
|
||||
def reader(request):
|
||||
if request.param == "farm":
|
||||
return FARMReader(
|
||||
@ -328,32 +338,32 @@ def reader(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def table_reader():
|
||||
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def ranker_two_logits():
|
||||
return SentenceTransformersRanker(
|
||||
model_name_or_path="deepset/gbert-base-germandpr-reranking",
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def ranker():
|
||||
return SentenceTransformersRanker(
|
||||
model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
||||
use_gpu=False
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def zero_shot_document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="cross-encoder/nli-distilroberta-base",
|
||||
@ -362,7 +372,7 @@ def zero_shot_document_classifier():
|
||||
labels=["negative", "positive"]
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def batched_document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
||||
@ -370,7 +380,7 @@ def batched_document_classifier():
|
||||
batch_size=16
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def indexing_document_classifier():
|
||||
return TransformersDocumentClassifier(
|
||||
model_name_or_path="bhadresh-savani/distilbert-base-uncased-emotion",
|
||||
@ -381,7 +391,7 @@ def indexing_document_classifier():
|
||||
|
||||
# TODO Fix bug in test_no_answer_output when using
|
||||
# @pytest.fixture(params=["farm", "transformers"])
|
||||
@pytest.fixture(params=["farm"], scope="module")
|
||||
@pytest.fixture(params=["farm"], scope="function")
|
||||
def no_answer_reader(request):
|
||||
if request.param == "farm":
|
||||
return FARMReader(
|
||||
@ -401,14 +411,14 @@ def no_answer_reader(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def prediction(reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = reader.predict(query="Who lives in Berlin?", documents=docs, top_k=5)
|
||||
return prediction
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
||||
docs = [Document.from_dict(d) if isinstance(d, dict) else d for d in test_docs_xs]
|
||||
prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
||||
@ -543,7 +553,7 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
|
||||
return document_store
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def adaptive_model_qa(num_processes):
|
||||
"""
|
||||
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
|
||||
@ -571,7 +581,7 @@ def adaptive_model_qa(num_processes):
|
||||
assert len(children) == 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def bert_base_squad2(request):
|
||||
model = QAInferencer.load(
|
||||
"deepset/minilm-uncased-squad2",
|
||||
|
||||
@ -6,6 +6,7 @@ components:
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
num_processes: 0
|
||||
- name: ESRetriever
|
||||
type: ElasticsearchRetriever
|
||||
params:
|
||||
|
||||
@ -6,6 +6,7 @@ components:
|
||||
params:
|
||||
no_ans_boost: -10
|
||||
model_name_or_path: deepset/minilm-uncased-squad2
|
||||
num_processes: 0
|
||||
- name: Retriever
|
||||
type: TfidfRetriever
|
||||
params:
|
||||
|
||||
@ -2,8 +2,8 @@ from haystack.nodes import FARMReader
|
||||
import torch
|
||||
|
||||
def test_distillation():
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny")
|
||||
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small")
|
||||
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny", num_processes=0)
|
||||
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small", num_processes=0)
|
||||
|
||||
# create a checkpoint of weights before distillation
|
||||
student_weights = []
|
||||
|
||||
@ -1,16 +1,74 @@
|
||||
import pytest
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.nodes.answer_generator.transformers import RAGenerator, RAGeneratorType
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.nodes.evaluator import EvalAnswers, EvalDocuments
|
||||
from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.pipelines.base import Pipeline
|
||||
from haystack.pipelines import ExtractiveQAPipeline
|
||||
from haystack.pipelines import ExtractiveQAPipeline, GenerativeQAPipeline, SearchSummarizationPipeline
|
||||
from haystack.pipelines.standard_pipelines import DocumentSearchPipeline, FAQPipeline, RetrieverQuestionGenerationPipeline, TranslationWrapperPipeline
|
||||
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
||||
from haystack.schema import Answer, Document, EvaluationResult, Label, MultiLabel, Span
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
|
||||
def test_generativeqa_calculate_metrics(document_store_with_docs: InMemoryDocumentStore, rag_generator, retriever_with_docs):
|
||||
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
|
||||
pipeline = GenerativeQAPipeline(generator=rag_generator, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
params={"Retriever": {"top_k": 5}}
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
assert "Retriever" in eval_result
|
||||
assert "Generator" in eval_result
|
||||
assert len(eval_result) == 2
|
||||
|
||||
assert metrics["Retriever"]["mrr"] == 0.5
|
||||
assert metrics["Retriever"]["map"] == 0.5
|
||||
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
|
||||
assert metrics["Retriever"]["recall_single_hit"] == 0.5
|
||||
assert metrics["Retriever"]["precision"] == 1.0/6
|
||||
assert metrics["Generator"]["exact_match"] == 0.0
|
||||
assert metrics["Generator"]["f1"] == 1.0/3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
|
||||
def test_summarizer_calculate_metrics(document_store_with_docs: ElasticsearchDocumentStore, summarizer, retriever_with_docs):
|
||||
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
|
||||
pipeline = SearchSummarizationPipeline(retriever=retriever_with_docs, summarizer=summarizer, return_in_answer_format=True)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
params={"Retriever": {"top_k": 5}}
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
assert "Retriever" in eval_result
|
||||
assert "Summarizer" in eval_result
|
||||
assert len(eval_result) == 2
|
||||
|
||||
assert metrics["Retriever"]["mrr"] == 0.5
|
||||
assert metrics["Retriever"]["map"] == 0.5
|
||||
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
|
||||
assert metrics["Retriever"]["recall_single_hit"] == 0.5
|
||||
assert metrics["Retriever"]["precision"] == 1.0/6
|
||||
assert metrics["Summarizer"]["mrr"] == 0.5
|
||||
assert metrics["Summarizer"]["map"] == 0.5
|
||||
assert metrics["Summarizer"]["recall_multi_hit"] == 0.5
|
||||
assert metrics["Summarizer"]["recall_single_hit"] == 0.5
|
||||
assert metrics["Summarizer"]["precision"] == 1.0/6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||
@pytest.mark.parametrize("batch_size", [None, 20])
|
||||
def test_add_eval_data(document_store, batch_size):
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
import pytest
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
from haystack.document_stores.memory import InMemoryDocumentStore
|
||||
from haystack.nodes.summarizer.transformers import TransformersSummarizer
|
||||
from haystack.pipelines import GenerativeQAPipeline, SearchSummarizationPipeline
|
||||
from haystack.schema import EvaluationResult
|
||||
from test_eval import EVAL_LABELS
|
||||
|
||||
|
||||
# had to be separated from other eval tests to work around OOM in Windows CI
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
|
||||
def test_generativeqa_calculate_metrics(document_store_with_docs: InMemoryDocumentStore, rag_generator, retriever_with_docs):
|
||||
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
|
||||
pipeline = GenerativeQAPipeline(generator=rag_generator, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
params={"Retriever": {"top_k": 5}}
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
assert "Retriever" in eval_result
|
||||
assert "Generator" in eval_result
|
||||
assert len(eval_result) == 2
|
||||
|
||||
assert metrics["Retriever"]["mrr"] == 0.5
|
||||
assert metrics["Retriever"]["map"] == 0.5
|
||||
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
|
||||
assert metrics["Retriever"]["recall_single_hit"] == 0.5
|
||||
assert metrics["Retriever"]["precision"] == 1.0/6
|
||||
assert metrics["Generator"]["exact_match"] == 0.0
|
||||
assert metrics["Generator"]["f1"] == 1.0/3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_summarizer_calculate_metrics(document_store_with_docs: ElasticsearchDocumentStore):
|
||||
summarizer = TransformersSummarizer(
|
||||
model_name_or_path="sshleifer/distill-pegasus-xsum-16-4",
|
||||
use_gpu=False
|
||||
)
|
||||
document_store_with_docs.embedding_dim = 384
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store_with_docs,
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
use_gpu=False
|
||||
)
|
||||
document_store_with_docs.update_embeddings(retriever=retriever)
|
||||
pipeline = SearchSummarizationPipeline(retriever=retriever, summarizer=summarizer, return_in_answer_format=True)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
params={"Retriever": {"top_k": 5}}
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics()
|
||||
|
||||
assert "Retriever" in eval_result
|
||||
assert "Summarizer" in eval_result
|
||||
assert len(eval_result) == 2
|
||||
|
||||
assert metrics["Retriever"]["mrr"] == 0.5
|
||||
assert metrics["Retriever"]["map"] == 0.5
|
||||
assert metrics["Retriever"]["recall_multi_hit"] == 0.5
|
||||
assert metrics["Retriever"]["recall_single_hit"] == 0.5
|
||||
assert metrics["Retriever"]["precision"] == 1.0/6
|
||||
assert metrics["Summarizer"]["mrr"] == 0.0
|
||||
assert metrics["Summarizer"]["map"] == 0.0
|
||||
assert metrics["Summarizer"]["recall_multi_hit"] == 0.0
|
||||
assert metrics["Summarizer"]["recall_single_hit"] == 0.0
|
||||
assert metrics["Summarizer"]["precision"] == 0.0
|
||||
@ -12,7 +12,7 @@ def test_extractor(document_store_with_docs):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor()
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
@ -36,7 +36,7 @@ def test_extractor_output_simplifier(document_store_with_docs):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor()
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
|
||||
@ -12,6 +12,35 @@ from haystack.pipelines import TranslationWrapperPipeline, GenerativeQAPipeline
|
||||
from conftest import DOCS_WITH_EMBEDDINGS
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store) combination to reduce test time
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_generator_pipeline_with_translator(
|
||||
document_store,
|
||||
retriever,
|
||||
rag_generator,
|
||||
en_to_de_translator,
|
||||
de_to_en_translator
|
||||
):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
|
||||
base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
)
|
||||
output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}})
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 2
|
||||
assert "berlin" in answers[0].answer
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
def test_rag_token_generator(rag_generator):
|
||||
@ -105,32 +134,3 @@ def test_lfqa_pipeline_invalid_converter(document_store, retriever):
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
output = pipeline.run(query=query, params={"top_k": 1})
|
||||
assert ("does not have a valid __call__ method signature" in str(exception_info.value))
|
||||
|
||||
|
||||
# Keeping few (retriever,document_store) combination to reduce test time
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize(
|
||||
"retriever,document_store",
|
||||
[("embedding", "memory")],
|
||||
indirect=True,
|
||||
)
|
||||
def test_generator_pipeline_with_translator(
|
||||
document_store,
|
||||
retriever,
|
||||
rag_generator,
|
||||
en_to_de_translator,
|
||||
de_to_en_translator
|
||||
):
|
||||
document_store.write_documents(DOCS_WITH_EMBEDDINGS)
|
||||
query = "Was ist die Hauptstadt der Bundesrepublik Deutschland?"
|
||||
base_pipeline = GenerativeQAPipeline(retriever=retriever, generator=rag_generator)
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=de_to_en_translator,
|
||||
output_translator=en_to_de_translator,
|
||||
pipeline=base_pipeline
|
||||
)
|
||||
output = pipeline.run(query=query, params={"Generator": {"top_k": 2}, "Retriever": {"top_k": 1}})
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 2
|
||||
assert "berlin" in answers[0].answer
|
||||
|
||||
@ -65,6 +65,7 @@ def test_load_and_save_yaml(document_store, tmp_path):
|
||||
params:
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
no_ans_boost: -10
|
||||
num_processes: 0
|
||||
type: FARMReader
|
||||
pipelines:
|
||||
- name: query
|
||||
@ -127,6 +128,7 @@ def test_load_and_save_yaml_prebuilt_pipelines(document_store, tmp_path):
|
||||
params:
|
||||
model_name_or_path: deepset/roberta-base-squad2
|
||||
no_ans_boost: -10
|
||||
num_processes: 0
|
||||
type: FARMReader
|
||||
pipelines:
|
||||
- name: query
|
||||
|
||||
@ -23,7 +23,7 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
|
||||
name="Retriever",
|
||||
inputs=["Query"])
|
||||
pipeline.add_node(
|
||||
component=FARMReader(model_name_or_path="deepset/minilm-uncased-squad2"),
|
||||
component=FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0),
|
||||
name="Reader",
|
||||
inputs=["Retriever"])
|
||||
|
||||
@ -50,7 +50,7 @@ def test_node_names_validation(document_store_with_docs, tmp_path):
|
||||
def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2")
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
@ -81,7 +81,7 @@ def test_debug_attributes_global(document_store_with_docs, tmp_path):
|
||||
def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2")
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
@ -111,7 +111,7 @@ def test_debug_attributes_per_node(document_store_with_docs, tmp_path):
|
||||
def test_global_debug_attributes_override_node_ones(document_store_with_docs, tmp_path):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2")
|
||||
reader = FARMReader(model_name_or_path="deepset/minilm-uncased-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
|
||||
@ -17,3 +17,13 @@ def test_load_pipeline(document_store_with_docs):
|
||||
assert ray.serve.get_deployment(name="Reader").num_replicas == 1
|
||||
assert prediction["query"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0].answer == "Carla"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def shutdown_ray():
|
||||
yield
|
||||
try:
|
||||
import ray
|
||||
ray.shutdown()
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -56,7 +56,7 @@ def test_prediction_attributes(prediction):
|
||||
def test_model_download_options():
|
||||
# download disabled and model is not cached locally
|
||||
with pytest.raises(OSError):
|
||||
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True)
|
||||
impossible_reader = FARMReader("mfeb/albert-xxlarge-v2-squad2", local_files_only=True, num_processes=0)
|
||||
|
||||
def test_answer_attributes(prediction):
|
||||
# TODO Transformers answer also has meta key
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user