From 4da43b6b057e7a2022d05f4cd1749a5a2a0610ee Mon Sep 17 00:00:00 2001 From: MichelBartels Date: Mon, 25 Sep 2023 10:03:01 +0200 Subject: [PATCH] Add link output to `SerperDevWebSearch` (#5853) * add link output * adjust tests * fix test * remove print statements --------- Co-authored-by: ZanSara --- .../components/websearch/serper_dev.py | 9 +++++---- .../components/websearch/test_serperdev.py | 20 +++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/haystack/preview/components/websearch/serper_dev.py b/haystack/preview/components/websearch/serper_dev.py index 265fd3b24..c7f6b4595 100644 --- a/haystack/preview/components/websearch/serper_dev.py +++ b/haystack/preview/components/websearch/serper_dev.py @@ -65,13 +65,12 @@ class SerperDevWebSearch: """ return default_from_dict(cls, data) - @component.output_types(documents=List[Document]) + @component.output_types(documents=List[Document], links=List[str]) def run(self, query: str): """ - Search the SerperDev API for the given query and return the results as a list of Documents. + Search the SerperDev API for the given query and return the results as a list of Documents and a list of links. :param query: Query string. - :return: List[Document] """ query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" @@ -136,5 +135,7 @@ class SerperDevWebSearch: documents = answer_box + organic + people_also_ask + links = [result["link"] for result in json_result["organic"]] + logger.debug("Serper Dev returned %s documents for the query '%s'", len(documents), query) - return {"documents": documents[: self.top_k]} + return {"documents": documents[: self.top_k], "links": links[: self.top_k]} diff --git a/test/preview/components/websearch/test_serperdev.py b/test/preview/components/websearch/test_serperdev.py index 605003673..4194da81a 100644 --- a/test/preview/components/websearch/test_serperdev.py +++ b/test/preview/components/websearch/test_serperdev.py @@ -145,9 +145,13 @@ class TestSerperDevSearchAPI: @pytest.mark.parametrize("top_k", [1, 5, 7]) def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int): ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k) - results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"] - assert len(results) == top_k - assert all(isinstance(doc, Document) for doc in results) + results = ws.run(query="Who is the boyfriend of Olivia Wilde?") + documents = results["documents"] + links = results["links"] + assert len(documents) == len(links) == top_k + assert all(isinstance(doc, Document) for doc in documents) + assert all(isinstance(link, str) for link in links) + assert all(link.startswith("http") for link in links) @pytest.mark.unit @patch("requests.post") @@ -183,8 +187,12 @@ class TestSerperDevSearchAPI: reason="Export an env var called SERPERDEV_API_KEY containing the SerperDev API key to run this test.", ) @pytest.mark.integration - def test_web_search(): + def test_web_search(self): ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10) - results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"] - assert len(results) == 10 + results = ws.run(query="Who is the boyfriend of Olivia Wilde?") + documents = results["documents"] + links = results["documents"] + assert len(documents) == len(links) == 10 assert all(isinstance(doc, Document) for doc in results) + assert all(isinstance(link, str) for link in links) + assert all(link.startswith("http") for link in links)