refactor: Rework WebRetriever caching, adjust tests (#5566)

* Rework WebRetriever caching, adjust tests

* Add release note

* Better pydocs

* Minor improvements

* Update haystack/nodes/retriever/web.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2023-08-16 17:41:11 +02:00 committed by GitHub
parent a8d4a99db9
commit 46c9139caf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 201 additions and 304 deletions

View File

@ -3,10 +3,9 @@ from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime, timedelta
from multiprocessing import cpu_count
from typing import Dict, Iterator, List, Optional, Literal, Union
from unicodedata import combining, normalize
from typing import Dict, Iterator, List, Optional, Literal, Union, Tuple
from haystack import Document
from haystack.schema import Document
from haystack.document_stores.base import BaseDocumentStore
from haystack.nodes.preprocessor import PreProcessor
from haystack.nodes.retriever.base import BaseRetriever
@ -80,106 +79,12 @@ class WebRetriever(BaseRetriever):
self.cache_document_store = cache_document_store
self.document_store = cache_document_store
self.cache_index = cache_index
self.top_k = top_k
self.cache_headers = cache_headers
self.cache_time = cache_time
self.top_k = top_k
self.preprocessor = None
if preprocessor is not None:
self.preprocessor = preprocessor
elif mode == "preprocessed_documents":
self.preprocessor = PreProcessor(progress_bar=False)
def _normalize_query(self, query: str) -> str:
return "".join([c for c in normalize("NFKD", query.lower()) if not combining(c)])
def _check_cache(
self,
query: str,
cache_index: Optional[str] = None,
cache_headers: Optional[Dict[str, str]] = None,
cache_time: Optional[int] = None,
) -> List[Document]:
"""
Private method to check if the documents for a given query are already cached. The documents are fetched from
the specified DocumentStore. It retrieves documents that are newer than the cache_time limit.
:param query: The query string to check in the cache.
:param cache_index: Optional index name in the DocumentStore to fetch the documents. Defaults to the instance's
cache_index.
:param cache_headers: Optional headers to be used when fetching documents from the DocumentStore. Defaults to
the instance's cache_headers.
:param cache_time: Optional time limit in seconds to check the cache. Only documents newer than cache_time are
returned. Defaults to the instance's cache_time.
:returns: A list of Document instances fetched from the cache. If no documents are found in the cache, an empty
list is returned.
"""
cache_document_store = self.cache_document_store
documents = []
if cache_document_store is not None:
query_norm = self._normalize_query(query)
cache_filter: FilterType = {"$and": {"search.query": query_norm}}
if cache_time is not None and cache_time > 0:
cache_filter["timestamp"] = {
"$gt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())
}
logger.debug("Cache filter: %s", cache_filter)
documents = cache_document_store.get_all_documents(
filters=cache_filter, index=cache_index, headers=cache_headers, return_embedding=False
)
logger.debug("Found %d documents in cache", len(documents))
return documents
def _save_cache(
self,
query: str,
documents: List[Document],
cache_index: Optional[str] = None,
cache_headers: Optional[Dict[str, str]] = None,
cache_time: Optional[int] = None,
) -> bool:
"""
Private method to cache the retrieved documents for a given query.
The documents are saved in the specified DocumentStore. If the same document already exists, it is
overwritten.
:param query: The query string for which the documents are being cached.
:param documents: The list of Document instances to be cached.
:param cache_index: Optional index name in the DocumentStore to save the documents. Defaults to the
instance's cache_index.
:param cache_headers: Optional headers to be used when saving documents in the DocumentStore. Defaults to
the instance's cache_headers.
:param cache_time: Optional time limit in seconds to check the cache. Documents older than the
cache_time are deleted. Defaults to the instance's cache_time.
:returns: True if the documents are successfully saved in the cache, False otherwise.
"""
cache_document_store = self.cache_document_store
if cache_document_store is not None:
cache_document_store.write_documents(
documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite"
)
logger.debug("Saved %d documents in the cache", len(documents))
cache_filter: FilterType = {"$and": {"search.query": query}}
if cache_time is not None and cache_time > 0:
cache_filter["timestamp"] = {
"$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())
}
cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter)
logger.debug("Deleted documents in the cache using filter: %s", cache_filter)
return True
return False
self.preprocessor = (
preprocessor or PreProcessor(progress_bar=False) if mode == "preprocessed_documents" else None
)
def retrieve( # type: ignore[override]
self,
@ -201,7 +106,7 @@ class WebRetriever(BaseRetriever):
Optionally, the retrieved documents can be stored in a DocumentStore for future use, saving time
and resources on repeated retrievals. This caching mechanism can significantly improve retrieval times
for frequently accessed information.
for frequently accessed URLs.
:param query: The query string.
:param top_k: The number of Documents to be returned by the retriever.
@ -209,56 +114,45 @@ class WebRetriever(BaseRetriever):
:param cache_document_store: The DocumentStore to cache the documents to.
:param cache_index: The index name to save the documents to.
:param cache_headers: The headers to save the documents to.
:param cache_time: The time limit in seconds to check the cache. The default is 24 hours.
:param cache_time: The time limit in seconds for the documents in the cache. If objects are older than this time,
they will be deleted from the cache on the next retrieval.
"""
# Initialize default parameters
preprocessor = preprocessor or self.preprocessor
cache_document_store = cache_document_store or self.cache_document_store
cache_index = cache_index or self.cache_index
top_k = top_k or self.top_k
cache_headers = cache_headers or self.cache_headers
cache_time = cache_time or self.cache_time
top_k = top_k or self.top_k
# Normalize query
query_norm = self._normalize_query(query)
search_results, _ = self.web_search.run(query=query)
result_docs = search_results["documents"]
# Check cache for query
extracted_docs = self._check_cache(
query_norm, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time
)
if self.mode != "snippets":
# for raw_documents and preprocessed_documents modes, we need to retrieve the links from the search results
links: List[SearchResult] = self._prepare_links(result_docs)
# If query is not cached, fetch from web
if not extracted_docs:
extracted_docs = self._retrieve_from_web(query_norm, preprocessor)
links_found_in_cache, cached_docs = self._check_cache(links)
logger.debug("Found %d links in cache", len(links_found_in_cache))
# Save results to cache
if cache_document_store and extracted_docs:
cached = self._save_cache(query_norm, extracted_docs, cache_index=cache_index, cache_headers=cache_headers)
if not cached:
logger.warning(
"Could not save retrieved Documents to the DocumentStore cache. "
"Check your DocumentStore configuration."
)
return extracted_docs[:top_k]
links_to_fetch = [link for link in links if link not in links_found_in_cache]
logger.debug("Fetching %d links", len(links_to_fetch))
result_docs = self._scrape_links(links_to_fetch)
def _retrieve_from_web(self, query_norm: str, preprocessor: Optional[PreProcessor]) -> List[Document]:
"""
Retrieve Documents from the web based on the query.
# Save result_docs to cache
self._save_to_cache(
result_docs, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time
)
:param query_norm: The normalized query string.
:param preprocessor: The PreProcessor to be used to split documents into paragraphs.
:return: List of Document objects.
"""
# join cached_docs and result_docs
result_docs = cached_docs + result_docs
search_results, _ = self.web_search.run(query=query_norm)
search_results_docs = search_results["documents"]
if self.mode == "snippets":
return search_results_docs
else:
links: List[SearchResult] = self._prepare_links(search_results_docs)
logger.debug("Starting to fetch %d links from WebSearch results", len(links))
return self._scrape_links(links, query_norm, preprocessor)
# Preprocess documents
if preprocessor:
result_docs = preprocessor.process(result_docs)
# Return results
return result_docs[:top_k]
def _prepare_links(self, search_results: List[Document]) -> List[SearchResult]:
"""
@ -277,38 +171,29 @@ class WebRetriever(BaseRetriever):
]
return links
def _scrape_links(
self, links: List[SearchResult], query_norm: str, preprocessor: Optional[PreProcessor]
) -> List[Document]:
def _scrape_links(self, links: List[SearchResult]) -> List[Document]:
"""
Scrape the links and return the documents.
:param links: List of SearchResult objects.
:param query_norm: The normalized query string.
:param preprocessor: The PreProcessor object to be used to split documents into paragraphs.
:return: List of Document objects obtained from scraping the links.
:return: List of Document objects obtained by fetching the content from the links.
"""
if not links:
return []
fetcher = (
LinkContentFetcher(processor=preprocessor, raise_on_failure=True)
if self.mode == "preprocessed_documents" and preprocessor
else LinkContentFetcher(raise_on_failure=True)
)
fetcher = LinkContentFetcher(raise_on_failure=True)
def scrape_link_content(link: SearchResult) -> List[Document]:
def link_fetch(link: SearchResult) -> List[Document]:
"""
Encapsulate the link scraping logic in a function to be used in a ThreadPoolExecutor.
Encapsulate the link fetching logic in a function to be used in a ThreadPoolExecutor.
"""
docs: List[Document] = []
try:
docs = fetcher.fetch(
url=link.url,
doc_kwargs={
"id_hash_keys": ["meta.url"],
"search.score": link.score,
"id_hash_keys": ["meta.url", "meta.search.query"],
"search.query": query_norm,
"search.position": link.position,
"snippet_text": link.snippet,
},
@ -319,18 +204,72 @@ class WebRetriever(BaseRetriever):
return docs
thread_count = cpu_count() if len(links) > cpu_count() else len(links)
thread_count = min(cpu_count() if len(links) > cpu_count() else len(links), 10) # max 10 threads
with ThreadPoolExecutor(max_workers=thread_count) as executor:
scraped_pages: Iterator[List[Document]] = executor.map(scrape_link_content, links)
fetched_pages: Iterator[List[Document]] = executor.map(link_fetch, links)
# Flatten list of lists to a single list
extracted_docs = [doc for doc_list in scraped_pages for doc in doc_list]
extracted_docs = [doc for doc_list in fetched_pages for doc in doc_list]
# Sort by score
extracted_docs = sorted(extracted_docs, key=lambda x: x.meta["search.score"], reverse=True)
return extracted_docs
def _check_cache(self, links: List[SearchResult]) -> Tuple[List[SearchResult], List[Document]]:
"""
Check the DocumentStore cache for documents.
:param links: List of SearchResult objects.
:return: Tuple of lists of SearchResult and Document objects that were found in the cache.
"""
if not links or not self.cache_document_store:
return [], []
cache_documents: List[Document] = []
cached_links: List[SearchResult] = []
valid_links = [link for link in links if link.url]
for link in valid_links:
cache_filter: FilterType = {"url": link.url}
documents = self.cache_document_store.get_all_documents(filters=cache_filter, return_embedding=False)
if documents:
cache_documents.extend(documents)
cached_links.append(link)
return cached_links, cache_documents
def _save_to_cache(
self,
documents: List[Document],
cache_index: Optional[str] = None,
cache_headers: Optional[Dict[str, str]] = None,
cache_time: Optional[int] = None,
) -> None:
"""
Save the documents to the cache and potentially delete old expired documents from the cache.
:param documents: List of Document objects to be saved to the cache.
:param cache_index: Optional index name to save the documents to.
:param cache_headers: Optional headers made to use when saving the documents to the cache.
:param cache_time: Optional time to live in seconds for the documents in the cache. If objects are older than
this time, they will be deleted from the cache.
"""
cache_document_store = self.cache_document_store
if cache_document_store is not None and documents:
cache_document_store.write_documents(
documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite"
)
if cache_document_store and cache_time is not None and cache_time > 0:
cache_filter: FilterType = {
"timestamp": {"$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())}
}
cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter)
logger.debug("Deleted documents in the cache using filter: %s", cache_filter)
def retrieve_batch( # type: ignore[override]
self,
queries: List[str],
@ -358,7 +297,8 @@ class WebRetriever(BaseRetriever):
DocumentStore is used.
:param cache_index: The index name to save the documents to. If None, the instance's default cache_index is used.
:param cache_headers: The headers to save the documents to. If None, the instance's default cache_headers is used.
:param cache_time: The time limit in seconds to check the cache. If None, the instance's default cache_time is used.
:param cache_time: The time limit in seconds for the documents in the cache.
:returns: A list of lists where each inner list represents the documents fetched for a particular query.
"""
documents = []
@ -370,8 +310,6 @@ class WebRetriever(BaseRetriever):
preprocessor=preprocessor,
cache_document_store=cache_document_store,
cache_index=cache_index,
cache_headers=cache_headers,
cache_time=cache_time,
)
)

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
The WebRetriever now employs an enhanced caching mechanism that caches web page content based on search engine
results rather than the query.

View File

@ -1,13 +1,11 @@
import os
from unittest.mock import MagicMock, patch, Mock
from unittest.mock import patch, Mock
from test.conftest import MockDocumentStore
import pytest
from haystack import Document, Pipeline
from haystack.document_stores.base import BaseDocumentStore
from haystack.nodes import WebRetriever, PromptNode
from haystack.nodes.retriever.link_content import html_content_handler
from haystack.nodes.preprocessor import PreProcessor
from haystack.nodes.retriever.web import SearchResult
from test.nodes.conftest import example_serperdev_response
@ -46,75 +44,45 @@ def test_init_default_parameters():
assert retriever.preprocessor is None
assert retriever.cache_document_store is None
assert retriever.cache_index is None
assert retriever.cache_headers is None
assert retriever.cache_time == 1 * 24 * 60 * 60
@pytest.mark.unit
def test_init_custom_parameters():
preprocessor = PreProcessor()
document_store = MagicMock(spec=BaseDocumentStore)
headers = {"Test": "Header"}
@pytest.mark.parametrize("mode", ["snippets", "raw_documents", "preprocessed_documents"])
@pytest.mark.parametrize("top_k", [1, 5, 7])
def test_retrieve_from_web_all_params(mock_web_search, mode, top_k):
"""
Test that the retriever returns the correct number of documents in all modes
"""
search_result_len = len(example_serperdev_response["organic"])
wr = WebRetriever(api_key="fake_key", top_k=top_k, mode=mode)
retriever = WebRetriever(
api_key="test_key",
search_engine_provider="SerperDev",
top_search_results=15,
top_k=7,
mode="preprocessed_documents",
preprocessor=preprocessor,
cache_document_store=document_store,
cache_index="custom_index",
cache_headers=headers,
cache_time=2 * 24 * 60 * 60,
)
docs = [Document("test" + str(i)) for i in range(search_result_len)]
with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=docs):
retrieved_docs = wr.retrieve(query="who is the boyfriend of olivia wilde?")
assert retriever.top_k == 7
assert retriever.mode == "preprocessed_documents"
assert retriever.preprocessor == preprocessor
assert retriever.cache_document_store == document_store
assert retriever.cache_index == "custom_index"
assert retriever.cache_headers == headers
assert retriever.cache_time == 2 * 24 * 60 * 60
@pytest.mark.unit
def test_retrieve_from_web_all_params(mock_web_search):
wr = WebRetriever(api_key="fake_key")
preprocessor = PreProcessor()
result = wr._retrieve_from_web(query_norm="who is the boyfriend of olivia wilde?", preprocessor=preprocessor)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)
assert len(result) == len(example_serperdev_response["organic"])
@pytest.mark.unit
def test_retrieve_from_web_no_preprocessor(mock_web_search):
# tests that we get top_k results when no PreProcessor is provided
wr = WebRetriever(api_key="fake_key")
result = wr._retrieve_from_web("query", None)
assert isinstance(result, list)
assert all(isinstance(doc, Document) for doc in result)
assert len(result) == len(example_serperdev_response["organic"])
assert isinstance(retrieved_docs, list)
assert all(isinstance(doc, Document) for doc in retrieved_docs)
assert len(retrieved_docs) == top_k
@pytest.mark.unit
def test_retrieve_from_web_invalid_query(mock_web_search):
# however, if query is None or empty, we expect an error
"""
Test that the retriever raises an error if the query is invalid
"""
wr = WebRetriever(api_key="fake_key")
with pytest.raises(ValueError, match="WebSearch run requires"):
wr._retrieve_from_web("", None)
wr.retrieve("")
with pytest.raises(ValueError, match="WebSearch run requires"):
wr._retrieve_from_web(None, None)
wr.retrieve(None)
@pytest.mark.unit
def test_prepare_links_empty_list():
"""
Test that the retriever's _prepare_links method returns an empty list if the input is an empty list
"""
wr = WebRetriever(api_key="fake_key")
result = wr._prepare_links([])
assert result == []
@ -125,8 +93,11 @@ def test_prepare_links_empty_list():
@pytest.mark.unit
def test_scrape_links_empty_list():
"""
Test that the retriever's _scrape_links method returns an empty list if the input is an empty list
"""
wr = WebRetriever(api_key="fake_key")
result = wr._scrape_links([], "query", None)
result = wr._scrape_links([])
assert result == []
@ -134,13 +105,16 @@ def test_scrape_links_empty_list():
def test_scrape_links_with_search_results(
mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type
):
"""
Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults
"""
wr = WebRetriever(api_key="fake_key")
sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1")
sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2")
sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1")
sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2")
fake_search_results = [sr1, sr2]
result = wr._scrape_links(fake_search_results, "query", None)
result = wr._scrape_links(fake_search_results)
assert isinstance(result, list)
assert all(isinstance(r, Document) for r in result)
@ -151,14 +125,17 @@ def test_scrape_links_with_search_results(
def test_scrape_links_with_search_results_with_preprocessor(
mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type
):
"""
Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults
and a preprocessor is provided
"""
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
preprocessor = PreProcessor(progress_bar=False)
sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1")
sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2")
sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1")
sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2")
fake_search_results = [sr1, sr2]
result = wr._scrape_links(fake_search_results, "query", preprocessor)
result = wr._scrape_links(fake_search_results)
assert isinstance(result, list)
assert all(isinstance(r, Document) for r in result)
@ -168,33 +145,44 @@ def test_scrape_links_with_search_results_with_preprocessor(
@pytest.mark.unit
def test_retrieve_uses_defaults():
wr = WebRetriever(api_key="fake_key")
def test_retrieve_checks_cache(mock_web_search):
"""
Test that the retriever's retrieve method checks the cache
"""
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
with patch.object(wr, "_retrieve_from_web", return_value=[]) as mock_retrieve_from_web:
wr.retrieve("query")
with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache:
wr.retrieve("query")
# cache is checked first, always
mock_check_cache.assert_called_with(
"query", cache_index=wr.cache_index, cache_headers=wr.cache_headers, cache_time=wr.cache_time
)
mock_retrieve_from_web.assert_called_with("query", wr.preprocessor)
# assert cache is checked
mock_check_cache.assert_called()
@pytest.mark.unit
def test_retrieve_batch():
queries = ["query1", "query2"]
wr = WebRetriever(api_key="fake_key")
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web:
result = wr.retrieve_batch(queries)
def test_retrieve_no_cache_checks_in_snippet_mode(mock_web_search):
"""
Test that the retriever's retrieve method does not check the cache if the mode is snippets
"""
wr = WebRetriever(api_key="fake_key", mode="snippets")
with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache:
wr.retrieve("query")
# assert cache is NOT checked
mock_check_cache.assert_not_called()
@pytest.mark.unit
def test_retrieve_batch(mock_web_search):
"""
Test that the retriever's retrieve_batch method returns a list of lists of Documents
"""
queries = ["query1", "query2"]
wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents")
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=web_docs):
result = wr.retrieve_batch(queries)
assert mock_check_cache.call_count == len(queries)
assert mock_retrieve_from_web.call_count == len(queries)
# check that the result is a list of lists of Documents
# where each list of Documents is the result of a single query
assert len(result) == len(queries)
# check that the result is a list of lists of Documents
@ -207,63 +195,44 @@ def test_retrieve_batch():
@pytest.mark.unit
def test_retrieve_uses_cache():
wr = WebRetriever(api_key="fake_key")
def test_retrieve_uses_cache(mock_web_search):
"""
Test that the retriever's retrieve method uses the cache if it is available
"""
wr = WebRetriever(api_key="fake_key", mode="raw_documents", cache_document_store=MockDocumentStore())
cached_links = [
SearchResult("https://pagesix.com", "Some text", 0.43, "1"),
SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"),
]
cached_docs = [Document("doc1"), Document("doc2")]
with patch.object(wr, "_check_cache", return_value=cached_docs) as mock_check_cache:
with patch.object(wr, "_retrieve_from_web") as mock_retrieve_from_web:
with patch.object(wr, "_save_cache") as mock_save_cache:
with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache:
with patch.object(wr, "_save_to_cache") as mock_save_cache:
with patch.object(wr, "_scrape_links", return_value=[]):
result = wr.retrieve("query")
# checking cache is always called
mock_check_cache.assert_called()
# these methods are not called because we found docs in cache
mock_retrieve_from_web.assert_not_called()
mock_save_cache.assert_not_called()
# cache save is called but with empty list of documents
mock_save_cache.assert_called()
assert mock_save_cache.call_args[0][0] == []
assert result == cached_docs
@pytest.mark.unit
def test_retrieve_saves_to_cache():
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore())
def test_retrieve_saves_to_cache(mock_web_search):
"""
Test that the retriever's retrieve method saves to the cache if it is available
"""
wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents")
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache:
with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web:
with patch.object(wr, "_save_cache") as mock_save_cache:
result = wr.retrieve("query")
with patch.object(wr, "_save_to_cache") as mock_save_cache:
with patch.object(wr, "_scrape_links", return_value=web_docs):
wr.retrieve("query")
mock_check_cache.assert_called()
# cache is empty, so we call _retrieve_from_web
mock_retrieve_from_web.assert_called()
# and save the results to cache
mock_save_cache.assert_called_with("query", web_docs, cache_index=wr.cache_index, cache_headers=wr.cache_headers)
assert result == web_docs
@pytest.mark.unit
def test_retrieve_returns_top_k():
wr = WebRetriever(api_key="", top_k=2)
with patch.object(wr, "_check_cache", return_value=[]):
web_docs = [Document("doc1"), Document("doc2"), Document("doc3")]
with patch.object(wr, "_retrieve_from_web", return_value=web_docs):
result = wr.retrieve("query")
assert result == web_docs[:2]
@pytest.mark.unit
@pytest.mark.parametrize("top_k", [1, 3, 6])
def test_top_k_parameter(mock_web_search, top_k):
web_retriever = WebRetriever(api_key="some_invalid_key", mode="snippets")
result = web_retriever.retrieve(query="Who is the boyfriend of Olivia Wilde?", top_k=top_k)
assert len(result) == top_k
assert all(isinstance(doc, Document) for doc in result)
mock_save_cache.assert_called()
@pytest.mark.integration
@ -277,7 +246,9 @@ def test_top_k_parameter(mock_web_search, top_k):
)
@pytest.mark.parametrize("top_k", [2, 4])
def test_top_k_parameter_in_pipeline(top_k):
# test that WebRetriever top_k param is NOT ignored in a pipeline
"""
Test that the top_k parameter works in the pipeline
"""
prompt_node = PromptNode(
"gpt-3.5-turbo",
api_key=os.environ.get("OPENAI_API_KEY"),
@ -293,20 +264,3 @@ def test_top_k_parameter_in_pipeline(top_k):
pipe.add_node(component=prompt_node, name="QAwithScoresPrompt", inputs=["WebRetriever"])
result = pipe.run(query="What year was Obama president", params={"WebRetriever": {"top_k": top_k}})
assert len(result["results"]) == top_k
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("SERPERDEV_API_KEY", None),
reason="Please export an env var called SERPERDEV_API_KEY containing the serper.dev API key to run this test.",
)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.skip
def test_web_retriever_speed():
retriever = WebRetriever(api_key=os.environ.get("SERPERDEV_API_KEY"), mode="preprocessed_documents")
result = retriever.retrieve(query="What's the meaning of it all?")
assert len(result) >= 5
assert all(isinstance(doc, Document) for doc in result)