diff --git a/haystack/nodes/retriever/web.py b/haystack/nodes/retriever/web.py index 14b56370c..106ba7cf6 100644 --- a/haystack/nodes/retriever/web.py +++ b/haystack/nodes/retriever/web.py @@ -3,10 +3,9 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime, timedelta from multiprocessing import cpu_count -from typing import Dict, Iterator, List, Optional, Literal, Union -from unicodedata import combining, normalize +from typing import Dict, Iterator, List, Optional, Literal, Union, Tuple -from haystack import Document +from haystack.schema import Document from haystack.document_stores.base import BaseDocumentStore from haystack.nodes.preprocessor import PreProcessor from haystack.nodes.retriever.base import BaseRetriever @@ -80,106 +79,12 @@ class WebRetriever(BaseRetriever): self.cache_document_store = cache_document_store self.document_store = cache_document_store self.cache_index = cache_index + self.top_k = top_k self.cache_headers = cache_headers self.cache_time = cache_time - self.top_k = top_k - self.preprocessor = None - if preprocessor is not None: - self.preprocessor = preprocessor - elif mode == "preprocessed_documents": - self.preprocessor = PreProcessor(progress_bar=False) - - def _normalize_query(self, query: str) -> str: - return "".join([c for c in normalize("NFKD", query.lower()) if not combining(c)]) - - def _check_cache( - self, - query: str, - cache_index: Optional[str] = None, - cache_headers: Optional[Dict[str, str]] = None, - cache_time: Optional[int] = None, - ) -> List[Document]: - """ - Private method to check if the documents for a given query are already cached. The documents are fetched from - the specified DocumentStore. It retrieves documents that are newer than the cache_time limit. - - :param query: The query string to check in the cache. - :param cache_index: Optional index name in the DocumentStore to fetch the documents. Defaults to the instance's - cache_index. - :param cache_headers: Optional headers to be used when fetching documents from the DocumentStore. Defaults to - the instance's cache_headers. - :param cache_time: Optional time limit in seconds to check the cache. Only documents newer than cache_time are - returned. Defaults to the instance's cache_time. - :returns: A list of Document instances fetched from the cache. If no documents are found in the cache, an empty - list is returned. - """ - cache_document_store = self.cache_document_store - documents = [] - - if cache_document_store is not None: - query_norm = self._normalize_query(query) - cache_filter: FilterType = {"$and": {"search.query": query_norm}} - - if cache_time is not None and cache_time > 0: - cache_filter["timestamp"] = { - "$gt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp()) - } - logger.debug("Cache filter: %s", cache_filter) - - documents = cache_document_store.get_all_documents( - filters=cache_filter, index=cache_index, headers=cache_headers, return_embedding=False - ) - - logger.debug("Found %d documents in cache", len(documents)) - - return documents - - def _save_cache( - self, - query: str, - documents: List[Document], - cache_index: Optional[str] = None, - cache_headers: Optional[Dict[str, str]] = None, - cache_time: Optional[int] = None, - ) -> bool: - """ - Private method to cache the retrieved documents for a given query. - The documents are saved in the specified DocumentStore. If the same document already exists, it is - overwritten. - - :param query: The query string for which the documents are being cached. - :param documents: The list of Document instances to be cached. - :param cache_index: Optional index name in the DocumentStore to save the documents. Defaults to the - instance's cache_index. - :param cache_headers: Optional headers to be used when saving documents in the DocumentStore. Defaults to - the instance's cache_headers. - :param cache_time: Optional time limit in seconds to check the cache. Documents older than the - cache_time are deleted. Defaults to the instance's cache_time. - :returns: True if the documents are successfully saved in the cache, False otherwise. - """ - cache_document_store = self.cache_document_store - - if cache_document_store is not None: - cache_document_store.write_documents( - documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite" - ) - - logger.debug("Saved %d documents in the cache", len(documents)) - - cache_filter: FilterType = {"$and": {"search.query": query}} - - if cache_time is not None and cache_time > 0: - cache_filter["timestamp"] = { - "$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp()) - } - - cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter) - - logger.debug("Deleted documents in the cache using filter: %s", cache_filter) - - return True - - return False + self.preprocessor = ( + preprocessor or PreProcessor(progress_bar=False) if mode == "preprocessed_documents" else None + ) def retrieve( # type: ignore[override] self, @@ -201,7 +106,7 @@ class WebRetriever(BaseRetriever): Optionally, the retrieved documents can be stored in a DocumentStore for future use, saving time and resources on repeated retrievals. This caching mechanism can significantly improve retrieval times - for frequently accessed information. + for frequently accessed URLs. :param query: The query string. :param top_k: The number of Documents to be returned by the retriever. @@ -209,56 +114,45 @@ class WebRetriever(BaseRetriever): :param cache_document_store: The DocumentStore to cache the documents to. :param cache_index: The index name to save the documents to. :param cache_headers: The headers to save the documents to. - :param cache_time: The time limit in seconds to check the cache. The default is 24 hours. + :param cache_time: The time limit in seconds for the documents in the cache. If objects are older than this time, + they will be deleted from the cache on the next retrieval. """ # Initialize default parameters preprocessor = preprocessor or self.preprocessor - cache_document_store = cache_document_store or self.cache_document_store cache_index = cache_index or self.cache_index + top_k = top_k or self.top_k cache_headers = cache_headers or self.cache_headers cache_time = cache_time or self.cache_time - top_k = top_k or self.top_k - # Normalize query - query_norm = self._normalize_query(query) + search_results, _ = self.web_search.run(query=query) + result_docs = search_results["documents"] - # Check cache for query - extracted_docs = self._check_cache( - query_norm, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time - ) + if self.mode != "snippets": + # for raw_documents and preprocessed_documents modes, we need to retrieve the links from the search results + links: List[SearchResult] = self._prepare_links(result_docs) - # If query is not cached, fetch from web - if not extracted_docs: - extracted_docs = self._retrieve_from_web(query_norm, preprocessor) + links_found_in_cache, cached_docs = self._check_cache(links) + logger.debug("Found %d links in cache", len(links_found_in_cache)) - # Save results to cache - if cache_document_store and extracted_docs: - cached = self._save_cache(query_norm, extracted_docs, cache_index=cache_index, cache_headers=cache_headers) - if not cached: - logger.warning( - "Could not save retrieved Documents to the DocumentStore cache. " - "Check your DocumentStore configuration." - ) - return extracted_docs[:top_k] + links_to_fetch = [link for link in links if link not in links_found_in_cache] + logger.debug("Fetching %d links", len(links_to_fetch)) + result_docs = self._scrape_links(links_to_fetch) - def _retrieve_from_web(self, query_norm: str, preprocessor: Optional[PreProcessor]) -> List[Document]: - """ - Retrieve Documents from the web based on the query. + # Save result_docs to cache + self._save_to_cache( + result_docs, cache_index=cache_index, cache_headers=cache_headers, cache_time=cache_time + ) - :param query_norm: The normalized query string. - :param preprocessor: The PreProcessor to be used to split documents into paragraphs. - :return: List of Document objects. - """ + # join cached_docs and result_docs + result_docs = cached_docs + result_docs - search_results, _ = self.web_search.run(query=query_norm) - search_results_docs = search_results["documents"] - if self.mode == "snippets": - return search_results_docs - else: - links: List[SearchResult] = self._prepare_links(search_results_docs) - logger.debug("Starting to fetch %d links from WebSearch results", len(links)) - return self._scrape_links(links, query_norm, preprocessor) + # Preprocess documents + if preprocessor: + result_docs = preprocessor.process(result_docs) + + # Return results + return result_docs[:top_k] def _prepare_links(self, search_results: List[Document]) -> List[SearchResult]: """ @@ -277,38 +171,29 @@ class WebRetriever(BaseRetriever): ] return links - def _scrape_links( - self, links: List[SearchResult], query_norm: str, preprocessor: Optional[PreProcessor] - ) -> List[Document]: + def _scrape_links(self, links: List[SearchResult]) -> List[Document]: """ Scrape the links and return the documents. :param links: List of SearchResult objects. - :param query_norm: The normalized query string. - :param preprocessor: The PreProcessor object to be used to split documents into paragraphs. - :return: List of Document objects obtained from scraping the links. + :return: List of Document objects obtained by fetching the content from the links. """ if not links: return [] - fetcher = ( - LinkContentFetcher(processor=preprocessor, raise_on_failure=True) - if self.mode == "preprocessed_documents" and preprocessor - else LinkContentFetcher(raise_on_failure=True) - ) + fetcher = LinkContentFetcher(raise_on_failure=True) - def scrape_link_content(link: SearchResult) -> List[Document]: + def link_fetch(link: SearchResult) -> List[Document]: """ - Encapsulate the link scraping logic in a function to be used in a ThreadPoolExecutor. + Encapsulate the link fetching logic in a function to be used in a ThreadPoolExecutor. """ docs: List[Document] = [] try: docs = fetcher.fetch( url=link.url, doc_kwargs={ + "id_hash_keys": ["meta.url"], "search.score": link.score, - "id_hash_keys": ["meta.url", "meta.search.query"], - "search.query": query_norm, "search.position": link.position, "snippet_text": link.snippet, }, @@ -319,18 +204,72 @@ class WebRetriever(BaseRetriever): return docs - thread_count = cpu_count() if len(links) > cpu_count() else len(links) + thread_count = min(cpu_count() if len(links) > cpu_count() else len(links), 10) # max 10 threads with ThreadPoolExecutor(max_workers=thread_count) as executor: - scraped_pages: Iterator[List[Document]] = executor.map(scrape_link_content, links) + fetched_pages: Iterator[List[Document]] = executor.map(link_fetch, links) # Flatten list of lists to a single list - extracted_docs = [doc for doc_list in scraped_pages for doc in doc_list] + extracted_docs = [doc for doc_list in fetched_pages for doc in doc_list] # Sort by score extracted_docs = sorted(extracted_docs, key=lambda x: x.meta["search.score"], reverse=True) return extracted_docs + def _check_cache(self, links: List[SearchResult]) -> Tuple[List[SearchResult], List[Document]]: + """ + Check the DocumentStore cache for documents. + + :param links: List of SearchResult objects. + :return: Tuple of lists of SearchResult and Document objects that were found in the cache. + """ + if not links or not self.cache_document_store: + return [], [] + + cache_documents: List[Document] = [] + cached_links: List[SearchResult] = [] + + valid_links = [link for link in links if link.url] + for link in valid_links: + cache_filter: FilterType = {"url": link.url} + documents = self.cache_document_store.get_all_documents(filters=cache_filter, return_embedding=False) + if documents: + cache_documents.extend(documents) + cached_links.append(link) + + return cached_links, cache_documents + + def _save_to_cache( + self, + documents: List[Document], + cache_index: Optional[str] = None, + cache_headers: Optional[Dict[str, str]] = None, + cache_time: Optional[int] = None, + ) -> None: + """ + Save the documents to the cache and potentially delete old expired documents from the cache. + + :param documents: List of Document objects to be saved to the cache. + :param cache_index: Optional index name to save the documents to. + :param cache_headers: Optional headers made to use when saving the documents to the cache. + :param cache_time: Optional time to live in seconds for the documents in the cache. If objects are older than + this time, they will be deleted from the cache. + """ + cache_document_store = self.cache_document_store + + if cache_document_store is not None and documents: + cache_document_store.write_documents( + documents=documents, index=cache_index, headers=cache_headers, duplicate_documents="overwrite" + ) + + if cache_document_store and cache_time is not None and cache_time > 0: + cache_filter: FilterType = { + "timestamp": {"$lt": int((datetime.utcnow() - timedelta(seconds=cache_time)).timestamp())} + } + + cache_document_store.delete_documents(index=cache_index, headers=cache_headers, filters=cache_filter) + logger.debug("Deleted documents in the cache using filter: %s", cache_filter) + def retrieve_batch( # type: ignore[override] self, queries: List[str], @@ -358,7 +297,8 @@ class WebRetriever(BaseRetriever): DocumentStore is used. :param cache_index: The index name to save the documents to. If None, the instance's default cache_index is used. :param cache_headers: The headers to save the documents to. If None, the instance's default cache_headers is used. - :param cache_time: The time limit in seconds to check the cache. If None, the instance's default cache_time is used. + :param cache_time: The time limit in seconds for the documents in the cache. + :returns: A list of lists where each inner list represents the documents fetched for a particular query. """ documents = [] @@ -370,8 +310,6 @@ class WebRetriever(BaseRetriever): preprocessor=preprocessor, cache_document_store=cache_document_store, cache_index=cache_index, - cache_headers=cache_headers, - cache_time=cache_time, ) ) diff --git a/releasenotes/notes/adjust-web-retriever-caching-logic-2e05fbc972a86f29.yaml b/releasenotes/notes/adjust-web-retriever-caching-logic-2e05fbc972a86f29.yaml new file mode 100644 index 000000000..41f6bc498 --- /dev/null +++ b/releasenotes/notes/adjust-web-retriever-caching-logic-2e05fbc972a86f29.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + The WebRetriever now employs an enhanced caching mechanism that caches web page content based on search engine + results rather than the query. diff --git a/test/nodes/test_web_retriever.py b/test/nodes/test_web_retriever.py index 73961b33e..c7a68b5d7 100644 --- a/test/nodes/test_web_retriever.py +++ b/test/nodes/test_web_retriever.py @@ -1,13 +1,11 @@ import os -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import patch, Mock from test.conftest import MockDocumentStore import pytest from haystack import Document, Pipeline -from haystack.document_stores.base import BaseDocumentStore from haystack.nodes import WebRetriever, PromptNode from haystack.nodes.retriever.link_content import html_content_handler -from haystack.nodes.preprocessor import PreProcessor from haystack.nodes.retriever.web import SearchResult from test.nodes.conftest import example_serperdev_response @@ -46,75 +44,45 @@ def test_init_default_parameters(): assert retriever.preprocessor is None assert retriever.cache_document_store is None assert retriever.cache_index is None - assert retriever.cache_headers is None - assert retriever.cache_time == 1 * 24 * 60 * 60 @pytest.mark.unit -def test_init_custom_parameters(): - preprocessor = PreProcessor() - document_store = MagicMock(spec=BaseDocumentStore) - headers = {"Test": "Header"} +@pytest.mark.parametrize("mode", ["snippets", "raw_documents", "preprocessed_documents"]) +@pytest.mark.parametrize("top_k", [1, 5, 7]) +def test_retrieve_from_web_all_params(mock_web_search, mode, top_k): + """ + Test that the retriever returns the correct number of documents in all modes + """ + search_result_len = len(example_serperdev_response["organic"]) + wr = WebRetriever(api_key="fake_key", top_k=top_k, mode=mode) - retriever = WebRetriever( - api_key="test_key", - search_engine_provider="SerperDev", - top_search_results=15, - top_k=7, - mode="preprocessed_documents", - preprocessor=preprocessor, - cache_document_store=document_store, - cache_index="custom_index", - cache_headers=headers, - cache_time=2 * 24 * 60 * 60, - ) + docs = [Document("test" + str(i)) for i in range(search_result_len)] + with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=docs): + retrieved_docs = wr.retrieve(query="who is the boyfriend of olivia wilde?") - assert retriever.top_k == 7 - assert retriever.mode == "preprocessed_documents" - assert retriever.preprocessor == preprocessor - assert retriever.cache_document_store == document_store - assert retriever.cache_index == "custom_index" - assert retriever.cache_headers == headers - assert retriever.cache_time == 2 * 24 * 60 * 60 - - -@pytest.mark.unit -def test_retrieve_from_web_all_params(mock_web_search): - wr = WebRetriever(api_key="fake_key") - - preprocessor = PreProcessor() - - result = wr._retrieve_from_web(query_norm="who is the boyfriend of olivia wilde?", preprocessor=preprocessor) - - assert isinstance(result, list) - assert all(isinstance(doc, Document) for doc in result) - assert len(result) == len(example_serperdev_response["organic"]) - - -@pytest.mark.unit -def test_retrieve_from_web_no_preprocessor(mock_web_search): - # tests that we get top_k results when no PreProcessor is provided - wr = WebRetriever(api_key="fake_key") - result = wr._retrieve_from_web("query", None) - - assert isinstance(result, list) - assert all(isinstance(doc, Document) for doc in result) - assert len(result) == len(example_serperdev_response["organic"]) + assert isinstance(retrieved_docs, list) + assert all(isinstance(doc, Document) for doc in retrieved_docs) + assert len(retrieved_docs) == top_k @pytest.mark.unit def test_retrieve_from_web_invalid_query(mock_web_search): - # however, if query is None or empty, we expect an error + """ + Test that the retriever raises an error if the query is invalid + """ wr = WebRetriever(api_key="fake_key") with pytest.raises(ValueError, match="WebSearch run requires"): - wr._retrieve_from_web("", None) + wr.retrieve("") with pytest.raises(ValueError, match="WebSearch run requires"): - wr._retrieve_from_web(None, None) + wr.retrieve(None) @pytest.mark.unit def test_prepare_links_empty_list(): + """ + Test that the retriever's _prepare_links method returns an empty list if the input is an empty list + """ wr = WebRetriever(api_key="fake_key") result = wr._prepare_links([]) assert result == [] @@ -125,8 +93,11 @@ def test_prepare_links_empty_list(): @pytest.mark.unit def test_scrape_links_empty_list(): + """ + Test that the retriever's _scrape_links method returns an empty list if the input is an empty list + """ wr = WebRetriever(api_key="fake_key") - result = wr._scrape_links([], "query", None) + result = wr._scrape_links([]) assert result == [] @@ -134,13 +105,16 @@ def test_scrape_links_empty_list(): def test_scrape_links_with_search_results( mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type ): + """ + Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults + """ wr = WebRetriever(api_key="fake_key") - sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1") - sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2") + sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1") + sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2") fake_search_results = [sr1, sr2] - result = wr._scrape_links(fake_search_results, "query", None) + result = wr._scrape_links(fake_search_results) assert isinstance(result, list) assert all(isinstance(r, Document) for r in result) @@ -151,14 +125,17 @@ def test_scrape_links_with_search_results( def test_scrape_links_with_search_results_with_preprocessor( mocked_requests, mocked_article_extractor, mocked_link_content_fetcher_handler_type ): + """ + Test that the retriever's _scrape_links method returns a list of Documents if the input is a list of SearchResults + and a preprocessor is provided + """ wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents") - preprocessor = PreProcessor(progress_bar=False) - sr1 = SearchResult("https://pagesix.com", "Some text", "0.43", "1") - sr2 = SearchResult("https://www.yahoo.com/", "Some text", "0.43", "2") + sr1 = SearchResult("https://pagesix.com", "Some text", 0.43, "1") + sr2 = SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2") fake_search_results = [sr1, sr2] - result = wr._scrape_links(fake_search_results, "query", preprocessor) + result = wr._scrape_links(fake_search_results) assert isinstance(result, list) assert all(isinstance(r, Document) for r in result) @@ -168,33 +145,44 @@ def test_scrape_links_with_search_results_with_preprocessor( @pytest.mark.unit -def test_retrieve_uses_defaults(): - wr = WebRetriever(api_key="fake_key") +def test_retrieve_checks_cache(mock_web_search): + """ + Test that the retriever's retrieve method checks the cache + """ + wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents") - with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache: - with patch.object(wr, "_retrieve_from_web", return_value=[]) as mock_retrieve_from_web: - wr.retrieve("query") + with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache: + wr.retrieve("query") - # cache is checked first, always - mock_check_cache.assert_called_with( - "query", cache_index=wr.cache_index, cache_headers=wr.cache_headers, cache_time=wr.cache_time - ) - mock_retrieve_from_web.assert_called_with("query", wr.preprocessor) + # assert cache is checked + mock_check_cache.assert_called() @pytest.mark.unit -def test_retrieve_batch(): - queries = ["query1", "query2"] - wr = WebRetriever(api_key="fake_key") - web_docs = [Document("doc1"), Document("doc2"), Document("doc3")] - with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache: - with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web: - result = wr.retrieve_batch(queries) +def test_retrieve_no_cache_checks_in_snippet_mode(mock_web_search): + """ + Test that the retriever's retrieve method does not check the cache if the mode is snippets + """ + wr = WebRetriever(api_key="fake_key", mode="snippets") + + with patch.object(wr, "_check_cache", return_value=([], [])) as mock_check_cache: + wr.retrieve("query") + + # assert cache is NOT checked + mock_check_cache.assert_not_called() + + +@pytest.mark.unit +def test_retrieve_batch(mock_web_search): + """ + Test that the retriever's retrieve_batch method returns a list of lists of Documents + """ + queries = ["query1", "query2"] + wr = WebRetriever(api_key="fake_key", mode="preprocessed_documents") + web_docs = [Document("doc1"), Document("doc2"), Document("doc3")] + with patch("haystack.nodes.retriever.web.WebRetriever._scrape_links", return_value=web_docs): + result = wr.retrieve_batch(queries) - assert mock_check_cache.call_count == len(queries) - assert mock_retrieve_from_web.call_count == len(queries) - # check that the result is a list of lists of Documents - # where each list of Documents is the result of a single query assert len(result) == len(queries) # check that the result is a list of lists of Documents @@ -207,63 +195,44 @@ def test_retrieve_batch(): @pytest.mark.unit -def test_retrieve_uses_cache(): - wr = WebRetriever(api_key="fake_key") +def test_retrieve_uses_cache(mock_web_search): + """ + Test that the retriever's retrieve method uses the cache if it is available + """ + wr = WebRetriever(api_key="fake_key", mode="raw_documents", cache_document_store=MockDocumentStore()) + cached_links = [ + SearchResult("https://pagesix.com", "Some text", 0.43, "1"), + SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"), + ] cached_docs = [Document("doc1"), Document("doc2")] - with patch.object(wr, "_check_cache", return_value=cached_docs) as mock_check_cache: - with patch.object(wr, "_retrieve_from_web") as mock_retrieve_from_web: - with patch.object(wr, "_save_cache") as mock_save_cache: + with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache: + with patch.object(wr, "_save_to_cache") as mock_save_cache: + with patch.object(wr, "_scrape_links", return_value=[]): result = wr.retrieve("query") # checking cache is always called mock_check_cache.assert_called() - # these methods are not called because we found docs in cache - mock_retrieve_from_web.assert_not_called() - mock_save_cache.assert_not_called() - + # cache save is called but with empty list of documents + mock_save_cache.assert_called() + assert mock_save_cache.call_args[0][0] == [] assert result == cached_docs @pytest.mark.unit -def test_retrieve_saves_to_cache(): - wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore()) +def test_retrieve_saves_to_cache(mock_web_search): + """ + Test that the retriever's retrieve method saves to the cache if it is available + """ + wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents") web_docs = [Document("doc1"), Document("doc2"), Document("doc3")] - with patch.object(wr, "_check_cache", return_value=[]) as mock_check_cache: - with patch.object(wr, "_retrieve_from_web", return_value=web_docs) as mock_retrieve_from_web: - with patch.object(wr, "_save_cache") as mock_save_cache: - result = wr.retrieve("query") + with patch.object(wr, "_save_to_cache") as mock_save_cache: + with patch.object(wr, "_scrape_links", return_value=web_docs): + wr.retrieve("query") - mock_check_cache.assert_called() - - # cache is empty, so we call _retrieve_from_web - mock_retrieve_from_web.assert_called() - # and save the results to cache - mock_save_cache.assert_called_with("query", web_docs, cache_index=wr.cache_index, cache_headers=wr.cache_headers) - assert result == web_docs - - -@pytest.mark.unit -def test_retrieve_returns_top_k(): - wr = WebRetriever(api_key="", top_k=2) - - with patch.object(wr, "_check_cache", return_value=[]): - web_docs = [Document("doc1"), Document("doc2"), Document("doc3")] - with patch.object(wr, "_retrieve_from_web", return_value=web_docs): - result = wr.retrieve("query") - - assert result == web_docs[:2] - - -@pytest.mark.unit -@pytest.mark.parametrize("top_k", [1, 3, 6]) -def test_top_k_parameter(mock_web_search, top_k): - web_retriever = WebRetriever(api_key="some_invalid_key", mode="snippets") - result = web_retriever.retrieve(query="Who is the boyfriend of Olivia Wilde?", top_k=top_k) - assert len(result) == top_k - assert all(isinstance(doc, Document) for doc in result) + mock_save_cache.assert_called() @pytest.mark.integration @@ -277,7 +246,9 @@ def test_top_k_parameter(mock_web_search, top_k): ) @pytest.mark.parametrize("top_k", [2, 4]) def test_top_k_parameter_in_pipeline(top_k): - # test that WebRetriever top_k param is NOT ignored in a pipeline + """ + Test that the top_k parameter works in the pipeline + """ prompt_node = PromptNode( "gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), @@ -293,20 +264,3 @@ def test_top_k_parameter_in_pipeline(top_k): pipe.add_node(component=prompt_node, name="QAwithScoresPrompt", inputs=["WebRetriever"]) result = pipe.run(query="What year was Obama president", params={"WebRetriever": {"top_k": top_k}}) assert len(result["results"]) == top_k - - -@pytest.mark.integration -@pytest.mark.skipif( - not os.environ.get("SERPERDEV_API_KEY", None), - reason="Please export an env var called SERPERDEV_API_KEY containing the serper.dev API key to run this test.", -) -@pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", -) -@pytest.mark.skip -def test_web_retriever_speed(): - retriever = WebRetriever(api_key=os.environ.get("SERPERDEV_API_KEY"), mode="preprocessed_documents") - result = retriever.retrieve(query="What's the meaning of it all?") - assert len(result) >= 5 - assert all(isinstance(doc, Document) for doc in result)