mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
bug: Adds better way of checking query in BaseRetriever and Pipeline.run() (#3304)
* changes how query and queries are checked if they have been passed in BaseRetriever * Fixes checking query properly in Pipeline run * Fixes checking query properly in Pipeline run * Adds test for FilterRetriever using run method when query is empty * Adds mock filter retriever and adapts test * Removes old test, adds MockRetriever to test file and test uses document_store * Logs error when query is not of type string with a new test for run batch * Update test/nodes/test_retriever.py * schemas
This commit is contained in:
parent
101d2bc86c
commit
3a2c8ae3c5
@ -268,10 +268,15 @@ class BaseRetriever(BaseComponent):
|
||||
scale_score: bool = None,
|
||||
):
|
||||
if root_node == "Query":
|
||||
if not query:
|
||||
if query is None:
|
||||
raise HaystackError(
|
||||
"Must provide a 'query' parameter for retrievers in pipelines where Query is the root node."
|
||||
)
|
||||
if not isinstance(query, str):
|
||||
logger.error(
|
||||
"The retriever received an unusual query: '%s' This query is likely to produce garbage output.",
|
||||
query,
|
||||
)
|
||||
self.query_count += 1
|
||||
run_query_timed = self.timing(self.run_query, "query_time")
|
||||
output, stream = run_query_timed(
|
||||
@ -296,10 +301,15 @@ class BaseRetriever(BaseComponent):
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if root_node == "Query":
|
||||
if not queries:
|
||||
if queries is None:
|
||||
raise HaystackError(
|
||||
"Must provide a 'queries' parameter for retrievers in pipelines where Query is the root node."
|
||||
)
|
||||
if not all(isinstance(query, str) for query in queries):
|
||||
logger.error(
|
||||
"The retriever received an unusual list of queries: '%s' Some of these queries are likely to produce garbage output.",
|
||||
queries,
|
||||
)
|
||||
self.query_count += len(queries) if isinstance(queries, list) else 1
|
||||
run_query_batch_timed = self.timing(self.run_query_batch, "query_time")
|
||||
output, stream = run_query_batch_timed(
|
||||
|
||||
@ -484,7 +484,7 @@ class Pipeline:
|
||||
queue: Dict[str, Any] = {
|
||||
root_node: {"root_node": root_node, "params": params}
|
||||
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
|
||||
if query:
|
||||
if query is not None:
|
||||
queue[root_node]["query"] = query
|
||||
if file_paths:
|
||||
queue[root_node]["file_paths"] = file_paths
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import logging
|
||||
import os
|
||||
from math import isclose
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
@ -25,7 +26,7 @@ from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetri
|
||||
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
||||
from haystack.nodes.retriever.multimodal import MultiModalRetriever
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
from ..conftest import SAMPLES_PATH, MockRetriever
|
||||
|
||||
|
||||
# TODO check if we this works with only "memory" arg
|
||||
@ -88,6 +89,46 @@ def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs:
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class MockBaseRetriever(MockRetriever):
|
||||
def __init__(self, document_store: BaseDocumentStore, mock_document: Document):
|
||||
self.document_store = document_store
|
||||
self.mock_document = mock_document
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: dict,
|
||||
top_k: Optional[int],
|
||||
index: str,
|
||||
headers: Optional[Dict[str, str]],
|
||||
scale_score: bool,
|
||||
):
|
||||
return [self.mock_document]
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: str = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: bool = None,
|
||||
):
|
||||
return [[self.mock_document] for _ in range(len(queries))]
|
||||
|
||||
|
||||
def test_retrieval_empty_query(document_store: BaseDocumentStore):
|
||||
# test with empty query using the run() method
|
||||
mock_document = Document(id="0", content="test")
|
||||
retriever = MockBaseRetriever(document_store=document_store, mock_document=mock_document)
|
||||
result = retriever.run(root_node="Query", query="", filters={})
|
||||
assert result[0]["documents"][0] == mock_document
|
||||
|
||||
result = retriever.run_batch(root_node="Query", queries=[""], filters={})
|
||||
assert result[0]["documents"][0][0] == mock_document
|
||||
|
||||
|
||||
def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs):
|
||||
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
||||
document_store_with_docs.update_embeddings(retriever_with_docs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user