Add link output to SerperDevWebSearch (#5853)

* add link output

* adjust tests

* fix test

* remove print statements

---------

Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
MichelBartels 2023-09-25 10:03:01 +02:00 committed by GitHub
parent c0f22372d4
commit 4da43b6b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 10 deletions

View File

@ -65,13 +65,12 @@ class SerperDevWebSearch:
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
@component.output_types(documents=List[Document], links=List[str])
def run(self, query: str):
"""
Search the SerperDev API for the given query and return the results as a list of Documents.
Search the SerperDev API for the given query and return the results as a list of Documents and a list of links.
:param query: Query string.
:return: List[Document]
"""
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
@ -136,5 +135,7 @@ class SerperDevWebSearch:
documents = answer_box + organic + people_also_ask
links = [result["link"] for result in json_result["organic"]]
logger.debug("Serper Dev returned %s documents for the query '%s'", len(documents), query)
return {"documents": documents[: self.top_k]}
return {"documents": documents[: self.top_k], "links": links[: self.top_k]}

View File

@ -145,9 +145,13 @@ class TestSerperDevSearchAPI:
@pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"]
assert len(results) == top_k
assert all(isinstance(doc, Document) for doc in results)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"]
links = results["links"]
assert len(documents) == len(links) == top_k
assert all(isinstance(doc, Document) for doc in documents)
assert all(isinstance(link, str) for link in links)
assert all(link.startswith("http") for link in links)
@pytest.mark.unit
@patch("requests.post")
@ -183,8 +187,12 @@ class TestSerperDevSearchAPI:
reason="Export an env var called SERPERDEV_API_KEY containing the SerperDev API key to run this test.",
)
@pytest.mark.integration
def test_web_search():
def test_web_search(self):
ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"]
assert len(results) == 10
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"]
links = results["documents"]
assert len(documents) == len(links) == 10
assert all(isinstance(doc, Document) for doc in results)
assert all(isinstance(link, str) for link in links)
assert all(link.startswith("http") for link in links)