mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-30 01:09:43 +00:00
Add link output to SerperDevWebSearch (#5853)
* add link output * adjust tests * fix test * remove print statements --------- Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
c0f22372d4
commit
4da43b6b05
@ -65,13 +65,12 @@ class SerperDevWebSearch:
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
@component.output_types(documents=List[Document], links=List[str])
|
||||
def run(self, query: str):
|
||||
"""
|
||||
Search the SerperDev API for the given query and return the results as a list of Documents.
|
||||
Search the SerperDev API for the given query and return the results as a list of Documents and a list of links.
|
||||
|
||||
:param query: Query string.
|
||||
:return: List[Document]
|
||||
"""
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
|
||||
|
||||
@ -136,5 +135,7 @@ class SerperDevWebSearch:
|
||||
|
||||
documents = answer_box + organic + people_also_ask
|
||||
|
||||
links = [result["link"] for result in json_result["organic"]]
|
||||
|
||||
logger.debug("Serper Dev returned %s documents for the query '%s'", len(documents), query)
|
||||
return {"documents": documents[: self.top_k]}
|
||||
return {"documents": documents[: self.top_k], "links": links[: self.top_k]}
|
||||
|
||||
@ -145,9 +145,13 @@ class TestSerperDevSearchAPI:
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 7])
|
||||
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
|
||||
ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k)
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"]
|
||||
assert len(results) == top_k
|
||||
assert all(isinstance(doc, Document) for doc in results)
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
documents = results["documents"]
|
||||
links = results["links"]
|
||||
assert len(documents) == len(links) == top_k
|
||||
assert all(isinstance(doc, Document) for doc in documents)
|
||||
assert all(isinstance(link, str) for link in links)
|
||||
assert all(link.startswith("http") for link in links)
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("requests.post")
|
||||
@ -183,8 +187,12 @@ class TestSerperDevSearchAPI:
|
||||
reason="Export an env var called SERPERDEV_API_KEY containing the SerperDev API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_web_search():
|
||||
def test_web_search(self):
|
||||
ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10)
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")["documents"]
|
||||
assert len(results) == 10
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
documents = results["documents"]
|
||||
links = results["documents"]
|
||||
assert len(documents) == len(links) == 10
|
||||
assert all(isinstance(doc, Document) for doc in results)
|
||||
assert all(isinstance(link, str) for link in links)
|
||||
assert all(link.startswith("http") for link in links)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user