feat: Improve UrlCacheChecker, make it more generic (#6699)

* Rename UrlCacheChecker to CacheChecker, make it field generic

* Add release note
This commit is contained in:
Vladimir Blagojevic 2024-01-08 16:15:27 +01:00 committed by GitHub
parent ae96c2ee83
commit 9e0b58784f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 38 deletions

View File

@ -1,3 +1,3 @@
from haystack.components.caching.url_cache_checker import UrlCacheChecker
from haystack.components.caching.cache_checker import CacheChecker
__all__ = ["UrlCacheChecker"]
__all__ = ["CacheChecker"]

View File

@ -12,27 +12,27 @@ logger = logging.getLogger(__name__)
@component
class UrlCacheChecker:
class CacheChecker:
"""
A component checks for the presence of a document from a specific URL in the store. UrlCacheChecker can thus
implement caching functionality within web retrieval pipelines that use a Document Store.
CacheChecker is a component that checks for the presence of documents in a Document Store based on a specified
cache field.
"""
def __init__(self, document_store: DocumentStore, url_field: str = "url"):
def __init__(self, document_store: DocumentStore, cache_field: str):
"""
Create a UrlCacheChecker component.
"""
self.document_store = document_store
self.url_field = url_field
self.cache_field = cache_field
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, document_store=self.document_store.to_dict(), url_field=self.url_field)
return default_to_dict(self, document_store=self.document_store.to_dict(), cache_field=self.cache_field)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "UrlCacheChecker":
def from_dict(cls, data: Dict[str, Any]) -> "CacheChecker":
"""
Deserialize this component from a dictionary.
"""
@ -57,22 +57,24 @@ class UrlCacheChecker:
data["init_parameters"]["document_store"] = docstore
return default_from_dict(cls, data)
@component.output_types(hits=List[Document], misses=List[str])
def run(self, urls: List[str]):
@component.output_types(hits=List[Document], misses=List[Any])
def run(self, items: List[Any]):
"""
Checks if any document coming from the given URL is already present in the store. If matching documents are
found, they are returned. If not, the URL is returned as a miss.
Checks if any document associated with the specified field is already present in the store. If matching documents
are found, they are returned as hits. If not, the items are returned as misses, indicating they are not in the cache.
:param urls: All the URLs the documents may be coming from to hit this cache.
:param items: A list of values associated with the cache_field to be checked against the cache.
:return: A dictionary with two keys: "hits" and "misses". The values are lists of documents that were found in
the cache and items that were not, respectively.
"""
found_documents = []
missing_urls = []
misses = []
for url in urls:
filters = {self.url_field: url}
for item in items:
filters = {self.cache_field: item}
found = self.document_store.filter_documents(filters=filters)
if found:
found_documents.extend(found)
else:
missing_urls.append(url)
return {"hits": found_documents, "misses": missing_urls}
misses.append(item)
return {"hits": found_documents, "misses": misses}

View File

@ -0,0 +1,5 @@
---
enhancements:
- |
Improve the URLCacheChecker so that it can work with any type of data in the DocumentStore, not just URL caching.
Rename the component to CacheChecker.

View File

@ -3,69 +3,69 @@ import pytest
from haystack import Document, DeserializationError
from haystack.testing.factory import document_store_class
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.components.caching.url_cache_checker import UrlCacheChecker
from haystack.components.caching.cache_checker import CacheChecker
class TestUrlCacheChecker:
def test_to_dict(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
component = UrlCacheChecker(document_store=mocked_docstore_class())
component = CacheChecker(document_store=mocked_docstore_class(), cache_field="url")
data = component.to_dict()
assert data == {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"type": "haystack.components.caching.cache_checker.CacheChecker",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MockedDocumentStore", "init_parameters": {}},
"url_field": "url",
"cache_field": "url",
},
}
def test_to_dict_with_custom_init_parameters(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
component = UrlCacheChecker(document_store=mocked_docstore_class(), url_field="my_url_field")
component = CacheChecker(document_store=mocked_docstore_class(), cache_field="my_url_field")
data = component.to_dict()
assert data == {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"type": "haystack.components.caching.cache_checker.CacheChecker",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MockedDocumentStore", "init_parameters": {}},
"url_field": "my_url_field",
"cache_field": "my_url_field",
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"type": "haystack.components.caching.cache_checker.CacheChecker",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"url_field": "my_url_field",
"cache_field": "my_url_field",
},
}
component = UrlCacheChecker.from_dict(data)
component = CacheChecker.from_dict(data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.url_field == "my_url_field"
assert component.cache_field == "my_url_field"
def test_from_dict_without_docstore(self):
data = {"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker", "init_parameters": {}}
data = {"type": "haystack.components.caching.cache_checker.CacheChecker", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
UrlCacheChecker.from_dict(data)
CacheChecker.from_dict(data)
def test_from_dict_without_docstore_type(self):
data = {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
"init_parameters": {"document_store": {"init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
UrlCacheChecker.from_dict(data)
CacheChecker.from_dict(data)
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"type": "haystack.components.caching.cache_checker.UrlCacheChecker",
"init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError):
UrlCacheChecker.from_dict(data)
CacheChecker.from_dict(data)
def test_run(self):
docstore = InMemoryDocumentStore()
@ -76,6 +76,6 @@ class TestUrlCacheChecker:
Document(content="doc4", meta={"url": "https://example.com/2"}),
]
docstore.write_documents(documents)
checker = UrlCacheChecker(docstore)
results = checker.run(urls=["https://example.com/1", "https://example.com/5"])
checker = CacheChecker(docstore, cache_field="url")
results = checker.run(items=["https://example.com/1", "https://example.com/5"])
assert results == {"hits": [documents[0], documents[2]], "misses": ["https://example.com/5"]}