From 3a2c8ae3c56065d733d5d384da2133c0a6433ea9 Mon Sep 17 00:00:00 2001 From: Unai Garay Maestre Date: Mon, 17 Oct 2022 19:00:13 +0200 Subject: [PATCH] bug: Adds better way of checking `query` in BaseRetriever and Pipeline.run() (#3304) * changes how query and queries are checked if they have been passed in BaseRetriever * Fixes checking query properly in Pipeline run * Fixes checking query properly in Pipeline run * Adds test for FilterRetriever using run method when query is empty * Adds mock filter retriever and adapts test * Removes old test, adds MockRetriever to test file and test uses document_store * Logs error when query is not of type string with a new test for run batch * Update test/nodes/test_retriever.py * schemas --- haystack/nodes/retriever/base.py | 14 +++++++++-- haystack/pipelines/base.py | 2 +- test/nodes/test_retriever.py | 43 +++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/haystack/nodes/retriever/base.py b/haystack/nodes/retriever/base.py index 149cb5024..4d144a287 100644 --- a/haystack/nodes/retriever/base.py +++ b/haystack/nodes/retriever/base.py @@ -268,10 +268,15 @@ class BaseRetriever(BaseComponent): scale_score: bool = None, ): if root_node == "Query": - if not query: + if query is None: raise HaystackError( "Must provide a 'query' parameter for retrievers in pipelines where Query is the root node." ) + if not isinstance(query, str): + logger.error( + "The retriever received an unusual query: '%s' This query is likely to produce garbage output.", + query, + ) self.query_count += 1 run_query_timed = self.timing(self.run_query, "query_time") output, stream = run_query_timed( @@ -296,10 +301,15 @@ class BaseRetriever(BaseComponent): headers: Optional[Dict[str, str]] = None, ): if root_node == "Query": - if not queries: + if queries is None: raise HaystackError( "Must provide a 'queries' parameter for retrievers in pipelines where Query is the root node." ) + if not all(isinstance(query, str) for query in queries): + logger.error( + "The retriever received an unusual list of queries: '%s' Some of these queries are likely to produce garbage output.", + queries, + ) self.query_count += len(queries) if isinstance(queries, list) else 1 run_query_batch_timed = self.timing(self.run_query_batch, "query_time") output, stream = run_query_batch_timed( diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 0ef2d72e9..1803c33df 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -484,7 +484,7 @@ class Pipeline: queue: Dict[str, Any] = { root_node: {"root_node": root_node, "params": params} } # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue - if query: + if query is not None: queue[root_node]["query"] = query if file_paths: queue[root_node]["file_paths"] = file_paths diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index 5e2e7aadd..792e245e8 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -4,6 +4,7 @@ import os import logging import os from math import isclose +from typing import Dict, List, Optional, Union import pytest import numpy as np @@ -25,7 +26,7 @@ from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetri from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from haystack.nodes.retriever.multimodal import MultiModalRetriever -from ..conftest import SAMPLES_PATH +from ..conftest import SAMPLES_PATH, MockRetriever # TODO check if we this works with only "memory" arg @@ -88,6 +89,46 @@ def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: assert len(result) == 0 +class MockBaseRetriever(MockRetriever): + def __init__(self, document_store: BaseDocumentStore, mock_document: Document): + self.document_store = document_store + self.mock_document = mock_document + + def retrieve( + self, + query: str, + filters: dict, + top_k: Optional[int], + index: str, + headers: Optional[Dict[str, str]], + scale_score: bool, + ): + return [self.mock_document] + + def retrieve_batch( + self, + queries: List[str], + filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, + top_k: Optional[int] = None, + index: str = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: bool = None, + ): + return [[self.mock_document] for _ in range(len(queries))] + + +def test_retrieval_empty_query(document_store: BaseDocumentStore): + # test with empty query using the run() method + mock_document = Document(id="0", content="test") + retriever = MockBaseRetriever(document_store=document_store, mock_document=mock_document) + result = retriever.run(root_node="Query", query="", filters={}) + assert result[0]["documents"][0] == mock_document + + result = retriever.run_batch(root_node="Query", queries=[""], filters={}) + assert result[0]["documents"][0][0] == mock_document + + def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs): if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)): document_store_with_docs.update_embeddings(retriever_with_docs)