mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
feat: Add domain scoping to WebRetriever (#5587)
* WebSearch: add allowed_domains scoped search * Add talk to website example * Add release note * Add allowed_domains to WebSearch * Minor fix --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
81f3aaf3e5
commit
2118f68769
77
examples/talk_to_website.py
Normal file
77
examples/talk_to_website.py
Normal file
@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.document_stores import InMemoryDocumentStore
|
||||
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler
|
||||
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker
|
||||
from haystack.nodes.retriever.web import WebRetriever
|
||||
|
||||
search_key = os.environ.get("SERPERDEV_API_KEY")
|
||||
if not search_key:
|
||||
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")
|
||||
|
||||
models_config: Dict[str, Any] = {
|
||||
"openai": {"api_key": os.environ.get("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo"},
|
||||
"anthropic": {"api_key": os.environ.get("ANTHROPIC_API_KEY"), "model_name": "claude-instant-1"},
|
||||
"hf": {"api_key": os.environ.get("HF_API_KEY"), "model_name": "tiiuae/falcon-7b-instruct"},
|
||||
}
|
||||
prompt_text = """
|
||||
Synthesize a comprehensive answer from the provided paragraphs and the given question.\n
|
||||
Focus on the question and avoid unnecessary information in your answer.\n
|
||||
\n\n Paragraphs: {join(documents)} \n\n Question: {query} \n\n Answer:
|
||||
"""
|
||||
|
||||
stream = True
|
||||
model: Dict[str, str] = models_config["openai"]
|
||||
prompt_node = PromptNode(
|
||||
model["model_name"],
|
||||
default_prompt_template=PromptTemplate(prompt_text),
|
||||
api_key=model["api_key"],
|
||||
max_length=768,
|
||||
model_kwargs={"stream": stream},
|
||||
)
|
||||
|
||||
web_retriever = WebRetriever(
|
||||
api_key=search_key,
|
||||
allowed_domains=["haystack.deepset.ai"],
|
||||
top_search_results=10,
|
||||
mode="preprocessed_documents",
|
||||
top_k=50,
|
||||
cache_document_store=InMemoryDocumentStore(),
|
||||
)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=TopPSampler(top_p=0.90), name="Sampler", inputs=["Retriever"])
|
||||
pipeline.add_node(component=LostInTheMiddleRanker(1024), name="LostInTheMiddleRanker", inputs=["Sampler"])
|
||||
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"])
|
||||
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
test = False
|
||||
questions = [
|
||||
"What are the main benefits of using pipelines in Haystack?",
|
||||
"Are there any ready-made pipelines available and why should I use them?",
|
||||
]
|
||||
|
||||
print(f"Running pipeline with {model['model_name']}\n")
|
||||
|
||||
if test:
|
||||
for question in questions:
|
||||
if stream:
|
||||
print("Answer:")
|
||||
response = pipeline.run(query=question)
|
||||
if not stream:
|
||||
print(f"Answer: {response['results'][0]}")
|
||||
else:
|
||||
while True:
|
||||
user_input = input("\nAsk question (type 'exit' or 'quit' to quit): ")
|
||||
if user_input.lower() == "exit" or user_input.lower() == "quit":
|
||||
break
|
||||
if stream:
|
||||
print("Answer:")
|
||||
response = pipeline.run(query=user_input)
|
||||
if not stream:
|
||||
print(f"Answer: {response['results'][0]}")
|
||||
@ -50,6 +50,7 @@ class WebRetriever(BaseRetriever):
|
||||
self,
|
||||
api_key: str,
|
||||
search_engine_provider: Union[str, SearchEngine] = "SerperDev",
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
top_search_results: Optional[int] = 10,
|
||||
top_k: Optional[int] = 5,
|
||||
mode: Literal["snippets", "raw_documents", "preprocessed_documents"] = "snippets",
|
||||
@ -62,6 +63,7 @@ class WebRetriever(BaseRetriever):
|
||||
"""
|
||||
:param api_key: API key for the search engine provider.
|
||||
:param search_engine_provider: Name of the search engine provider class, see `providers.py` for a list of supported providers.
|
||||
:param allowed_domains: List of domains to restrict the search to. If not provided, the search is unrestricted.
|
||||
:param top_search_results: Number of top search results to be retrieved.
|
||||
:param top_k: Top k documents to be returned by the retriever.
|
||||
:param mode: Whether to return snippets, raw documents, or preprocessed documents. Snippets are the default.
|
||||
@ -73,7 +75,10 @@ class WebRetriever(BaseRetriever):
|
||||
"""
|
||||
super().__init__()
|
||||
self.web_search = WebSearch(
|
||||
api_key=api_key, top_k=top_search_results, search_engine_provider=search_engine_provider
|
||||
api_key=api_key,
|
||||
top_k=top_search_results,
|
||||
allowed_domains=allowed_domains,
|
||||
search_engine_provider=search_engine_provider,
|
||||
)
|
||||
self.mode = mode
|
||||
self.cache_document_store = cache_document_store
|
||||
|
||||
@ -20,12 +20,14 @@ class SerpAPI(SearchEngine):
|
||||
self,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
engine: Optional[str] = "google",
|
||||
search_engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: API key for SerpAPI.
|
||||
:param top_k: Number of results to return.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param engine: Search engine to use, for example google, bing, baidu, duckduckgo, yahoo, yandex.
|
||||
See the [SerpAPI documentation](https://serpapi.com/search-api) for the full list of supported engines.
|
||||
:param search_engine_kwargs: Additional parameters passed to the SerperDev API. For example, you can set 'lr' to 'lang_en'
|
||||
@ -38,6 +40,7 @@ class SerpAPI(SearchEngine):
|
||||
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
|
||||
self.engine = engine
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
|
||||
def search(self, query: str, **kwargs) -> List[Document]:
|
||||
"""
|
||||
@ -51,7 +54,9 @@ class SerpAPI(SearchEngine):
|
||||
top_k = kwargs.pop("top_k", self.top_k)
|
||||
url = "https://serpapi.com/search"
|
||||
|
||||
params = {"source": "python", "serp_api_key": self.api_key, "q": query, **kwargs}
|
||||
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
|
||||
params = {"source": "python", "serp_api_key": self.api_key, "q": query_prepend + query, **kwargs}
|
||||
|
||||
if self.engine:
|
||||
params["engine"] = self.engine
|
||||
@ -124,16 +129,24 @@ class SerperDev(SearchEngine):
|
||||
Search engine using SerperDev API. See the [Serper Dev website](https://serper.dev/) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
search_engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: API key for the SerperDev API.
|
||||
:param top_k: Number of documents to return.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param search_engine_kwargs: Additional parameters passed to the SerperDev API.
|
||||
For example, you can set 'num' to 20 to increase the number of search results.
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
|
||||
|
||||
def search(self, query: str, **kwargs) -> List[Document]:
|
||||
@ -144,10 +157,12 @@ class SerperDev(SearchEngine):
|
||||
"""
|
||||
kwargs = {**self.kwargs, **kwargs}
|
||||
top_k = kwargs.pop("top_k", self.top_k)
|
||||
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
|
||||
|
||||
url = "https://google.serper.dev/search"
|
||||
|
||||
payload = json.dumps({"q": query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs})
|
||||
payload = json.dumps({"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **kwargs})
|
||||
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload, timeout=30)
|
||||
@ -211,21 +226,29 @@ class BingAPI(SearchEngine):
|
||||
Search engine using the Bing API. See [Bing Web Search API](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview) for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, top_k: Optional[int] = 10, search_engine_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
search_engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: API key for the Bing API.
|
||||
:param top_k: Number of documents to return.
|
||||
:param search_engine_kwargs: Additional parameters passed to the SerperDev API. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param search_engine_kwargs: Additional parameters passed to the Bing. As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
|
||||
|
||||
def search(self, query: str, **kwargs) -> List[Document]:
|
||||
"""
|
||||
:param query: Query string.
|
||||
:param kwargs: Additional parameters passed to the SerperDev API.
|
||||
:param kwargs: Additional parameters passed to the BingAPI.
|
||||
As an example, you can pass the market parameter to specify the market to use for the query: 'mkt':'en-US'.
|
||||
If you don't specify the market parameter, the default market for the user's location is used.
|
||||
For a complete list of the market codes, see [Market Codes](https://learn.microsoft.com/en-us/rest/api/cognitiveservices-bingsearch/bing-web-api-v7-reference#market-codes).
|
||||
@ -237,7 +260,10 @@ class BingAPI(SearchEngine):
|
||||
top_k = kwargs.pop("top_k", self.top_k)
|
||||
url = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
params: Dict[str, Union[str, int, float]] = {"q": query, "count": 50, **kwargs}
|
||||
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
|
||||
|
||||
params: Dict[str, Union[str, int, float]] = {"q": query_prepend + query, "count": 50, **kwargs}
|
||||
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
|
||||
|
||||
@ -282,12 +308,14 @@ class GoogleAPI(SearchEngine):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
engine_id: Optional[str] = None,
|
||||
search_engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param top_k: Number of documents to return.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param api_key: API key for the Google API.
|
||||
:param engine_id: Engine ID for the Google API.
|
||||
:param search_engine_kwargs: Additional parameters passed to the Google API. As an example, you can pass the hl parameter to specify the language to use for the query: 'hl':'en'.
|
||||
@ -296,6 +324,7 @@ class GoogleAPI(SearchEngine):
|
||||
self.api_key = api_key
|
||||
self.engine_id = engine_id
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
self.kwargs = search_engine_kwargs if search_engine_kwargs else {}
|
||||
|
||||
def _validate_environment(self):
|
||||
@ -338,8 +367,11 @@ class GoogleAPI(SearchEngine):
|
||||
self._validate_environment()
|
||||
|
||||
top_k = kwargs.pop("top_k", self.top_k)
|
||||
allowed_domains = kwargs.pop("allowed_domains", self.allowed_domains)
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in allowed_domains) if allowed_domains else ""
|
||||
|
||||
params: Dict[str, Union[str, int, float]] = {"num": 10, **kwargs}
|
||||
res = self.service.cse().list(q=query, cx=self.engine_id, **params).execute()
|
||||
res = self.service.cse().list(q=query_prepend + query, cx=self.engine_id, **params).execute()
|
||||
documents: List[Document] = []
|
||||
for i, result in enumerate(res["items"]):
|
||||
documents.append(
|
||||
|
||||
@ -26,6 +26,7 @@ class WebSearch(BaseComponent):
|
||||
self,
|
||||
api_key: str,
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
search_engine_provider: Union[str, SearchEngine] = "SerperDev",
|
||||
search_engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@ -48,7 +49,7 @@ class WebSearch(BaseComponent):
|
||||
)
|
||||
if not issubclass(klass, SearchEngine):
|
||||
raise ValueError(f"Class {search_engine_provider} is not a subclass of SearchEngine.")
|
||||
self.search_engine = klass(api_key=api_key, top_k=top_k, search_engine_kwargs=search_engine_kwargs) # type: ignore
|
||||
self.search_engine = klass(api_key=api_key, top_k=top_k, allowed_domains=allowed_domains, search_engine_kwargs=search_engine_kwargs) # type: ignore
|
||||
elif isinstance(search_engine_provider, SearchEngine):
|
||||
self.search_engine = search_engine_provider
|
||||
else:
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Introduced `allowed_domains` parameter in `WebRetriever` for domain-specific searches,
|
||||
thus enabling "talk to a website" and "talk to docs" scenarios.
|
||||
Loading…
x
Reference in New Issue
Block a user