From 75ff768c2182cd3fdbb1b5b644096271baf7faab Mon Sep 17 00:00:00 2001 From: Pouyan Date: Tue, 2 May 2023 17:09:17 +0200 Subject: [PATCH] Pouyanpi/feat/search engine/providers/google api (#4722) * feat: implement google api search engine provider Signed-off-by: Pouyan --------- Signed-off-by: Pouyan --- haystack/nodes/search_engine/providers.py | 75 +++++++++++++++++++++++ test/nodes/test_web_search.py | 72 ++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/haystack/nodes/search_engine/providers.py b/haystack/nodes/search_engine/providers.py index 823a38139..1979db830 100644 --- a/haystack/nodes/search_engine/providers.py +++ b/haystack/nodes/search_engine/providers.py @@ -274,3 +274,78 @@ class BingAPI(SearchEngine): logger.debug("Bing API returned %s documents for the query '%s'", len(documents), query) return documents[:top_k] + + +class GoogleAPI(SearchEngine): + """Search engine using the Google API. See [Google Search API](https://developers.google.com/custom-search/v1/overview) for more details.""" + + def __init__( + self, + top_k: Optional[int] = 10, + api_key: Optional[str] = None, + engine_id: Optional[str] = None, + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + :param top_k: Number of documents to return. + :param api_key: API key for the Google API. + :param engine_id: Engine ID for the Google API. + :param search_engine_kwargs: Additional parameters passed to the Google API. As an example, you can pass the hl parameter to specify the language to use for the query: 'hl':'en'. + """ + super().__init__() + self.api_key = api_key + self.engine_id = engine_id + self.top_k = top_k + self.kwargs = search_engine_kwargs if search_engine_kwargs else {} + + def _validate_environment(self): + """ + Validate if the environment variables are set. + """ + if not self.api_key: + raise ValueError( + "You need to provide an API key for the Google API. See https://developers.google.com/custom-search/v1/overview" + ) + if not self.engine_id: + raise ValueError( + "You need to provide an engine ID for the Google API. See https://developers.google.com/custom-search/v1/overview" + ) + + # check if google api is installed + try: + from googleapiclient.discovery import build + except ImportError: + raise ImportError( + "You need to install the Google API client. You can do so by running 'pip install google-api-python-client'." + ) + # create a custom search service + self.service = build("customsearch", "v1", developerKey=self.api_key) + + def search(self, query: str, **kwargs) -> List[Document]: + """ + :param query: Query string. + :param kwargs: Additional parameters passed to the Google API. + As an example, you can pass the hl parameter to specify the language to use for the query: 'hl':'en'. + If you don't specify the hl parameter, the default language for the user's location is used. + For a complete list of the language codes, see [Language Codes](https://developers.google.com/custom-search/docs/xml_results#languageCollections). + You can also pass the num parameter to specify the number of results to return: 'num':10. + You can find a full list of parameters at [Query Parameters](https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list). + :return: List[Document] + """ + kwargs = {**self.kwargs, **kwargs} + self.engine_id = kwargs.pop("engine_id", self.engine_id) + + self._validate_environment() + + top_k = kwargs.pop("top_k", self.top_k) + params: Dict[str, Union[str, int, float]] = {"num": 10, **kwargs} + res = self.service.cse().list(q=query, cx=self.engine_id, **params).execute() + documents: List[Document] = [] + for i, result in enumerate(res["items"]): + documents.append( + Document.from_dict( + {"title": result["title"], "content": result["snippet"], "position": i, "link": result["link"]} + ) + ) + logger.debug("Google API returned %s documents for the query '%s'", len(documents), query) + return documents[:top_k] diff --git a/test/nodes/test_web_search.py b/test/nodes/test_web_search.py index 3a96a48d7..9668d7dad 100644 --- a/test/nodes/test_web_search.py +++ b/test/nodes/test_web_search.py @@ -1,10 +1,19 @@ import os +import unittest +from unittest.mock import MagicMock, patch import pytest from haystack.nodes.search_engine import WebSearch from haystack.schema import Document +try: + import googleapiclient + + googleapi_installed = True +except ImportError: + googleapi_installed = False + @pytest.mark.skipif( not os.environ.get("SERPERDEV_API_KEY", None), @@ -33,3 +42,66 @@ def test_web_search_with_site_keyword(): assert all( ["nasa" in doc.meta["link"] or "lifewire" in doc.meta["link"] for doc in result["documents"]] ), "Some documents are not from the specified sites lifewire.com or nasa.gov." + + +@pytest.mark.unit +def test_web_search_with_google_api_provider(): + if not googleapi_installed: + pytest.skip("google-api-python-client is not installed, skipping test.") + + GOOGLE_API_KEY = "dummy_api_key" + SEARCH_ENGINE_ID = "dummy_search_engine_id" + query = "The founder of Python" + + with patch("haystack.nodes.search_engine.WebSearch.run") as mock_run: + mock_run.return_value = ([{"content": "Guido van Rossum"}], None) + ws = WebSearch( + api_key=GOOGLE_API_KEY, + search_engine_provider="GoogleAPI", + search_engine_kwargs={"engine_id": SEARCH_ENGINE_ID}, + ) + result, _ = ws.run(query=query) + + mock_run.assert_called_once_with(query=query) + + assert "guido" in result[0]["content"].lower() + + +@pytest.mark.unit +def test_web_search_with_google_api_client(): + if not googleapi_installed: + pytest.skip("google-api-python-client is not installed, skipping test.") + + GOOGLE_API_KEY = "dummy_api_key" + SEARCH_ENGINE_ID = "dummy_search_engine_id" + query = "The founder of Python" + + with patch("googleapiclient.discovery.build") as mock_build: + mock_service = MagicMock() + mock_cse = MagicMock() + mock_list = MagicMock() + + mock_build.return_value = mock_service + mock_service.cse.return_value = mock_cse + mock_cse.list.return_value = mock_list + mock_list.execute.return_value = { + "items": [ + { + "title": "Guido van Rossum", + "snippet": "The founder of Python programming language.", + "link": "https://example.com/guido", + } + ] + } + + ws = WebSearch( + api_key=GOOGLE_API_KEY, + search_engine_provider="GoogleAPI", + search_engine_kwargs={"engine_id": SEARCH_ENGINE_ID}, + ) + result, _ = ws.run(query=query) + + mock_build.assert_called_once_with("customsearch", "v1", developerKey=GOOGLE_API_KEY) + mock_service.cse.assert_called_once() + mock_cse.list.assert_called_once_with(q=query, cx=SEARCH_ENGINE_ID, num=10) + mock_list.execute.assert_called_once()