mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 13:46:54 +00:00
refactor: Rework WebRetriever caching, adjust tests (#5566)
* Rework WebRetriever caching, adjust tests * Add release note * Better pydocs * Minor improvements * Update haystack/nodes/retriever/web.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
a8d4a99db9
commit
46c9139caf
@ -3,10 +3,9 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from multiprocessing import cpu_count
|
||||
from typing import Dict, Iterator, List, Optional, Literal, Union
|
||||
from unicodedata import combining, normalize
|
||||
from typing import Dict, Iterator, List, Optional, Literal, Union, Tuple
|
||||
|
||||
from haystack import Document
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
@ -80,106 +79,12 @@ class WebRetriever(BaseRetriever):
|
||||
self.cache_document_store = cache_document_store
|
||||
self.document_store = cache_document_store
|
||||
self.cache_index = cache_index
|
||||
self.top_k = top_k
|
||||
self.cache_headers = cache_headers
|
||||
self.cache_time = cache_time
|
||||
self.top_k = top_k
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = preprocessor
|
||||
elif mode == "preprocessed_documents":
|
||||
self.preprocessor = PreProcessor(progress_bar=False)
|
||||
|
||||
def _normalize_query(self, query: str) -> str:
|
||||
return "".join([c for c in normalize("NFKD", query.lower()) if not combining(c)])
|
||||
|
||||
def _check_cache(
|
||||
self,
|
||||
query: str,
|
||||
cache_index: Optional[str] = None,
|
||||
cache_headers: Optional[Dict[str, str]] = None,
|
||||
cache_time: Optional[int] = None,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Private method to check if the documents for a given query are already cached. The documents are fetched from
|
||||
the specified DocumentStore. It retrieves documents that are newer than the cache_time limit.
|
||||
|
||||
:param query: The query string to check in the cache.
|
||||
:param cache_index: Optional index name in the DocumentStore to fetch the documents. Defaults to the instance's
|
||||
cache_index.
|
||||
:param cache_headers: Optional headers to be used when fetching documents from the DocumentStore. Defaults to
|
||||
the instance's cache_headers.
|
||||
:param cache_time: Optional time limit in seconds to check the cache. Only documents newer than cache_time are
|
||||
returned. Defaults to the instance's cache_time.
|
||||
:returns: A list of Document instances fetched from the cache. If no documents are found in the cache, an empty
|
||||
list is returned.
|
||||
"""
|
||||
cache_document_store = self.cache_document_store
|
||||
documents = []
|
||||
|
||||
if cache_document_store is not None:
|
||||
query_norm = self._normalize_query(query)
|
||||
cache_filter: FilterType = {"$and": {"search.query": query_norm}}
|
||||
|
||||
if cache_time is not None and cache_time > 0:
|
||||
cache_filter["timestamp"] = {
|
||||
"$gt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())
|
||||
}
|
||||
logger.debug("Cache filter: %s", cache_filter)
|
||||
|
||||
documents = cache_document_store.get_all_documents(
|
||||
filters=cache_filter, index=cache_index, headers=cache_headers, return_embedding=False
|
||||
)
|
||||
|
||||
logger.debug("Found %d documents in cache", len(documents))
|
||||
|
||||
return documents
|
||||
|
||||
def _save_cache(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
cache_index: Optional[str] = None,
|
||||
cache_headers: Optional[Dict[str, str]] = None,
|
||||
cache_time: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Private method to cache the retrieved documents for a given query.
|
||||
The documents are saved in the specified DocumentStore. If the same document already exists, it is
|
||||
overwritten.
|
||||
|
||||
:param query: The query string for which the documents are being cached.
|
||||
:param documents: The list of Document instances to be cached.
|
||||
:param cache_index: Optional index name in the DocumentStore to save the documents. Defaults to the
|
||||
instance's cache_index.
|
||||
:param cache_headers: Optional headers to be used when saving documents in the DocumentStore. Defaults to
|
||||
the instance's cache_headers.
|
||||
:param cache_time: Optional time limit in seconds to check the cache. Documents older than the
|
||||
cache_time are deleted. Defaults to the instance's cache_time.
|
||||
:returns: True if the documents are successfully saved in the cache, False otherwise.
|
||||
"""
|
||||
cache_document_store = self.cache_document_store
|
||||
|
||||
if cache_document_store is not None:
|
||||
cache_document_store.write_documents(
|
||||
documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite"
|
||||
)
|
||||
|
||||
logger.debug("Saved %d documents in the cache", len(documents))
|
||||
|
||||
cache_filter: FilterType = {"$and": {"search.query": query}}
|
||||
|
||||
if cache_time is not None and cache_time > 0:
|
||||
cache_filter["timestamp"] = {
|
||||
"$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())
|
||||
}
|
||||
|
||||
cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter)
|
||||
|
||||
logger.debug("Deleted documents in the cache using filter: %s", cache_filter)
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
self.preprocessor = (
|
||||
preprocessor or PreProcessor(progress_bar=False) if mode == "preprocessed_documents" else None
|
||||
)
|
||||
|
||||
def retrieve( # type: ignore[override]
|
||||
self,
|
||||
@ -201,7 +106,7 @@ class WebRetriever(BaseRetriever):
|
||||
|
||||
Optionally, the retrieved documents can be stored in a DocumentStore for future use, saving time
|
||||
and resources on repeated retrievals. This caching mechanism can significantly improve retrieval times
|
||||
for frequently accessed information.
|
||||
for frequently accessed URLs.
|
||||
|
||||
:param query: The query string.
|
||||
:param top_k: The number of Documents to be returned by the retriever.
|
||||
@ -209,56 +114,45 @@ class WebRetriever(BaseRetriever):
|
||||
:param cache_document_store: The DocumentStore to cache the documents to.
|
||||
:param cache_index: The index name to save the documents to.
|
||||
:param cache_headers: The headers to save the documents to.
|
||||
:param cache_time: The time limit in seconds to check the cache. The default is 24 hours.
|
||||
:param cache_time: The time limit in seconds for the documents in the cache. If objects are older than this time,
|
||||
they will be deleted from the cache on the next retrieval.
|
||||
"""
|
||||
|
||||
# Initialize default parameters
|
||||
preprocessor = preprocessor or self.preprocessor
|
||||
cache_document_store = cache_document_store or self.cache_document_store
|
||||
cache_index = cache_index or self.cache_index
|
||||
top_k = top_k or self.top_k
|
||||
cache_headers = cache_headers or self.cache_headers
|
||||
cache_time = cache_time or self.cache_time
|
||||
top_k = top_k or self.top_k
|
||||
|
||||
# Normalize query
|
||||
query_norm = self._normalize_query(query)
|
||||
search_results, _ = self.web_search.run(query=query)
|
||||
result_docs = search_results["documents"]
|
||||
|
||||
# Check cache for query
|
||||
extracted_docs = self._check_cache(
|
||||
query_norm, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time
|
||||
)
|
||||
if self.mode != "snippets":
|
||||
# for raw_documents and preprocessed_documents modes, we need to retrieve the links from the search results
|
||||
links: List[SearchResult] = self._prepare_links(result_docs)
|
||||
|
||||
# If query is not cached, fetch from web
|
||||
if not extracted_docs:
|
||||
extracted_docs = self._retrieve_from_web(query_norm, preprocessor)
|
||||
links_found_in_cache, cached_docs = self._check_cache(links)
|
||||
logger.debug("Found %d links in cache", len(links_found_in_cache))
|
||||
|
||||
# Save results to cache
|
||||
if cache_document_store and extracted_docs:
|
||||
cached = self._save_cache(query_norm, extracted_docs, cache_index=cache_index, cache_headers=cache_headers)
|
||||
if not cached:
|
||||
logger.warning(
|
||||
"Could not save retrieved Documents to the DocumentStore cache. "
|
||||
"Check your DocumentStore configuration."
|
||||
)
|
||||
return extracted_docs[:top_k]
|
||||
links_to_fetch = [link for link in links if link not in links_found_in_cache]
|
||||
logger.debug("Fetching %d links", len(links_to_fetch))
|
||||
result_docs = self._scrape_links(links_to_fetch)
|
||||
|
||||
def _retrieve_from_web(self, query_norm: str, preprocessor: Optional[PreProcessor]) -> List[Document]:
|
||||
"""
|
||||
Retrieve Documents from the web based on the query.
|
||||
# Save result_docs to cache
|
||||
self._save_to_cache(
|
||||
result_docs, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time
|
||||
)
|
||||
|
||||
:param query_norm: The normalized query string.
|
||||
:param preprocessor: The PreProcessor to be used to split documents into paragraphs.
|
||||
:return: List of Document objects.
|
||||
"""
|
||||
# join cached_docs and result_docs
|
||||
result_docs = cached_docs + result_docs
|
||||
|
||||
search_results, _ = self.web_search.run(query=query_norm)
|
||||
search_results_docs = search_results["documents"]
|
||||
if self.mode == "snippets":
|
||||
return search_results_docs
|
||||
else:
|
||||
links: List[SearchResult] = self._prepare_links(search_results_docs)
|
||||
logger.debug("Starting to fetch %d links from WebSearch results", len(links))
|
||||
return self._scrape_links(links, query_norm, preprocessor)
|
||||
# Preprocess documents
|
||||
if preprocessor:
|
||||
result_docs = preprocessor.process(result_docs)
|
||||
|
||||
# Return results
|
||||
return result_docs[:top_k]
|
||||
|
||||
def _prepare_links(self, search_results: List[Document]) -> List[SearchResult]:
|
||||
"""
|
||||
@ -277,38 +171,29 @@ class WebRetriever(BaseRetriever):
|
||||
]
|
||||
return links
|
||||
|
||||
def _scrape_links(
|
||||
self, links: List[SearchResult], query_norm: str, preprocessor: Optional[PreProcessor]
|
||||
) -> List[Document]:
|
||||
def _scrape_links(self, links: List[SearchResult]) -> List[Document]:
|
||||
"""
|
||||
Scrape the links and return the documents.
|
||||
|
||||
:param links: List of SearchResult objects.
|
||||
:param query_norm: The normalized query string.
|
||||
:param preprocessor: The PreProcessor object to be used to split documents into paragraphs.
|
||||
:return: List of Document objects obtained from scraping the links.
|
||||
:return: List of Document objects obtained by fetching the content from the links.
|
||||
"""
|
||||
if not links:
|
||||
return []
|
||||
|
||||
fetcher = (
|
||||
LinkContentFetcher(processor=preprocessor, raise_on_failure=True)
|
||||
if self.mode == "preprocessed_documents" and preprocessor
|
||||
else LinkContentFetcher(raise_on_failure=True)
|
||||
)
|
||||
fetcher = LinkContentFetcher(raise_on_failure=True)
|
||||
|
||||
def scrape_link_content(link: SearchResult) -> List[Document]:
|
||||
def link_fetch(link: SearchResult) -> List[Document]:
|
||||
"""
|
||||
Encapsulate the link scraping logic in a function to be used in a ThreadPoolExecutor.
|
||||
Encapsulate the link fetching logic in a function to be used in a ThreadPoolExecutor.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
try:
|
||||
docs = fetcher.fetch(
|
||||
url=link.url,
|
||||
doc_kwargs={
|
||||
"id_hash_keys": ["meta.url"],
|
||||
"search.score": link.score,
|
||||
"id_hash_keys": ["meta.url", "meta.search.query"],
|
||||
"search.query": query_norm,
|
||||
"search.position": link.position,
|
||||
"snippet_text": link.snippet,
|
||||
},
|
||||
@ -319,18 +204,72 @@ class WebRetriever(BaseRetriever):
|
||||
|
||||
return docs
|
||||
|
||||
thread_count = cpu_count() if len(links) > cpu_count() else len(links)
|
||||
thread_count = min(cpu_count() if len(links) > cpu_count() else len(links), 10) # max 10 threads
|
||||
with ThreadPoolExecutor(max_workers=thread_count) as executor:
|
||||
scraped_pages: Iterator[List[Document]] = executor.map(scrape_link_content, links)
|
||||
fetched_pages: Iterator[List[Document]] = executor.map(link_fetch, links)
|
||||
|
||||
# Flatten list of lists to a single list
|
||||
extracted_docs = [doc for doc_list in scraped_pages for doc in doc_list]
|
||||
extracted_docs = [doc for doc_list in fetched_pages for doc in doc_list]
|
||||
|
||||
# Sort by score
|
||||
extracted_docs = sorted(extracted_docs, key=lambda x: x.meta["search.score"], reverse=True)
|
||||
|
||||
return extracted_docs
|
||||
|
||||
def _check_cache(self, links: List[SearchResult]) -> Tuple[List[SearchResult], List[Document]]:
|
||||
"""
|
||||
Check the DocumentStore cache for documents.
|
||||
|
||||
:param links: List of SearchResult objects.
|
||||
:return: Tuple of lists of SearchResult and Document objects that were found in the cache.
|
||||
"""
|
||||
if not links or not self.cache_document_store:
|
||||
return [], []
|
||||
|
||||
cache_documents: List[Document] = []
|
||||
cached_links: List[SearchResult] = []
|
||||
|
||||
valid_links = [link for link in links if link.url]
|
||||
for link in valid_links:
|
||||
cache_filter: FilterType = {"url": link.url}
|
||||
documents = self.cache_document_store.get_all_documents(filters=cache_filter, return_embedding=False)
|
||||
if documents:
|
||||
cache_documents.extend(documents)
|
||||
cached_links.append(link)
|
||||
|
||||
return cached_links, cache_documents
|
||||
|
||||
def _save_to_cache(
|
||||
self,
|
||||
documents: List[Document],
|
||||
cache_index: Optional[str] = None,
|
||||
cache_headers: Optional[Dict[str, str]] = None,
|
||||
cache_time: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the documents to the cache and potentially delete old expired documents from the cache.
|
||||
|
||||
:param documents: List of Document objects to be saved to the cache.
|
||||
:param cache_index: Optional index name to save the documents to.
|
||||
:param cache_headers: Optional headers made to use when saving the documents to the cache.
|
||||
:param cache_time: Optional time to live in seconds for the documents in the cache. If objects are older than
|
||||
this time, they will be deleted from the cache.
|
||||
"""
|
||||
cache_document_store = self.cache_document_store
|
||||
|
||||
if cache_document_store is not None and documents:
|
||||
cache_document_store.write_documents(
|
||||
documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite"
|
||||
)
|
||||
|
||||
if cache_document_store and cache_time is not None and cache_time > 0:
|
||||
cache_filter: FilterType = {
|
||||
"timestamp": {"$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())}
|
||||
}
|
||||
|
||||
cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter)
|
||||
logger.debug("Deleted documents in the cache using filter: %s", cache_filter)
|
||||
|
||||
def retrieve_batch( # type: ignore[override]
|
||||
self,
|
||||
queries: List[str],
|
||||
@ -358,7 +297,8 @@ class WebRetriever(BaseRetriever):
|
||||
DocumentStore is used.
|
||||
:param cache_index: The index name to save the documents to. If None, the instance's default cache_index is used.
|
||||
:param cache_headers: The headers to save the documents to. If None, the instance's default cache_headers is used.
|
||||
:param cache_time: The time limit in seconds to check the cache. If None, the instance's default cache_time is used.
|
||||
:param cache_time: The time limit in seconds for the documents in the cache.
|
||||
|
||||
:returns: A list of lists where each inner list represents the documents fetched for a particular query.
|
||||
"""
|
||||
documents = []
|
||||
@ -370,8 +310,6 @@ class WebRetriever(BaseRetriever):
|
||||
preprocessor=preprocessor,
|
||||
cache_document_store=cache_document_store,
|
||||
cache_index=cache_index,
|
||||
cache_headers=cache_headers,
|
||||
cache_time=cache_time,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
The WebRetriever now employs an enhanced caching mechanism that caches web page content based on search engine
|
||||
results rather than the query.
|
||||
@ -1,13 +1,11 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
from unittest.mock import patch, Mock
|
||||
from test.conftest import MockDocumentStore
|
||||
import pytest
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.nodes import WebRetriever, PromptNode
|
||||
from haystack.nodes.retriever.link_content import html_content_handler
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.nodes.retriever.web import SearchResult
|
||||
from test.nodes.conftest import example_serperdev_response
|
||||
|
||||
@ -46,75 +44,45 @@ def test_init_default_parameters():
|
||||
assert retriever.preprocessor is None
|
||||
assert retriever.cache_document_store is None
|
||||
assert retriever.cache_index is None
|
||||
assert retriever.cache_headers is None
|
||||
assert retriever.cache_time == 1 * 24 * 60 * 60
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_custom_parameters():
|
||||
preprocessor = PreProcessor()
|
||||
document_store = MagicMock(spec=BaseDocumentStore)
|
||||
headers = {"Test": "Header"}
|
||||
@pytest.mark.parametrize("mode", ["snippets", "raw_documents", "preprocessed_documents"])
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 7])
|
||||
def test_retrieve_from_web_all_params(mock_web_search, mode, top_k):
|
||||
"""
|
||||
Test that the retriever returns the correct number of documents in all modes
|
||||
"""
|
||||
search_result_len = len(example_serperdev_response["organic"])
|
||||
wr = WebRetriever(api_key="fake_key", top_k=top_k, mode=mode)
|
||||
|
||||
retriever = WebRetriever(
|
||||
api_key="test_key",
|
||||
search_engine_provider="SerperDev",
|
||||
top_search_results=15,
|
||||
top_k=7,
|
||||
mode="preprocessed_documents",
|
||||
preprocessor=preprocessor,
|
||||
cache_document_store=document_store,
|
||||
cache_index="custom_index",
|
||||
cache_headers=headers,
|
||||
cache_time=2 * 24 * 60 * 60,
|
||||
)
|
||||
docs = [Document("test" + str(i)) for i in range(search_result_len)]
|
||||
with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=docs):
|
||||
retrieved_docs = wr.retrieve(query="who is the boyfriend of olivia wilde?")
|
||||
|
||||
assert retriever.top_k == 7
|
||||
assert retriever.mode == "preprocessed_documents"
|
||||
assert retriever.preprocessor == preprocessor
|
||||
assert retriever.cache_document_store == document_store
|
||||
assert retriever.cache_index == "custom_index"
|
||||
assert retriever.cache_headers == headers
|
||||
assert retriever.cache_time == 2 * 24 * 60 * 60
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_from_web_all_params(mock_web_search):
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
|
||||
preprocessor = PreProcessor()
|
||||
|
||||
result = wr._retrieve_from_web(query_norm="who is the boyfriend of olivia wilde?", preprocessor=preprocessor)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(doc, Document) for doc in result)
|
||||
assert len(result) == len(example_serperdev_response["organic"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_from_web_no_preprocessor(mock_web_search):
|
||||
# tests that we get top_k results when no PreProcessor is provided
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
result = wr._retrieve_from_web("query", None)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(doc, Document) for doc in result)
|
||||
assert len(result) == len(example_serperdev_response["organic"])
|
||||
assert isinstance(retrieved_docs, list)
|
||||
assert all(isinstance(doc, Document) for doc in retrieved_docs)
|
||||
assert len(retrieved_docs) == top_k
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_from_web_invalid_query(mock_web_search):
|
||||
# however, if query is None or empty, we expect an error
|
||||
"""
|
||||
Test that the retriever raises an error if the query is invalid
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
with pytest.raises(ValueError, match="WebSearch run requires"):
|
||||
wr._retrieve_from_web("", None)
|
||||
wr.retrieve("")
|
||||
|
||||
with pytest.raises(ValueError, match="WebSearch run requires"):
|
||||
wr._retrieve_from_web(None, None)
|
||||
wr.retrieve(None)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_links_empty_list():
|
||||
"""
|
||||
Test that the retriever's _prepare_links method returns an empty list if the input is an empty list
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
result = wr._prepare_links([])
|
||||
assert result == []
|
||||
@ -125,8 +93,11 @@ def test_prepare_links_empty_list():
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_scrape_links_empty_list():
|
||||
"""
|
||||
Test that the retriever's _scrape_links method returns an empty list if the input is an empty list
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
result = wr._scrape_links([], "query", None)
|
||||
result = wr._scrape_links([])
|
||||
assert result == []
|
||||
|
||||
|
||||
@ -134,13 +105,16 @@ def test_scrape_links_empty_list():
|
||||
def test_scrape_links_with_search_results(
|
||||
mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type
|
||||
):
|
||||
"""
|
||||
Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
|
||||
sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1")
|
||||
sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2")
|
||||
sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1")
|
||||
sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2")
|
||||
fake_search_results = [sr1, sr2]
|
||||
|
||||
result = wr._scrape_links(fake_search_results, "query", None)
|
||||
result = wr._scrape_links(fake_search_results)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(r, Document) for r in result)
|
||||
@ -151,14 +125,17 @@ def test_scrape_links_with_search_results(
|
||||
def test_scrape_links_with_search_results_with_preprocessor(
|
||||
mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type
|
||||
):
|
||||
"""
|
||||
Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults
|
||||
and a preprocessor is provided
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
|
||||
preprocessor = PreProcessor(progress_bar=False)
|
||||
|
||||
sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1")
|
||||
sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2")
|
||||
sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1")
|
||||
sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2")
|
||||
fake_search_results = [sr1, sr2]
|
||||
|
||||
result = wr._scrape_links(fake_search_results, "query", preprocessor)
|
||||
result = wr._scrape_links(fake_search_results)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(r, Document) for r in result)
|
||||
@ -168,33 +145,44 @@ def test_scrape_links_with_search_results_with_preprocessor(
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_uses_defaults():
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
def test_retrieve_checks_cache(mock_web_search):
|
||||
"""
|
||||
Test that the retriever's retrieve method checks the cache
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
|
||||
|
||||
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
|
||||
with patch.object(wr, "_retrieve_from_web", return_value=[]) as mock_retrieve_from_web:
|
||||
wr.retrieve("query")
|
||||
with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache:
|
||||
wr.retrieve("query")
|
||||
|
||||
# cache is checked first, always
|
||||
mock_check_cache.assert_called_with(
|
||||
"query", cache_index=wr.cache_index, cache_headers=wr.cache_headers, cache_time=wr.cache_time
|
||||
)
|
||||
mock_retrieve_from_web.assert_called_with("query", wr.preprocessor)
|
||||
# assert cache is checked
|
||||
mock_check_cache.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_batch():
|
||||
queries = ["query1", "query2"]
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
|
||||
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
|
||||
with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web:
|
||||
result = wr.retrieve_batch(queries)
|
||||
def test_retrieve_no_cache_checks_in_snippet_mode(mock_web_search):
|
||||
"""
|
||||
Test that the retriever's retrieve method does not check the cache if the mode is snippets
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key", mode="snippets")
|
||||
|
||||
with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache:
|
||||
wr.retrieve("query")
|
||||
|
||||
# assert cache is NOT checked
|
||||
mock_check_cache.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_batch(mock_web_search):
|
||||
"""
|
||||
Test that the retriever's retrieve_batch method returns a list of lists of Documents
|
||||
"""
|
||||
queries = ["query1", "query2"]
|
||||
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
|
||||
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
|
||||
with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=web_docs):
|
||||
result = wr.retrieve_batch(queries)
|
||||
|
||||
assert mock_check_cache.call_count == len(queries)
|
||||
assert mock_retrieve_from_web.call_count == len(queries)
|
||||
# check that the result is a list of lists of Documents
|
||||
# where each list of Documents is the result of a single query
|
||||
assert len(result) == len(queries)
|
||||
|
||||
# check that the result is a list of lists of Documents
|
||||
@ -207,63 +195,44 @@ def test_retrieve_batch():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_uses_cache():
|
||||
wr = WebRetriever(api_key="fake_key")
|
||||
def test_retrieve_uses_cache(mock_web_search):
|
||||
"""
|
||||
Test that the retriever's retrieve method uses the cache if it is available
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key", mode="raw_documents", cache_document_store=MockDocumentStore())
|
||||
|
||||
cached_links = [
|
||||
SearchResult("https://pagesix.com", "Some text", 0.43, "1"),
|
||||
SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"),
|
||||
]
|
||||
cached_docs = [Document("doc1"), Document("doc2")]
|
||||
with patch.object(wr, "_check_cache", return_value=cached_docs) as mock_check_cache:
|
||||
with patch.object(wr, "_retrieve_from_web") as mock_retrieve_from_web:
|
||||
with patch.object(wr, "_save_cache") as mock_save_cache:
|
||||
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache:
|
||||
with patch.object(wr, "_save_to_cache") as mock_save_cache:
|
||||
with patch.object(wr, "_scrape_links", return_value=[]):
|
||||
result = wr.retrieve("query")
|
||||
|
||||
# checking cache is always called
|
||||
mock_check_cache.assert_called()
|
||||
|
||||
# these methods are not called because we found docs in cache
|
||||
mock_retrieve_from_web.assert_not_called()
|
||||
mock_save_cache.assert_not_called()
|
||||
|
||||
# cache save is called but with empty list of documents
|
||||
mock_save_cache.assert_called()
|
||||
assert mock_save_cache.call_args[0][0] == []
|
||||
assert result == cached_docs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_saves_to_cache():
|
||||
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore())
|
||||
def test_retrieve_saves_to_cache(mock_web_search):
|
||||
"""
|
||||
Test that the retriever's retrieve method saves to the cache if it is available
|
||||
"""
|
||||
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents")
|
||||
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
|
||||
|
||||
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
|
||||
with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web:
|
||||
with patch.object(wr, "_save_cache") as mock_save_cache:
|
||||
result = wr.retrieve("query")
|
||||
with patch.object(wr, "_save_to_cache") as mock_save_cache:
|
||||
with patch.object(wr, "_scrape_links", return_value=web_docs):
|
||||
wr.retrieve("query")
|
||||
|
||||
mock_check_cache.assert_called()
|
||||
|
||||
# cache is empty, so we call _retrieve_from_web
|
||||
mock_retrieve_from_web.assert_called()
|
||||
# and save the results to cache
|
||||
mock_save_cache.assert_called_with("query", web_docs, cache_index=wr.cache_index, cache_headers=wr.cache_headers)
|
||||
assert result == web_docs
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_retrieve_returns_top_k():
|
||||
wr = WebRetriever(api_key="", top_k=2)
|
||||
|
||||
with patch.object(wr, "_check_cache", return_value=[]):
|
||||
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
|
||||
with patch.object(wr, "_retrieve_from_web", return_value=web_docs):
|
||||
result = wr.retrieve("query")
|
||||
|
||||
assert result == web_docs[:2]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("top_k", [1, 3, 6])
|
||||
def test_top_k_parameter(mock_web_search, top_k):
|
||||
web_retriever = WebRetriever(api_key="some_invalid_key", mode="snippets")
|
||||
result = web_retriever.retrieve(query="Who is the boyfriend of Olivia Wilde?", top_k=top_k)
|
||||
assert len(result) == top_k
|
||||
assert all(isinstance(doc, Document) for doc in result)
|
||||
mock_save_cache.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -277,7 +246,9 @@ def test_top_k_parameter(mock_web_search, top_k):
|
||||
)
|
||||
@pytest.mark.parametrize("top_k", [2, 4])
|
||||
def test_top_k_parameter_in_pipeline(top_k):
|
||||
# test that WebRetriever top_k param is NOT ignored in a pipeline
|
||||
"""
|
||||
Test that the top_k parameter works in the pipeline
|
||||
"""
|
||||
prompt_node = PromptNode(
|
||||
"gpt-3.5-turbo",
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
@ -293,20 +264,3 @@ def test_top_k_parameter_in_pipeline(top_k):
|
||||
pipe.add_node(component=prompt_node, name="QAwithScoresPrompt", inputs=["WebRetriever"])
|
||||
result = pipe.run(query="What year was Obama president", params={"WebRetriever": {"top_k": top_k}})
|
||||
assert len(result["results"]) == top_k
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("SERPERDEV_API_KEY", None),
|
||||
reason="Please export an env var called SERPERDEV_API_KEY containing the serper.dev API key to run this test.",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.skip
|
||||
def test_web_retriever_speed():
|
||||
retriever = WebRetriever(api_key=os.environ.get("SERPERDEV_API_KEY"), mode="preprocessed_documents")
|
||||
result = retriever.retrieve(query="What's the meaning of it all?")
|
||||
assert len(result) >= 5
|
||||
assert all(isinstance(doc, Document) for doc in result)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user