docs: review and normalize haystack.components.websearch (#7236)

* docs: review and normalize `haystack.components.websearch`

* fix: use correct type annotations

* refactor: use type from protocol

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Revert "refactor: use type from protocol"

This reverts commit 23d6f45cd763c39b98be1bff03639a90f2a01fac.

* docs: refactor according to comments

* build: correctly pin to 4.7

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Tobias Wochinger 2024-02-28 16:43:08 +01:00 committed by GitHub
parent 20ebb46fa5
commit f22d49944d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 78 additions and 30 deletions

View File

@ -1,7 +1,7 @@
loaders: loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader - type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/websearch] search_path: [../../../haystack/components/websearch]
modules: ["serper_dev"] modules: ["serper_dev", "searchapi"]
ignore_when_discovered: ["__init__"] ignore_when_discovered: ["__init__"]
processors: processors:
- type: filter - type: filter

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any, Union
import requests import requests
@ -20,9 +20,21 @@ class SearchApiError(ComponentError):
@component @component
class SearchApiWebSearch: class SearchApiWebSearch:
""" """
Search engine using SearchApi API. Given a query, it returns a list of URLs that are the most relevant. Uses [SearchApi](https://www.searchapi.io/) to search the web for relevant documents.
See the [SearchApi website](https://www.searchapi.io/) for more details. See the [SearchApi website](https://www.searchapi.io/) for more details.
Usage example:
```python
from haystack.components.websearch import SearchApiWebSearch
from haystack.utils import Secret
websearch = SearchApiWebSearch(top_k=10, api_key=Secret.from_token("test-api-key"))
results = websearch.run(query="Who is the boyfriend of Olivia Wilde?")
assert results["documents"]
assert results["links"]
```
""" """
def __init__( def __init__(
@ -37,8 +49,8 @@ class SearchApiWebSearch:
:param top_k: Number of documents to return. :param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to. :param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SearchApi API. :param search_params: Additional parameters passed to the SearchApi API.
For example, you can set 'num' to 100 to increase the number of search results. For example, you can set 'num' to 100 to increase the number of search results.
See the [SearchApi website](https://www.searchapi.io/) for more details. See the [SearchApi website](https://www.searchapi.io/) for more details.
""" """
self.api_key = api_key self.api_key = api_key
@ -51,7 +63,10 @@ class SearchApiWebSearch:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Serialize this component to a dictionary. Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
""" """
return default_to_dict( return default_to_dict(
self, self,
@ -64,17 +79,27 @@ class SearchApiWebSearch:
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SearchApiWebSearch": def from_dict(cls, data: Dict[str, Any]) -> "SearchApiWebSearch":
""" """
Deserialize this component from a dictionary. Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str]) @component.output_types(documents=List[Document], links=Union[List[Document], List[str]])
def run(self, query: str): def run(self, query: str) -> Dict[str, Union[List[Document], List[str]]]:
""" """
Search the SearchApi API for the given query and return the results as a list of Documents and a list of links. Uses [SearchApi](https://www.searchapi.io/) to search the web.
:param query: Query string. :param query: Search query.
:returns: A dictionary with the following keys:
- "documents": List of documents returned by the search engine.
- "links": List of links returned by the search engine.
:raises TimeoutError: If the request to the SearchApi API times out.
:raises SearchApiError: If an error occurs while querying the SearchApi API.
""" """
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
@ -84,8 +109,8 @@ class SearchApiWebSearch:
try: try:
response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90) response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90)
response.raise_for_status() # Will raise an HTTPError for bad responses response.raise_for_status() # Will raise an HTTPError for bad responses
except requests.Timeout: except requests.Timeout as error:
raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") from error
except requests.RequestException as e: except requests.RequestException as e:
raise SearchApiError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e raise SearchApiError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e

View File

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any, Union
import requests import requests
@ -20,9 +20,21 @@ class SerperDevError(ComponentError):
@component @component
class SerperDevWebSearch: class SerperDevWebSearch:
""" """
Search engine using SerperDev API. Given a query, it returns a list of URLs that are the most relevant. Uses [Serper](https://serper.dev/) to search the web for relevant documents.
See the [Serper Dev website](https://serper.dev/) for more details. See the [Serper Dev website](https://serper.dev/) for more details.
Usage example:
```python
from haystack.components.websearch import SerperDevWebSearch
from haystack.utils import Secret
websearch = SerperDevWebSearch(top_k=10, api_key=Secret.from_token("test-api-key"))
results = websearch.run(query="Who is the boyfriend of Olivia Wilde?")
assert results["documents"]
assert results["links"]
```
""" """
def __init__( def __init__(
@ -33,12 +45,12 @@ class SerperDevWebSearch:
search_params: Optional[Dict[str, Any]] = None, search_params: Optional[Dict[str, Any]] = None,
): ):
""" """
:param api_key: API key for the SerperDev API. :param api_key: API key for the Serper API.
:param top_k: Number of documents to return. :param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to. :param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API. :param search_params: Additional parameters passed to the Serper API.
For example, you can set 'num' to 20 to increase the number of search results. For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details. See the [Serper website](https://serper.dev/) for more details.
""" """
self.api_key = api_key self.api_key = api_key
self.top_k = top_k self.top_k = top_k
@ -50,7 +62,10 @@ class SerperDevWebSearch:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Serialize this component to a dictionary. Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
""" """
return default_to_dict( return default_to_dict(
self, self,
@ -63,17 +78,25 @@ class SerperDevWebSearch:
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch": def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
""" """
Deserialize this component from a dictionary. Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str]) @component.output_types(documents=List[Document], links=Union[List[Document], List[str]])
def run(self, query: str): def run(self, query: str) -> Dict[str, Union[List[Document], List[str]]]:
""" """
Search the SerperDev API for the given query and return the results as a list of Documents and a list of links. Use [Serper](https://serper.dev/) to search the web.
:param query: Query string. :param query: Search query.
:returns: A dictionary with the following keys:
- "documents": List of documents returned by the search engine.
- "links": List of links returned by the search engine.
:raises SerperDevError: If an error occurs while querying the SerperDev API.
:raises TimeoutError: If the request to the SerperDev API times out.
""" """
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
@ -85,8 +108,8 @@ class SerperDevWebSearch:
try: try:
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) # type: ignore response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) # type: ignore
response.raise_for_status() # Will raise an HTTPError for bad responses response.raise_for_status() # Will raise an HTTPError for bad responses
except requests.Timeout: except requests.Timeout as error:
raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") from error
except requests.RequestException as e: except requests.RequestException as e:
raise SerperDevError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e raise SerperDevError(f"An error occurred while querying {self.__class__.__name__}. Error: {e}") from e

View File

@ -57,7 +57,7 @@ dependencies = [
"pyyaml", "pyyaml",
"more-itertools", # TextDocumentSplitter "more-itertools", # TextDocumentSplitter
"networkx", # Pipeline graphs "networkx", # Pipeline graphs
"typing_extensions>=3.7", # typing support for Python 3.8 "typing_extensions>=4.7", # typing support for Python 3.8
"boilerpy3", # Fulltext extraction from HTML pages "boilerpy3", # Fulltext extraction from HTML pages
] ]

View File

@ -1,5 +1,5 @@
--- ---
fixes: fixes:
- | - |
Pin the `typing-extensions` package to versions >= 3.7 to avoid Pin the `typing-extensions` package to versions >= 4.7 to avoid
[incompatibilities with the `openai` package](https://community.openai.com/t/error-while-importing-openai-from-open-import-openai/578166/26). [incompatibilities with the `openai` package](https://community.openai.com/t/error-while-importing-openai-from-open-import-openai/578166/26).

View File

@ -174,7 +174,7 @@ class TestSerperDevSearchAPI:
ws = SerperDevWebSearch(top_k=10) ws = SerperDevWebSearch(top_k=10)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?") results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"] documents = results["documents"]
links = results["documents"] links = results["links"]
assert len(documents) == len(links) == 10 assert len(documents) == len(links) == 10
assert all(isinstance(doc, Document) for doc in results) assert all(isinstance(doc, Document) for doc in results)
assert all(isinstance(link, str) for link in links) assert all(isinstance(link, str) for link in links)