mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 06:28:33 +00:00
fix: Support isolated node eval in run_batch in Generators (#5291)
* Add isolated node eval to BaseGenerator's run_batch * Add unit tests
This commit is contained in:
parent
395854d823
commit
0697f5c63e
@ -61,12 +61,31 @@ class BaseGenerator(BaseComponent):
|
||||
queries: List[str],
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
top_k: Optional[int] = None,
|
||||
labels: Optional[List[MultiLabel]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
add_isolated_node_eval: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
results = self.predict_batch(
|
||||
queries=queries, documents=documents, top_k=top_k, batch_size=batch_size, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# run evaluation with "perfect" labels as node inputs to calculate "upper bound" metrics for just this node
|
||||
if add_isolated_node_eval and labels is not None:
|
||||
relevant_documents = []
|
||||
for labelx in labels:
|
||||
# Deduplicate same Documents in a MultiLabel based on their Document ID and filter out empty Documents
|
||||
relevant_docs_labels = list(
|
||||
{
|
||||
label.document.id: label.document
|
||||
for label in labelx.labels
|
||||
if not isinstance(label.document.content, str) or label.document.content.strip() != ""
|
||||
}.values()
|
||||
)
|
||||
relevant_documents.append(relevant_docs_labels)
|
||||
results_label_input = self.predict_batch(queries=queries, documents=relevant_documents, top_k=top_k)
|
||||
|
||||
results["answers_isolated"] = results_label_input["answers"]
|
||||
return results, "output_1"
|
||||
|
||||
def _flatten_docs(self, documents: List[Document]):
|
||||
|
||||
@ -47,7 +47,7 @@ from haystack.nodes import (
|
||||
PromptTemplate,
|
||||
)
|
||||
from haystack.nodes.prompt import PromptNode
|
||||
from haystack.schema import Document, FilterType
|
||||
from haystack.schema import Document, FilterType, MultiLabel, Label, Span
|
||||
|
||||
from .mocks import pinecone as pinecone_mock
|
||||
|
||||
@ -476,6 +476,43 @@ def gc_cleanup(request):
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eval_labels() -> List[MultiLabel]:
|
||||
EVAL_LABELS = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
query="Who lives in Berlin?",
|
||||
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(
|
||||
id="a0747b83aea0b60c4b114b15476dd32d",
|
||||
content_type="text",
|
||||
content="My name is Carla and I live in Berlin",
|
||||
),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
origin="gold-label",
|
||||
)
|
||||
]
|
||||
),
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
query="Who lives in Munich?",
|
||||
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(
|
||||
id="something_else", content_type="text", content="My name is Carla and I live in Munich"
|
||||
),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
origin="gold-label",
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
return EVAL_LABELS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deepset_cloud_fixture():
|
||||
if MOCK_DC:
|
||||
|
||||
@ -2,7 +2,7 @@ from unittest.mock import patch, create_autospec
|
||||
|
||||
import pytest
|
||||
from haystack import Pipeline
|
||||
from haystack.schema import Document
|
||||
from haystack.schema import Document, Answer
|
||||
from haystack.nodes.answer_generator import OpenAIAnswerGenerator
|
||||
from haystack.nodes import PromptTemplate
|
||||
|
||||
@ -135,3 +135,42 @@ def test_openai_answer_generator_pipeline_max_tokens():
|
||||
result = pipeline.run(query=question, documents=nyc_docs, params={"generator": {"max_tokens": 3}})
|
||||
assert result["answers"] == mocked_response
|
||||
openai_generator.run.assert_called_with(query=question, documents=nyc_docs, max_tokens=3)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.openai.OpenAIAnswerGenerator.predict")
|
||||
def test_openai_answer_generator_run_with_labels_and_isolated_node_eval(patched_predict, eval_labels):
|
||||
label = eval_labels[0]
|
||||
query = label.query
|
||||
document = label.labels[0].document
|
||||
|
||||
patched_predict.return_value = {
|
||||
"answers": [Answer(answer=label.labels[0].answer.answer, document_ids=[document.id])]
|
||||
}
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1)
|
||||
result, _ = openai_generator.run(query=query, documents=[document], labels=label, add_isolated_node_eval=True)
|
||||
|
||||
assert "answers_isolated" in result
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.nodes.answer_generator.base.BaseGenerator.predict_batch")
|
||||
def test_openai_answer_generator_run_batch_with_labels_and_isolated_node_eval(patched_predict_batch, eval_labels):
|
||||
queries = [label.query for label in eval_labels]
|
||||
documents = [[label.labels[0].document] for label in eval_labels]
|
||||
|
||||
patched_predict_batch.return_value = {
|
||||
"queries": queries,
|
||||
"answers": [
|
||||
[Answer(answer=label.labels[0].answer.answer, document_ids=[label.labels[0].document.id])]
|
||||
for label in eval_labels
|
||||
],
|
||||
}
|
||||
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
|
||||
openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1)
|
||||
result, _ = openai_generator.run_batch(
|
||||
queries=queries, documents=documents, labels=eval_labels, add_isolated_node_eval=True
|
||||
)
|
||||
|
||||
assert "answers_isolated" in result
|
||||
|
||||
@ -31,14 +31,16 @@ from haystack.schema import Answer, Document, EvaluationResult, Label, MultiLabe
|
||||
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Causes OOM on windows github runner")
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["embedding"], indirect=True)
|
||||
def test_summarizer_calculate_metrics(document_store_with_docs: ElasticsearchDocumentStore, retriever_with_docs):
|
||||
def test_summarizer_calculate_metrics(
|
||||
document_store_with_docs: ElasticsearchDocumentStore, retriever_with_docs, eval_labels
|
||||
):
|
||||
document_store_with_docs.update_embeddings(retriever=retriever_with_docs)
|
||||
summarizer = TransformersSummarizer(model_name_or_path="sshleifer/distill-pegasus-xsum-16-4", use_gpu=False)
|
||||
pipeline = SearchSummarizationPipeline(
|
||||
retriever=retriever_with_docs, summarizer=summarizer, return_in_answer_format=True
|
||||
)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}}, context_matching_min_length=10
|
||||
labels=eval_labels, params={"Retriever": {"top_k": 5}}, context_matching_min_length=10
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="context")
|
||||
@ -249,39 +251,6 @@ def test_eval_data_split_passage(document_store, samples_path):
|
||||
assert len(docs[1].content) == 56
|
||||
|
||||
|
||||
EVAL_LABELS = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
query="Who lives in Berlin?",
|
||||
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(
|
||||
id="a0747b83aea0b60c4b114b15476dd32d",
|
||||
content_type="text",
|
||||
content="My name is Carla and I live in Berlin",
|
||||
),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
origin="gold-label",
|
||||
)
|
||||
]
|
||||
),
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
query="Who lives in Munich?",
|
||||
answer=Answer(answer="Carla", offsets_in_context=[Span(11, 16)]),
|
||||
document=Document(
|
||||
id="something_else", content_type="text", content="My name is Carla and I live in Munich"
|
||||
),
|
||||
is_correct_answer=True,
|
||||
is_correct_document=True,
|
||||
origin="gold-label",
|
||||
)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
NO_ANSWER_EVAL_LABELS = [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
@ -499,8 +468,8 @@ def test_table_qa_eval(table_reader_and_param, document_store, retriever):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
|
||||
labels = EVAL_LABELS[:1]
|
||||
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path, eval_labels):
|
||||
labels = eval_labels[:1]
|
||||
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result = pipeline.eval(labels=labels, params={"Retriever": {"top_k": 5}})
|
||||
@ -630,8 +599,8 @@ def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@responses.activate
|
||||
def test_generative_qa_eval(retriever_with_docs, tmp_path):
|
||||
labels = EVAL_LABELS[:1]
|
||||
def test_generative_qa_eval(retriever_with_docs, tmp_path, eval_labels):
|
||||
labels = eval_labels[:1]
|
||||
responses.add(
|
||||
responses.POST,
|
||||
"https://api.openai.com/v1/completions",
|
||||
@ -732,8 +701,8 @@ def test_generative_qa_eval(retriever_with_docs, tmp_path):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_generative_qa_w_promptnode_eval(retriever_with_docs, tmp_path):
|
||||
labels = EVAL_LABELS[:1]
|
||||
def test_generative_qa_w_promptnode_eval(retriever_with_docs, tmp_path, eval_labels):
|
||||
labels = eval_labels[:1]
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(retriever_with_docs, name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(
|
||||
@ -834,9 +803,9 @@ def test_generative_qa_w_promptnode_eval(retriever_with_docs, tmp_path):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_path):
|
||||
def test_extractive_qa_eval_multiple_queries(reader, retriever_with_docs, tmp_path, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -981,10 +950,10 @@ def test_extractive_qa_labels_with_filters(reader, retriever_with_docs, tmp_path
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_sas(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_sas(reader, retriever_with_docs, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
labels=eval_labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
|
||||
)
|
||||
@ -1009,12 +978,12 @@ def test_extractive_qa_eval_sas(reader, retriever_with_docs):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_reader_eval_in_pipeline(reader):
|
||||
def test_reader_eval_in_pipeline(reader, eval_labels):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["Query"])
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
documents=[[label.document for label in multilabel.labels] for multilabel in EVAL_LABELS],
|
||||
labels=eval_labels,
|
||||
documents=[[label.document for label in multilabel.labels] for multilabel in eval_labels],
|
||||
params={},
|
||||
)
|
||||
|
||||
@ -1026,10 +995,10 @@ def test_reader_eval_in_pipeline(reader):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractive_qa_eval_document_scope(retriever_with_docs):
|
||||
def test_extractive_qa_eval_document_scope(retriever_with_docs, eval_labels):
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
labels=eval_labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
context_matching_min_length=20, # artificially set down min_length to see if context matching is working properly
|
||||
)
|
||||
@ -1312,10 +1281,10 @@ def test_extractive_qa_eval_document_scope_no_answer(retriever_with_docs, docume
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_answer_scope(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_answer_scope(reader, retriever_with_docs, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
labels=eval_labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
|
||||
context_matching_min_length=20, # artificially set down min_length to see if context matching is working properly
|
||||
@ -1373,10 +1342,10 @@ def test_extractive_qa_eval_answer_scope(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_answer_document_scope_combinations(reader, retriever_with_docs, caplog):
|
||||
def test_extractive_qa_eval_answer_document_scope_combinations(reader, retriever_with_docs, caplog, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
labels=eval_labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
|
||||
context_matching_min_length=20, # artificially set down min_length to see if context matching is working properly
|
||||
@ -1408,10 +1377,10 @@ def test_extractive_qa_eval_answer_document_scope_combinations(reader, retriever
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS,
|
||||
labels=eval_labels,
|
||||
params={"Retriever": {"top_k": 5}},
|
||||
sas_model_name_or_path="sentence-transformers/paraphrase-MiniLM-L3-v2",
|
||||
)
|
||||
@ -1456,9 +1425,9 @@ def test_extractive_qa_eval_simulated_top_k_reader(reader, retriever_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_simulated_top_k_retriever(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_simulated_top_k_retriever(reader, retriever_with_docs, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics_top_10 = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -1508,9 +1477,9 @@ def test_extractive_qa_eval_simulated_top_k_retriever(reader, retriever_with_doc
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriever_with_docs, eval_labels):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 10}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 10}})
|
||||
|
||||
metrics_top_10 = eval_result.calculate_metrics(simulated_top_k_reader=1, document_scope="document_id")
|
||||
|
||||
@ -1567,8 +1536,8 @@ def test_extractive_qa_eval_simulated_top_k_reader_and_retriever(reader, retriev
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_isolated(reader, retriever_with_docs):
|
||||
labels = deepcopy(EVAL_LABELS)
|
||||
def test_extractive_qa_eval_isolated(reader, retriever_with_docs, eval_labels):
|
||||
labels = deepcopy(eval_labels)
|
||||
# Copy one of the labels and change only the answer have a label with a different answer but same Document
|
||||
label_copy = deepcopy(labels[0].labels[0])
|
||||
label_copy.answer = Answer(answer="I", offsets_in_context=[Span(21, 22)])
|
||||
@ -1702,9 +1671,9 @@ def test_extractive_qa_print_eval_report(reader, retriever_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_document_search_calculate_metrics(retriever_with_docs):
|
||||
def test_document_search_calculate_metrics(retriever_with_docs, eval_labels):
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -1732,11 +1701,11 @@ def test_document_search_calculate_metrics(retriever_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_document_search_isolated(retriever_with_docs):
|
||||
def test_document_search_isolated(retriever_with_docs, eval_labels):
|
||||
pipeline = DocumentSearchPipeline(retriever=retriever_with_docs)
|
||||
# eval run must not fail even though no node supports add_isolated_node_eval
|
||||
eval_result: EvaluationResult = pipeline.eval(
|
||||
labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}}, add_isolated_node_eval=True
|
||||
labels=eval_labels, params={"Retriever": {"top_k": 5}}, add_isolated_node_eval=True
|
||||
)
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
@ -1769,9 +1738,9 @@ def test_document_search_isolated(retriever_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_faq_calculate_metrics(retriever_with_docs):
|
||||
def test_faq_calculate_metrics(retriever_with_docs, eval_labels):
|
||||
pipeline = FAQPipeline(retriever=retriever_with_docs)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -1792,7 +1761,7 @@ def test_faq_calculate_metrics(retriever_with_docs):
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_extractive_qa_eval_translation(reader, retriever_with_docs):
|
||||
def test_extractive_qa_eval_translation(reader, retriever_with_docs, eval_labels):
|
||||
# FIXME it makes no sense to have DE->EN input and DE->EN output, right?
|
||||
# Yet switching direction breaks the test. TO BE FIXED.
|
||||
input_translator = TransformersTranslator(model_name_or_path="Helsinki-NLP/opus-mt-de-en")
|
||||
@ -1802,7 +1771,7 @@ def test_extractive_qa_eval_translation(reader, retriever_with_docs):
|
||||
pipeline = TranslationWrapperPipeline(
|
||||
input_translator=input_translator, output_translator=output_translator, pipeline=pipeline
|
||||
)
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -1832,10 +1801,10 @@ def test_extractive_qa_eval_translation(reader, retriever_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_question_generation_eval(retriever_with_docs, question_generator):
|
||||
def test_question_generation_eval(retriever_with_docs, question_generator, eval_labels):
|
||||
pipeline = RetrieverQuestionGenerationPipeline(retriever=retriever_with_docs, question_generator=question_generator)
|
||||
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=EVAL_LABELS, params={"Retriever": {"top_k": 5}})
|
||||
eval_result: EvaluationResult = pipeline.eval(labels=eval_labels, params={"Retriever": {"top_k": 5}})
|
||||
|
||||
metrics = eval_result.calculate_metrics(document_scope="document_id")
|
||||
|
||||
@ -1860,7 +1829,7 @@ def test_question_generation_eval(retriever_with_docs, question_generator):
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_qa_multi_retriever_pipeline_eval(document_store_with_docs, reader):
|
||||
def test_qa_multi_retriever_pipeline_eval(document_store_with_docs, reader, eval_labels):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
dpr_retriever = DensePassageRetriever(document_store_with_docs)
|
||||
document_store_with_docs.update_embeddings(retriever=dpr_retriever)
|
||||
@ -1874,7 +1843,7 @@ def test_qa_multi_retriever_pipeline_eval(document_store_with_docs, reader):
|
||||
|
||||
# EVAL_QUERIES: 2 go dpr way
|
||||
# in Berlin goes es way
|
||||
labels = EVAL_LABELS + [
|
||||
labels = eval_labels + [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
@ -1923,7 +1892,7 @@ def test_qa_multi_retriever_pipeline_eval(document_store_with_docs, reader):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_multi_retriever_pipeline_eval(document_store_with_docs):
|
||||
def test_multi_retriever_pipeline_eval(document_store_with_docs, eval_labels):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
dpr_retriever = DensePassageRetriever(document_store_with_docs)
|
||||
document_store_with_docs.update_embeddings(retriever=dpr_retriever)
|
||||
@ -1936,7 +1905,7 @@ def test_multi_retriever_pipeline_eval(document_store_with_docs):
|
||||
|
||||
# EVAL_QUERIES: 2 go dpr way
|
||||
# in Berlin goes es way
|
||||
labels = EVAL_LABELS + [
|
||||
labels = eval_labels + [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
@ -1982,7 +1951,7 @@ def test_multi_retriever_pipeline_eval(document_store_with_docs):
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_docs, reader):
|
||||
def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_docs, reader, eval_labels):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
dpr_retriever = DensePassageRetriever(document_store_with_docs)
|
||||
document_store_with_docs.update_embeddings(retriever=dpr_retriever)
|
||||
@ -1996,7 +1965,7 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do
|
||||
|
||||
# EVAL_QUERIES: 2 go dpr way
|
||||
# in Berlin goes es way
|
||||
labels = EVAL_LABELS + [
|
||||
labels = eval_labels + [
|
||||
MultiLabel(
|
||||
labels=[
|
||||
Label(
|
||||
@ -2047,8 +2016,8 @@ def test_multi_retriever_pipeline_with_asymmetric_qa_eval(document_store_with_do
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True)
|
||||
def test_empty_documents_dont_fail_pipeline(reader, retriever_with_docs):
|
||||
multilabels = EVAL_LABELS[:2]
|
||||
def test_empty_documents_dont_fail_pipeline(reader, retriever_with_docs, eval_labels):
|
||||
multilabels = eval_labels[:2]
|
||||
multilabels[0].labels[0].document.content = ""
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
eval_result_integrated: EvaluationResult = pipeline.eval(labels=multilabels, add_isolated_node_eval=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user