diff --git a/examples/talk_to_website.py b/examples/talk_to_website.py new file mode 100644 index 000000000..f4a957583 --- /dev/null +++ b/examples/talk_to_website.py @@ -0,0 +1,77 @@ +import logging +import os +from typing import Dict, Any + +from haystack import Pipeline +from haystack.document_stores import InMemoryDocumentStore +from haystack.nodes import PromptNode, PromptTemplate, TopPSampler +from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker +from haystack.nodes.retriever.web import WebRetriever + +search_key = os.environ.get("SERPERDEV_API_KEY") +if not search_key: + raise ValueError("Please set the SERPERDEV_API_KEY environment variable") + +models_config: Dict[str, Any] = { + "openai": {"api_key": os.environ.get("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo"}, + "anthropic": {"api_key": os.environ.get("ANTHROPIC_API_KEY"), "model_name": "claude-instant-1"}, + "hf": {"api_key": os.environ.get("HF_API_KEY"), "model_name": "tiiuae/falcon-7b-instruct"}, +} +prompt_text = """ +Synthesize a comprehensive answer from the provided paragraphs and the given question.\n +Focus on the question and avoid unnecessary information in your answer.\n +\n\n Paragraphs: {join(documents)} \n\n Question: {query} \n\n Answer: +""" + +stream = True +model: Dict[str, str] = models_config["openai"] +prompt_node = PromptNode( + model["model_name"], + default_prompt_template=PromptTemplate(prompt_text), + api_key=model["api_key"], + max_length=768, + model_kwargs={"stream": stream}, +) + +web_retriever = WebRetriever( + api_key=search_key, + allowed_domains=["haystack.deepset.ai"], + top_search_results=10, + mode="preprocessed_documents", + top_k=50, + cache_document_store=InMemoryDocumentStore(), +) + +pipeline = Pipeline() +pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"]) +pipeline.add_node(component=TopPSampler(top_p=0.90), name="Sampler", inputs=["Retriever"]) +pipeline.add_node(component=LostInTheMiddleRanker(1024), name="LostInTheMiddleRanker", inputs=["Sampler"]) +pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"]) + +logging.disable(logging.CRITICAL) + +test = False +questions = [ + "What are the main benefits of using pipelines in Haystack?", + "Are there any ready-made pipelines available and why should I use them?", +] + +print(f"Running pipeline with {model['model_name']}\n") + +if test: + for question in questions: + if stream: + print("Answer:") + response = pipeline.run(query=question) + if not stream: + print(f"Answer: {response['results'][0]}") +else: + while True: + user_input = input("\nAsk question (type 'exit' or 'quit' to quit): ") + if user_input.lower() == "exit" or user_input.lower() == "quit": + break + if stream: + print("Answer:") + response = pipeline.run(query=user_input) + if not stream: + print(f"Answer: {response['results'][0]}") diff --git a/haystack/nodes/retriever/web.py b/haystack/nodes/retriever/web.py index 106ba7cf6..e93e32d26 100644 --- a/haystack/nodes/retriever/web.py +++ b/haystack/nodes/retriever/web.py @@ -50,6 +50,7 @@ class WebRetriever(BaseRetriever): self, api_key: str, search_engine_provider: Union[str, SearchEngine] = "SerperDev", + allowed_domains: Optional[List[str]] = None, top_search_results: Optional[int] = 10, top_k: Optional[int] = 5, mode: Literal["snippets", "raw_documents", "preprocessed_documents"] = "snippets", @@ -62,6 +63,7 @@ class WebRetriever(BaseRetriever): """ :param api_key: API key for the search engine provider. :param search_engine_provider: Name of the search engine provider class, see `providers.py` for a list of supported providers. + :param allowed_domains: List of domains to restrict the search to. If not provided, the search is unrestricted. :param top_search_results: Number of top search results to be retrieved. :param top_k: Top k documents to be returned by the retriever. :param mode: Whether to return snippets, raw documents, or preprocessed documents. Snippets are the default. @@ -73,7 +75,10 @@ class WebRetriever(BaseRetriever): """ super().__init__() self.web_search = WebSearch( - api_key=api_key, top_k=top_search_results, search_engine_provider=search_engine_provider + api_key=api_key, + top_k=top_search_results, + allowed_domains=allowed_domains, + search_engine_provider=search_engine_provider, ) self.mode = mode self.cache_document_store = cache_document_store diff --git a/haystack/nodes/search_engine/providers.py b/haystack/nodes/search_engine/providers.py index 1979db830..6446ee512 100644 --- a/haystack/nodes/search_engine/providers.py +++ b/haystack/nodes/search_engine/providers.py @@ -20,12 +20,14 @@ class SerpAPI(SearchEngine): self, api_key: str, top_k: Optional[int] = 10, + allowed_domains: Optional[List[str]] = None, engine: Optional[str] = "google", search_engine_kwargs: Optional[Dict[str, Any]] = None, ): """ :param api_key: API key for SerpAPI. :param top_k: Number of results to return. + :param allowed_domains: List of domains to limit the search to. :param engine: Search engine to use, for example google, bing, baidu, duckduckgo, yahoo, yandex. See the [SerpAPI documentation](https://serpapi.com/search-api) for the full list of supported engines. :param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'lr' to 'lang_en' @@ -38,6 +40,7 @@ class SerpAPI(SearchEngine): self.kwargs = search_engine_kwargs if search_engine_kwargs else {} self.engine = engine self.top_k = top_k + self.allowed_domains = allowed_domains def search(self, query: str, **kwargs) -> List[Document]: """ @@ -51,7 +54,9 @@ class SerpAPI(SearchEngine): top_k = kwargs.pop("top_k", self.top_k) url = "https://serpapi.com/search" - params = {"source": "python", "serp_api_key": self.api_key, "q": query, **kwargs} + allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains) + query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else "" + params = {"source": "python", "serp_api_key": self.api_key, "q": query_prepend + query, **kwargs} if self.engine: params["engine"] = self.engine @@ -124,16 +129,24 @@ class SerperDev(SearchEngine): Search engine using SerperDev API. See the [Serper Dev website](https://serper.dev/) for more details. """ - def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, + api_key: str, + top_k: Optional[int] = 10, + allowed_domains: Optional[List[str]] = None, + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): """ :param api_key: API key for the SerperDev API. :param top_k: Number of documents to return. + :param allowed_domains: List of domains to limit the search to. :param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'num' to 20 to increase the number of search results. """ super().__init__() self.api_key = api_key self.top_k = top_k + self.allowed_domains = allowed_domains self.kwargs = search_engine_kwargs if search_engine_kwargs else {} def search(self, query: str, **kwargs) -> List[Document]: @@ -144,10 +157,12 @@ class SerperDev(SearchEngine): """ kwargs = {**self.kwargs, **kwargs} top_k = kwargs.pop("top_k", self.top_k) + allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains) + query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else "" url = "https://google.serper.dev/search" - payload = json.dumps({"q": query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs}) + payload = json.dumps({"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs}) headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} response = requests.request("POST", url, headers=headers, data=payload, timeout=30) @@ -211,21 +226,29 @@ class BingAPI(SearchEngine): Search engine using the Bing API. See [Bing Web Search API](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview) for more details. """ - def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, + api_key: str, + top_k: Optional[int] = 10, + allowed_domains: Optional[List[str]] = None, + search_engine_kwargs: Optional[Dict[str, Any]] = None, + ): """ :param api_key: API key for the Bing API. :param top_k: Number of documents to return. - :param search_engine_kwargs: Additional parameters passed to the SerperDev API. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'. + :param allowed_domains: List of domains to limit the search to. + :param search_engine_kwargs: Additional parameters passed to the Bing. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'. """ super().__init__() self.api_key = api_key self.top_k = top_k + self.allowed_domains = allowed_domains self.kwargs = search_engine_kwargs if search_engine_kwargs else {} def search(self, query: str, **kwargs) -> List[Document]: """ :param query: Query string. - :param kwargs: Additional parameters passed to the SerperDev API. + :param kwargs: Additional parameters passed to the BingAPI. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'. If you don't specify the market parameter, the default market for the user's location is used. For a complete list of the market codes, see [Market Codes](https://learn.microsoft.com/en-us/rest/api/cognitiveservices-bingsearch/bing-web-api-v7-reference#market-codes). @@ -237,7 +260,10 @@ class BingAPI(SearchEngine): top_k = kwargs.pop("top_k", self.top_k) url = "https://api.bing.microsoft.com/v7.0/search" - params: Dict[str, Union[str, int, float]] = {"q": query, "count": 50, **kwargs} + allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains) + query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else "" + + params: Dict[str, Union[str, int, float]] = {"q": query_prepend + query, "count": 50, **kwargs} headers = {"Ocp-Apim-Subscription-Key": self.api_key} @@ -282,12 +308,14 @@ class GoogleAPI(SearchEngine): def __init__( self, top_k: Optional[int] = 10, + allowed_domains: Optional[List[str]] = None, api_key: Optional[str] = None, engine_id: Optional[str] = None, search_engine_kwargs: Optional[Dict[str, Any]] = None, ): """ :param top_k: Number of documents to return. + :param allowed_domains: List of domains to limit the search to. :param api_key: API key for the Google API. :param engine_id: Engine ID for the Google API. :param search_engine_kwargs: Additional parameters passed to the Google API. As an example, you can pass the hl parameter to specify the language to use for the query: 'hl':'en'. @@ -296,6 +324,7 @@ class GoogleAPI(SearchEngine): self.api_key = api_key self.engine_id = engine_id self.top_k = top_k + self.allowed_domains = allowed_domains self.kwargs = search_engine_kwargs if search_engine_kwargs else {} def _validate_environment(self): @@ -338,8 +367,11 @@ class GoogleAPI(SearchEngine): self._validate_environment() top_k = kwargs.pop("top_k", self.top_k) + allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains) + query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else "" + params: Dict[str, Union[str, int, float]] = {"num": 10, **kwargs} - res = self.service.cse().list(q=query, cx=self.engine_id, **params).execute() + res = self.service.cse().list(q=query_prepend + query, cx=self.engine_id, **params).execute() documents: List[Document] = [] for i, result in enumerate(res["items"]): documents.append( diff --git a/haystack/nodes/search_engine/web.py b/haystack/nodes/search_engine/web.py index e4de1eee4..eaa21b36c 100644 --- a/haystack/nodes/search_engine/web.py +++ b/haystack/nodes/search_engine/web.py @@ -26,6 +26,7 @@ class WebSearch(BaseComponent): self, api_key: str, top_k: Optional[int] = 10, + allowed_domains: Optional[List[str]] = None, search_engine_provider: Union[str, SearchEngine] = "SerperDev", search_engine_kwargs: Optional[Dict[str, Any]] = None, ): @@ -48,7 +49,7 @@ class WebSearch(BaseComponent): ) if not issubclass(klass, SearchEngine): raise ValueError(f"Class {search_engine_provider} is not a subclass of SearchEngine.") - self.search_engine = klass(api_key=api_key, top_k=top_k, search_engine_kwargs=search_engine_kwargs) # type: ignore + self.search_engine = klass(api_key=api_key, top_k=top_k, allowed_domains=allowed_domains, search_engine_kwargs=search_engine_kwargs) # type: ignore elif isinstance(search_engine_provider, SearchEngine): self.search_engine = search_engine_provider else: diff --git a/releasenotes/notes/web-retriever-add-domain-scoping-6594425e0c0ace3c.yaml b/releasenotes/notes/web-retriever-add-domain-scoping-6594425e0c0ace3c.yaml new file mode 100644 index 000000000..51b1b4ed4 --- /dev/null +++ b/releasenotes/notes/web-retriever-add-domain-scoping-6594425e0c0ace3c.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Introduced `allowed_domains` parameter in `WebRetriever` for domain-specific searches, + thus enabling "talk to a website" and "talk to docs" scenarios.