refactor: Extract link retrieval from WebRetriever, introduce LinkContentRetriever (#5227)

* Extract link retrieval from WebRetriever, introduce LinkContentRetriever

* Add example
---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: Daria Fokina <daria.f93@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-07-13 12:54:40 +02:00 committed by GitHub
parent fd350bbb8f
commit f21005f8ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 615 additions and 1 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/nodes/retriever]
modules: ["base", "sparse", "dense", "multimodal/retriever", "web"]
modules: ["base", "sparse", "dense", "multimodal/retriever", "web", "link_content"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -0,0 +1,36 @@
import os
from haystack.nodes import PromptNode, LinkContentFetcher, PromptTemplate
from haystack import Pipeline
openai_key = os.environ.get("OPENAI_API_KEY")
if not openai_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable")
retriever = LinkContentFetcher()
pt = PromptTemplate(
"Given the paragraphs of the blog post, "
"provide the main learnings and the final conclusion using short bullet points format."
"\n\nParagraphs: {documents}"
)
prompt_node = PromptNode(
"gpt-3.5-turbo-16k-0613",
api_key=openai_key,
max_length=512,
default_prompt_template=pt,
model_kwargs={"stream": True},
)
pipeline = Pipeline()
pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"])
blog_posts = [
"https://pythonspeed.com/articles/base-image-python-docker-images/",
"https://lilianweng.github.io/posts/2023-06-23-agent/",
]
for blog_post in blog_posts:
print(f"Blog post summary: {blog_post}")
pipeline.run(blog_post)
print("\n\n\n")

View File

@ -39,6 +39,7 @@ from haystack.nodes.retriever import (
TfidfRetriever,
TableTextRetriever,
MultiModalRetriever,
LinkContentFetcher,
WebRetriever,
)

View File

@ -8,4 +8,5 @@ from haystack.nodes.retriever.dense import (
)
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.link_content import LinkContentFetcher
from haystack.nodes.retriever.web import WebRetriever

View File

@ -0,0 +1,208 @@
import logging
from datetime import datetime
from http import HTTPStatus
from typing import Optional, Dict, List, Union, Callable, Any, Tuple
from urllib.parse import urlparse
import requests
from boilerpy3 import extractors
from requests import Response
from requests.exceptions import InvalidURL
from haystack import __version__
from haystack.nodes import PreProcessor, BaseComponent
from haystack.schema import Document, MultiLabel
logger = logging.getLogger(__name__)
def html_content_handler(response: Response, raise_on_failure: bool = False) -> Optional[str]:
"""
Extracts content from the response text using the boilerpy3 extractor.
:param response: Response object from the request.
:param raise_on_failure: A boolean indicating whether to raise an exception when a failure occurs
"""
extractor = extractors.ArticleExtractor(raise_on_failure=raise_on_failure)
content = ""
try:
content = extractor.get_content(response.text)
except Exception as e:
if raise_on_failure:
raise e
return content
def pdf_content_handler(response: Response, raise_on_failure: bool = False) -> Optional[str]:
# TODO: implement this
return None
class LinkContentFetcher(BaseComponent):
"""
LinkContentFetcher fetches content from a URL and converts it into a list of Document objects.
LinkContentFetcher supports the following content types:
- HTML
"""
outgoing_edges = 1
REQUEST_HEADERS = {
"accept": "*/*",
"User-Agent": f"haystack/LinkContentFetcher/{__version__}",
"Accept-Language": "en-US,en;q=0.9,it;q=0.8,es;q=0.7",
"referer": "https://www.google.com/",
}
def __init__(self, processor: Optional[PreProcessor] = None, raise_on_failure: Optional[bool] = False):
"""
Creates a LinkContentFetcher instance.
:param processor: PreProcessor to apply to the extracted text
:param raise_on_failure: A boolean indicating whether to raise an exception when a failure occurs
during content extraction. If False, the error is simply logged and the program continues.
Defaults to False.
"""
super().__init__()
self.processor = processor
self.raise_on_failure = raise_on_failure
self.handlers: Dict[str, Callable] = {"html": html_content_handler, "pdf": pdf_content_handler}
def fetch(self, url: str, timeout: Optional[int] = 3, doc_kwargs: Optional[dict] = None) -> List[Document]:
"""
Fetches content from a URL and converts it into a list of Document objects. If no content is extracted,
an empty list is returned.
:param url: URL to fetch content from.
:param timeout: Timeout in seconds for the request.
:param doc_kwargs: Optional kwargs to pass to the Document constructor.
:return: List of Document objects or an empty list if no content is extracted.
"""
if not url or not self._is_valid_url(url):
raise InvalidURL("Invalid or missing URL: {}".format(url))
doc_kwargs = doc_kwargs or {}
extracted_doc: Dict[str, Union[str, dict]] = {
"meta": {"url": url, "timestamp": int(datetime.utcnow().timestamp())}
}
extracted_doc.update(doc_kwargs)
response = self._get_response(url, timeout=timeout)
has_content = response.status_code == HTTPStatus.OK and response.text
fetched_documents = []
if has_content:
handler = "html" # will handle non-HTML content types soon, add content type resolution here
if handler in self.handlers:
extracted_content = self.handlers[handler](response, self.raise_on_failure)
if extracted_content:
extracted_doc["content"] = extracted_content
logger.debug("%s handler extracted content from %s", handler, url)
else:
logger.warning("%s handler failed to extract content from %s", handler, url)
# perhaps we have a snippet from web search, if so, use it as content
snippet_text = extracted_doc.get("snippet_text", "")
if snippet_text:
extracted_doc["content"] = snippet_text
if "content" in extracted_doc:
document = Document.from_dict(extracted_doc)
if self.processor:
fetched_documents = self.processor.process(documents=[document])
else:
fetched_documents = [document]
return fetched_documents
def run(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
) -> Tuple[Dict, str]:
"""
Fetches content from a URL specified by query parameter and converts it into a list of Document objects.
param query: The query - a URL to fetch content from.
param filters: Not used.
param top_k: Not used.
param labels: Not used.
param documents: Not used.
param meta: Not used.
return: List of Document objects.
"""
if not query:
raise ValueError("LinkContentFetcher run requires the `query` parameter")
documents = self.fetch(url=query)
return {"documents": documents}, "output_1"
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
"""
Takes a list of queries, where each query is expected to be a URL. For each query, the method
fetches content from the specified URL and transforms it into a list of Document objects. The output is a list
of these document lists, where each individual list of Document objects corresponds to the content retrieved
from a specific query URL.
param queries: List of queries - URLs to fetch content from.
param file_paths: Not used.
param labels: Not used.
param documents: Not used.
param meta: Not used.
param params: Not used.
param debug: Not used.
return: List of lists of Document objects.
"""
results = []
if isinstance(queries, str):
queries = [queries]
elif not isinstance(queries, list):
raise ValueError(
"LinkContentFetcher run_batch requires the `queries` parameter to be Union[str, List[str]]"
)
for query in queries:
results.append(self.fetch(url=query))
return {"documents": results}, "output_1"
def _get_response(self, url: str, timeout: Optional[int]) -> requests.Response:
"""
Fetches content from a URL. Returns a response object.
:param url: The URL to fetch content from.
:param timeout: The timeout in seconds.
:return: A response object.
"""
try:
response = requests.get(url, headers=LinkContentFetcher.REQUEST_HEADERS, timeout=timeout)
response.raise_for_status()
except Exception as e:
if self.raise_on_failure:
raise e
logger.warning("Couldn't retrieve content from %s", url)
response = requests.Response()
return response
def _is_valid_url(self, url: str) -> bool:
"""
Checks if a URL is valid.
:param url: The URL to check.
:return: True if the URL is valid, False otherwise.
"""
result = urlparse(url)
# schema is http or https and netloc is not empty
return all([result.scheme in ["http", "https"], result.netloc])

View File

@ -0,0 +1,368 @@
from unittest.mock import Mock, patch
import logging
import pytest
import requests
from requests import Response
from haystack import Document
from haystack.nodes import LinkContentFetcher
@pytest.fixture
def mocked_requests():
with patch("haystack.nodes.retriever.link_content.requests") as mock_requests:
mock_response = Mock()
mock_requests.get.return_value = mock_response
mock_response.status_code = 200
mock_response.text = "Sample content from webpage"
yield mock_requests
@pytest.fixture
def mocked_article_extractor():
with patch("boilerpy3.extractors.ArticleExtractor.get_content", return_value="Sample content from webpage"):
yield
@pytest.mark.unit
def test_init():
"""
Checks the initialization of the LinkContentFetcher without a preprocessor.
"""
r = LinkContentFetcher()
assert r.processor is None
assert isinstance(r.handlers, dict)
assert "html" in r.handlers
@pytest.mark.unit
def test_init_with_preprocessor():
"""
Checks the initialization of the LinkContentFetcher with a preprocessor.
"""
pre_processor_mock = Mock()
r = LinkContentFetcher(processor=pre_processor_mock)
assert r.processor == pre_processor_mock
assert isinstance(r.handlers, dict)
assert "html" in r.handlers
@pytest.mark.unit
def test_fetch(mocked_requests, mocked_article_extractor):
"""
Checks if the LinkContentFetcher is able to fetch content.
"""
url = "https://haystack.deepset.ai/"
pre_processor_mock = Mock()
pre_processor_mock.process.return_value = [Document("Sample content from webpage")]
r = LinkContentFetcher(pre_processor_mock)
result = r.fetch(url=url, doc_kwargs={"text": "Sample content from webpage"})
assert len(result) == 1
assert isinstance(result[0], Document)
assert result[0].content == "Sample content from webpage"
@pytest.mark.unit
def test_fetch_no_url(mocked_requests, mocked_article_extractor):
"""
Ensures an InvalidURL exception is raised when URL is missing.
"""
pre_processor_mock = Mock()
pre_processor_mock.process.return_value = [Document("Sample content from webpage")]
retriever_no_url = LinkContentFetcher(processor=pre_processor_mock)
with pytest.raises(requests.exceptions.InvalidURL, match="Invalid or missing URL"):
retriever_no_url.fetch(url="")
@pytest.mark.unit
def test_fetch_invalid_url(caplog, mocked_requests, mocked_article_extractor):
"""
Ensures an InvalidURL exception is raised when the URL is invalid.
"""
url = "this-is-invalid-url"
r = LinkContentFetcher()
with pytest.raises(requests.exceptions.InvalidURL):
r.fetch(url=url)
@pytest.mark.unit
def test_fetch_no_preprocessor(mocked_requests, mocked_article_extractor):
"""
Checks if the LinkContentFetcher can fetch content without a preprocessor.
"""
url = "https://www.example.com"
r = LinkContentFetcher()
result_no_preprocessor = r.fetch(url=url)
assert len(result_no_preprocessor) == 1
assert isinstance(result_no_preprocessor[0], Document)
assert result_no_preprocessor[0].content == "Sample content from webpage"
@pytest.mark.unit
def test_fetch_correct_arguments(mocked_requests, mocked_article_extractor):
"""
Ensures that requests.get is called with correct arguments.
"""
url = "https://www.example.com"
r = LinkContentFetcher()
r.fetch(url=url)
# Check the arguments that requests.get was called with
args, kwargs = mocked_requests.get.call_args
assert args[0] == url
assert kwargs["timeout"] == 3
assert kwargs["headers"] == r.REQUEST_HEADERS
# another variant
url = "https://deepset.ai"
r.fetch(url=url, timeout=10)
# Check the arguments that requests.get was called with
args, kwargs = mocked_requests.get.call_args
assert args[0] == url
assert kwargs["timeout"] == 10
assert kwargs["headers"] == r.REQUEST_HEADERS
@pytest.mark.unit
def test_fetch_default_empty_content(mocked_requests):
"""
Checks handling of content extraction returning empty content.
"""
url = "https://www.example.com"
timeout = 10
content_text = ""
r = LinkContentFetcher()
with patch("boilerpy3.extractors.ArticleExtractor.get_content", return_value=content_text):
result = r.fetch(url=url, timeout=timeout)
assert "text" not in result
assert isinstance(result, list) and len(result) == 0
@pytest.mark.unit
def test_fetch_exception_during_content_extraction_no_raise_on_failure(caplog, mocked_requests):
"""
Checks the behavior when there's an exception during content extraction, and raise_on_failure is set to False.
"""
caplog.set_level(logging.WARNING)
url = "https://www.example.com"
r = LinkContentFetcher()
with patch("boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")):
result = r.fetch(url=url)
assert "text" not in result
assert "failed to extract content from" in caplog.text
@pytest.mark.unit
def test_fetch_exception_during_content_extraction_raise_on_failure(caplog, mocked_requests):
"""
Checks the behavior when there's an exception during content extraction, and raise_on_failure is set to True.
"""
caplog.set_level(logging.WARNING)
url = "https://www.example.com"
r = LinkContentFetcher(raise_on_failure=True)
with patch("boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")):
with pytest.raises(Exception, match="Could not extract content"):
r.fetch(url=url)
@pytest.mark.unit
def test_fetch_exception_during_request_get_no_raise_on_failure(caplog):
"""
Checks the behavior when there's an exception during request.get, and raise_on_failure is set to False.
"""
caplog.set_level(logging.WARNING)
url = "https://www.example.com"
r = LinkContentFetcher()
with patch("haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()):
r.fetch(url=url)
assert f"Couldn't retrieve content from {url}" in caplog.text
@pytest.mark.unit
def test_fetch_exception_during_request_get_raise_on_failure(caplog):
"""
Checks the behavior when there's an exception during request.get, and raise_on_failure is set to True.
"""
caplog.set_level(logging.WARNING)
url = "https://www.example.com"
r = LinkContentFetcher(raise_on_failure=True)
with patch("haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()):
with pytest.raises(requests.RequestException):
r.fetch(url=url)
@pytest.mark.unit
@pytest.mark.parametrize("error_code", [403, 404, 500])
def test_handle_various_response_errors(caplog, mocked_requests, error_code: int):
"""
Tests the handling of various HTTP error responses.
"""
caplog.set_level(logging.WARNING)
url = "https://some-problematic-url.com"
# we don't throw exceptions, there might be many of them
# we log them on debug level
mock_response = Response()
mock_response.status_code = error_code
mocked_requests.get.return_value = mock_response
r = LinkContentFetcher()
docs = r.fetch(url=url)
assert f"Couldn't retrieve content from {url}" in caplog.text
assert docs == []
@pytest.mark.unit
@pytest.mark.parametrize("error_code", [403, 404, 500])
def test_handle_http_error(mocked_requests, error_code: int):
"""
Checks the behavior when there's an HTTPError raised, and raise_on_failure is set to True.
"""
url = "https://some-problematic-url.com"
# we don't throw exceptions, there might be many of them
# we log them on debug level
mock_response = Response()
mock_response.status_code = error_code
mocked_requests.get.return_value = mock_response
r = LinkContentFetcher(raise_on_failure=True)
with pytest.raises(requests.HTTPError):
r.fetch(url=url)
@pytest.mark.unit
def test_is_valid_url():
"""
Checks the _is_valid_url function with a set of valid URLs.
"""
retriever = LinkContentFetcher()
valid_urls = [
"http://www.google.com",
"https://www.google.com",
"http://google.com",
"https://google.com",
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
"http://example.com/path/to/page?name=value",
"https://example.com/path/to/page?name=value",
"http://example.com:8000",
"https://example.com:8000",
]
for url in valid_urls:
assert retriever._is_valid_url(url), f"Expected {url} to be valid"
@pytest.mark.unit
def test_is_invalid_url():
"""
Checks the _is_valid_url function with a set of invalid URLs.
"""
retriever = LinkContentFetcher()
invalid_urls = [
"http://",
"https://",
"http:",
"https:",
"www.google.com",
"google.com",
"localhost",
"127.0.0.1",
"[::1]",
"/path/to/page",
"/path/to/page?name=value",
":8000",
"example.com",
"http:///example.com",
"https:///example.com",
"",
None,
]
for url in invalid_urls:
assert not retriever._is_valid_url(url), f"Expected {url} to be invalid"
@pytest.mark.integration
def test_call_with_valid_url_on_live_web():
"""
Test that LinkContentFetcher can fetch content from a valid URL
"""
retriever = LinkContentFetcher()
docs = retriever.fetch(url="https://docs.haystack.deepset.ai/", timeout=2)
assert len(docs) >= 1
assert isinstance(docs[0], Document)
assert "Haystack" in docs[0].content
@pytest.mark.integration
def test_retrieve_with_valid_url_on_live_web():
"""
Test that LinkContentFetcher can fetch content from a valid URL using the run method
"""
retriever = LinkContentFetcher()
docs, _ = retriever.run(query="https://docs.haystack.deepset.ai/")
docs = docs["documents"]
assert len(docs) >= 1
assert isinstance(docs[0], Document)
assert "Haystack" in docs[0].content
@pytest.mark.integration
def test_retrieve_with_invalid_url():
"""
Test that LinkContentFetcher raises ValueError when trying to fetch content from an invalid URL
"""
retriever = LinkContentFetcher()
with pytest.raises(ValueError):
retriever.run(query="")
@pytest.mark.integration
def test_retrieve_batch():
"""
Test that LinkContentFetcher can fetch content from a valid URL using the retrieve_batch method
"""
retriever = LinkContentFetcher()
docs, _ = retriever.run_batch(queries=["https://docs.haystack.deepset.ai/", "https://deepset.ai/"])
assert docs
docs = docs["documents"]
# no processor is applied, so each query should return a list of documents with one entry
assert len(docs) == 2 and len(docs[0]) == 1 and len(docs[1]) == 1
# each query should return a list of documents
assert isinstance(docs[0], list) and isinstance(docs[0][0], Document)
assert isinstance(docs[1], list) and isinstance(docs[1][0], Document)