bug: Adds better way of checking query in BaseRetriever and Pipeline.run() (#3304)

* changes how query and queries are checked if they have been passed in BaseRetriever

* Fixes checking query properly in Pipeline run

* Fixes checking query properly in Pipeline run

* Adds test for FilterRetriever using run method when query is empty

* Adds mock filter retriever and adapts test

* Removes old test, adds MockRetriever to test file and test uses document_store

* Logs error when query is not of type string with a new test for run batch

* Update test/nodes/test_retriever.py

* schemas
This commit is contained in:
Unai Garay Maestre 2022-10-17 19:00:13 +02:00 committed by GitHub
parent 101d2bc86c
commit 3a2c8ae3c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 4 deletions

View File

@ -268,10 +268,15 @@ class BaseRetriever(BaseComponent):
scale_score: bool = None,
):
if root_node == "Query":
if not query:
if query is None:
raise HaystackError(
"Must provide a 'query' parameter for retrievers in pipelines where Query is the root node."
)
if not isinstance(query, str):
logger.error(
"The retriever received an unusual query: '%s' This query is likely to produce garbage output.",
query,
)
self.query_count += 1
run_query_timed = self.timing(self.run_query, "query_time")
output, stream = run_query_timed(
@ -296,10 +301,15 @@ class BaseRetriever(BaseComponent):
headers: Optional[Dict[str, str]] = None,
):
if root_node == "Query":
if not queries:
if queries is None:
raise HaystackError(
"Must provide a 'queries' parameter for retrievers in pipelines where Query is the root node."
)
if not all(isinstance(query, str) for query in queries):
logger.error(
"The retriever received an unusual list of queries: '%s' Some of these queries are likely to produce garbage output.",
queries,
)
self.query_count += len(queries) if isinstance(queries, list) else 1
run_query_batch_timed = self.timing(self.run_query_batch, "query_time")
output, stream = run_query_batch_timed(

View File

@ -484,7 +484,7 @@ class Pipeline:
queue: Dict[str, Any] = {
root_node: {"root_node": root_node, "params": params}
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
if query:
if query is not None:
queue[root_node]["query"] = query
if file_paths:
queue[root_node]["file_paths"] = file_paths

View File

@ -4,6 +4,7 @@ import os
import logging
import os
from math import isclose
from typing import Dict, List, Optional, Union
import pytest
import numpy as np
@ -25,7 +26,7 @@ from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetri
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from ..conftest import SAMPLES_PATH
from ..conftest import SAMPLES_PATH, MockRetriever
# TODO check if we this works with only "memory" arg
@ -88,6 +89,46 @@ def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs:
assert len(result) == 0
class MockBaseRetriever(MockRetriever):
def __init__(self, document_store: BaseDocumentStore, mock_document: Document):
self.document_store = document_store
self.mock_document = mock_document
def retrieve(
self,
query: str,
filters: dict,
top_k: Optional[int],
index: str,
headers: Optional[Dict[str, str]],
scale_score: bool,
):
return [self.mock_document]
def retrieve_batch(
self,
queries: List[str],
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: Optional[int] = None,
index: str = None,
headers: Optional[Dict[str, str]] = None,
batch_size: Optional[int] = None,
scale_score: bool = None,
):
return [[self.mock_document] for _ in range(len(queries))]
def test_retrieval_empty_query(document_store: BaseDocumentStore):
# test with empty query using the run() method
mock_document = Document(id="0", content="test")
retriever = MockBaseRetriever(document_store=document_store, mock_document=mock_document)
result = retriever.run(root_node="Query", query="", filters={})
assert result[0]["documents"][0] == mock_document
result = retriever.run_batch(root_node="Query", queries=[""], filters={})
assert result[0]["documents"][0][0] == mock_document
def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs):
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
document_store_with_docs.update_embeddings(retriever_with_docs)