feat: Add domain scoping to WebRetriever (#5587)

* WebSearch: add allowed_domains scoped search

* Add talk to website example

* Add release note

* Add allowed_domains to WebSearch

* Minor fix

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Vladimir Blagojevic 2023-08-28 20:02:02 +02:00 committed by GitHub
parent 81f3aaf3e5
commit 2118f68769
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 10 deletions

View File

@ -0,0 +1,77 @@
import logging
import os
from typing import Dict, Any
from haystack import Pipeline
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker
from haystack.nodes.retriever.web import WebRetriever
search_key = os.environ.get("SERPERDEV_API_KEY")
if not search_key:
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")
models_config: Dict[str, Any] = {
"openai": {"api_key": os.environ.get("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo"},
"anthropic": {"api_key": os.environ.get("ANTHROPIC_API_KEY"), "model_name": "claude-instant-1"},
"hf": {"api_key": os.environ.get("HF_API_KEY"), "model_name": "tiiuae/falcon-7b-instruct"},
}
prompt_text = """
Synthesize a comprehensive answer from the provided paragraphs and the given question.\n
Focus on the question and avoid unnecessary information in your answer.\n
\n\n Paragraphs: {join(documents)} \n\n Question: {query} \n\n Answer:
"""
stream = True
model: Dict[str, str] = models_config["openai"]
prompt_node = PromptNode(
model["model_name"],
default_prompt_template=PromptTemplate(prompt_text),
api_key=model["api_key"],
max_length=768,
model_kwargs={"stream": stream},
)
web_retriever = WebRetriever(
api_key=search_key,
allowed_domains=["haystack.deepset.ai"],
top_search_results=10,
mode="preprocessed_documents",
top_k=50,
cache_document_store=InMemoryDocumentStore(),
)
pipeline = Pipeline()
pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
pipeline.add_node(component=TopPSampler(top_p=0.90), name="Sampler", inputs=["Retriever"])
pipeline.add_node(component=LostInTheMiddleRanker(1024), name="LostInTheMiddleRanker", inputs=["Sampler"])
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"])
logging.disable(logging.CRITICAL)
test = False
questions = [
"What are the main benefits of using pipelines in Haystack?",
"Are there any ready-made pipelines available and why should I use them?",
]
print(f"Running pipeline with {model['model_name']}\n")
if test:
for question in questions:
if stream:
print("Answer:")
response = pipeline.run(query=question)
if not stream:
print(f"Answer: {response['results'][0]}")
else:
while True:
user_input = input("\nAsk question (type 'exit' or 'quit' to quit): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
if stream:
print("Answer:")
response = pipeline.run(query=user_input)
if not stream:
print(f"Answer: {response['results'][0]}")

View File

@ -50,6 +50,7 @@ class WebRetriever(BaseRetriever):
self,
api_key: str,
search_engine_provider: Union[str, SearchEngine] = "SerperDev",
allowed_domains: Optional[List[str]] = None,
top_search_results: Optional[int] = 10,
top_k: Optional[int] = 5,
mode: Literal["snippets", "raw_documents", "preprocessed_documents"] = "snippets",
@ -62,6 +63,7 @@ class WebRetriever(BaseRetriever):
"""
:param api_key: API key for the search engine provider.
:param search_engine_provider: Name of the search engine provider class, see `providers.py` for a list of supported providers.
:param allowed_domains: List of domains to restrict the search to. If not provided, the search is unrestricted.
:param top_search_results: Number of top search results to be retrieved.
:param top_k: Top k documents to be returned by the retriever.
:param mode: Whether to return snippets, raw documents, or preprocessed documents. Snippets are the default.
@ -73,7 +75,10 @@ class WebRetriever(BaseRetriever):
"""
super().__init__()
self.web_search = WebSearch(
api_key=api_key, top_k=top_search_results, search_engine_provider=search_engine_provider
api_key=api_key,
top_k=top_search_results,
allowed_domains=allowed_domains,
search_engine_provider=search_engine_provider,
)
self.mode = mode
self.cache_document_store = cache_document_store

View File

@ -20,12 +20,14 @@ class SerpAPI(SearchEngine):
self,
api_key: str,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
engine: Optional[str] = "google",
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for SerpAPI.
:param top_k: Number of results to return.
:param allowed_domains: List of domains to limit the search to.
:param engine: Search engine to use, for example google, bing, baidu, duckduckgo, yahoo, yandex.
See the [SerpAPI documentation](https://serpapi.com/search-api) for the full list of supported engines.
:param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'lr' to 'lang_en'
@ -38,6 +40,7 @@ class SerpAPI(SearchEngine):
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
self.engine = engine
self.top_k = top_k
self.allowed_domains = allowed_domains
def search(self, query: str, **kwargs) -> List[Document]:
"""
@ -51,7 +54,9 @@ class SerpAPI(SearchEngine):
top_k = kwargs.pop("top_k", self.top_k)
url = "https://serpapi.com/search"
params = {"source": "python", "serp_api_key": self.api_key, "q": query, **kwargs}
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
params = {"source": "python", "serp_api_key": self.api_key, "q": query_prepend + query, **kwargs}
if self.engine:
params["engine"] = self.engine
@ -124,16 +129,24 @@ class SerperDev(SearchEngine):
Search engine using SerperDev API. See the [Serper Dev website](https://serper.dev/) for more details.
"""
def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None):
def __init__(
self,
api_key: str,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SerperDev API.
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_engine_kwargs: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results.
"""
super().__init__()
self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
def search(self, query: str, **kwargs) -> List[Document]:
@ -144,10 +157,12 @@ class SerperDev(SearchEngine):
"""
kwargs = {**self.kwargs, **kwargs}
top_k = kwargs.pop("top_k", self.top_k)
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
url = "https://google.serper.dev/search"
payload = json.dumps({"q": query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs})
payload = json.dumps({"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs})
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload, timeout=30)
@ -211,21 +226,29 @@ class BingAPI(SearchEngine):
Search engine using the Bing API. See [Bing Web Search API](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview) for more details.
"""
def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None):
def __init__(
self,
api_key: str,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the Bing API.
:param top_k: Number of documents to return.
:param search_engine_kwargs: Additional parameters passed to the SerperDev API. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
:param allowed_domains: List of domains to limit the search to.
:param search_engine_kwargs: Additional parameters passed to the Bing. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
"""
super().__init__()
self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
def search(self, query: str, **kwargs) -> List[Document]:
"""
:param query: Query string.
:param kwargs: Additional parameters passed to the SerperDev API.
:param kwargs: Additional parameters passed to the BingAPI.
As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
If you don't specify the market parameter, the default market for the user's location is used.
For a complete list of the market codes, see [Market Codes](https://learn.microsoft.com/en-us/rest/api/cognitiveservices-bingsearch/bing-web-api-v7-reference#market-codes).
@ -237,7 +260,10 @@ class BingAPI(SearchEngine):
top_k = kwargs.pop("top_k", self.top_k)
url = "https://api.bing.microsoft.com/v7.0/search"
params: Dict[str, Union[str, int, float]] = {"q": query, "count": 50, **kwargs}
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
params: Dict[str, Union[str, int, float]] = {"q": query_prepend + query, "count": 50, **kwargs}
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
@ -282,12 +308,14 @@ class GoogleAPI(SearchEngine):
def __init__(
self,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
api_key: Optional[str] = None,
engine_id: Optional[str] = None,
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param api_key: API key for the Google API.
:param engine_id: Engine ID for the Google API.
:param search_engine_kwargs: Additional parameters passed to the Google API. As an example, you can pass the hl parameter to specify the language to use for the query: 'hl':'en'.
@ -296,6 +324,7 @@ class GoogleAPI(SearchEngine):
self.api_key = api_key
self.engine_id = engine_id
self.top_k = top_k
self.allowed_domains = allowed_domains
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
def _validate_environment(self):
@ -338,8 +367,11 @@ class GoogleAPI(SearchEngine):
self._validate_environment()
top_k = kwargs.pop("top_k", self.top_k)
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
params: Dict[str, Union[str, int, float]] = {"num": 10, **kwargs}
res = self.service.cse().list(q=query, cx=self.engine_id, **params).execute()
res = self.service.cse().list(q=query_prepend + query, cx=self.engine_id, **params).execute()
documents: List[Document] = []
for i, result in enumerate(res["items"]):
documents.append(

View File

@ -26,6 +26,7 @@ class WebSearch(BaseComponent):
self,
api_key: str,
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_engine_provider: Union[str, SearchEngine] = "SerperDev",
search_engine_kwargs: Optional[Dict[str, Any]] = None,
):
@ -48,7 +49,7 @@ class WebSearch(BaseComponent):
)
if not issubclass(klass, SearchEngine):
raise ValueError(f"Class {search_engine_provider} is not a subclass of SearchEngine.")
self.search_engine = klass(api_key=api_key, top_k=top_k, search_engine_kwargs=search_engine_kwargs) # type: ignore
self.search_engine = klass(api_key=api_key, top_k=top_k, allowed_domains=allowed_domains, search_engine_kwargs=search_engine_kwargs) # type: ignore
elif isinstance(search_engine_provider, SearchEngine):
self.search_engine = search_engine_provider
else:

View File

@ -0,0 +1,5 @@
---
features:
- |
Introduced `allowed_domains` parameter in `WebRetriever` for domain-specific searches,
thus enabling "talk to a website" and "talk to docs" scenarios.