feat: Add filter_policy init parameter to in memory retrievers (#7795)

* Add filter_policy init parameter to in-memory retrievers
This commit is contained in:
Vladimir Blagojevic 2024-06-04 17:51:16 +02:00 committed by GitHub
parent fd838fc573
commit 678f193f10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 9 deletions

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import FilterPolicy
@component
@ -40,6 +41,7 @@ class InMemoryBM25Retriever:
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = False,
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
):
"""
Create the InMemoryBM25Retriever component.
@ -52,7 +54,7 @@ class InMemoryBM25Retriever:
The maximum number of documents to retrieve.
:param scale_score:
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
:param filter_policy: The filter policy to apply during retrieval.
:raises ValueError:
If the specified `top_k` is not > 0.
"""
@ -67,6 +69,7 @@ class InMemoryBM25Retriever:
self.filters = filters
self.top_k = top_k
self.scale_score = scale_score
self.filter_policy = filter_policy
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -83,7 +86,12 @@ class InMemoryBM25Retriever:
"""
docstore = self.document_store.to_dict()
return default_to_dict(
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
self,
document_store=docstore,
filters=self.filters,
top_k=self.top_k,
scale_score=self.scale_score,
filter_policy=self.filter_policy.value,
)
@classmethod
@ -101,6 +109,8 @@ class InMemoryBM25Retriever:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if "filter_policy" in init_params:
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
@ -132,8 +142,10 @@ class InMemoryBM25Retriever:
:raises ValueError:
If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if self.filter_policy == FilterPolicy.MERGE and filters:
filters = {**(self.filters or {}), **filters}
else:
filters = filters or self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import FilterPolicy
@component
@ -50,6 +51,7 @@ class InMemoryEmbeddingRetriever:
top_k: int = 10,
scale_score: bool = False,
return_embedding: bool = False,
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
):
"""
Create the InMemoryEmbeddingRetriever component.
@ -64,7 +66,7 @@ class InMemoryEmbeddingRetriever:
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
:param return_embedding:
Whether to return the embedding of the retrieved Documents.
:param filter_policy: The filter policy to apply during retrieval.
:raises ValueError:
If the specified top_k is not > 0.
"""
@ -80,6 +82,7 @@ class InMemoryEmbeddingRetriever:
self.top_k = top_k
self.scale_score = scale_score
self.return_embedding = return_embedding
self.filter_policy = filter_policy
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -102,6 +105,7 @@ class InMemoryEmbeddingRetriever:
top_k=self.top_k,
scale_score=self.scale_score,
return_embedding=self.return_embedding,
filter_policy=self.filter_policy.value,
)
@classmethod
@ -119,6 +123,8 @@ class InMemoryEmbeddingRetriever:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if "filter_policy" in init_params:
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
@ -153,8 +159,10 @@ class InMemoryEmbeddingRetriever:
:raises ValueError:
If the specified DocumentStore is not found or is not an InMemoryDocumentStore instance.
"""
if filters is None:
filters = self.filters
if self.filter_policy == FilterPolicy.MERGE and filters:
filters = {**(self.filters or {}), **filters}
else:
filters = filters or self.filters
if top_k is None:
top_k = self.top_k
if scale_score is None:

View File

@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
from .filter_policy import FilterPolicy
from .policy import DuplicatePolicy
from .protocol import DocumentStore
__all__ = ["DocumentStore", "DuplicatePolicy"]
__all__ = ["DocumentStore", "DuplicatePolicy", "FilterPolicy"]

View File

@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from enum import Enum
class FilterPolicy(Enum):
"""
Policy to determine how filters are applied in retrievers interacting with document stores.
"""
# Runtime filters replace init filters during retriever run invocation.
REPLACE = "replace"
# Runtime filters are merged with init filters, with runtime filters overwriting init values.
MERGE = "merge"
def __str__(self):
return self.value
@staticmethod
def from_str(filter_policy: str) -> "FilterPolicy":
"""
Convert a string to a FilterPolicy enum.
:param filter_policy: The string to convert.
:return: The corresponding FilterPolicy enum.
"""
enum_map = {e.value: e for e in FilterPolicy}
policy = enum_map.get(filter_policy)
if policy is None:
msg = f"Unknown FilterPolicy type '{filter_policy}'. Supported types are: {list(enum_map.keys())}"
raise ValueError(msg)
return policy

View File

@ -1,4 +1,4 @@
---
enhancements:
- |
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Introduced a 'filter_policy' init parameter for both InMemoryBM25Retriever and InMemoryEmbeddingRetriever, allowing users to define how runtime filters should be applied with options to either 'replace' the initial filters or 'merge' them, providing greater flexibility in filtering query results.

View File

@ -6,6 +6,7 @@ from typing import Dict, Any
import pytest
from haystack import Pipeline, DeserializationError
from haystack.document_stores.types import FilterPolicy
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.dataclasses import Document
@ -56,6 +57,7 @@ class TestMemoryBM25Retriever:
"filters": None,
"top_k": 10,
"scale_score": False,
"filter_policy": "replace",
},
}
@ -77,6 +79,7 @@ class TestMemoryBM25Retriever:
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": True,
"filter_policy": "replace",
},
}
@ -99,6 +102,7 @@ class TestMemoryBM25Retriever:
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score is False
assert component.filter_policy == FilterPolicy.REPLACE
def test_from_dict_without_docstore(self):
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}

View File

@ -7,6 +7,7 @@ import pytest
import numpy as np
from haystack import Pipeline, DeserializationError
from haystack.document_stores.types import FilterPolicy
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
from haystack.dataclasses import Document
@ -47,6 +48,7 @@ class TestMemoryEmbeddingRetriever:
"top_k": 10,
"scale_score": False,
"return_embedding": False,
"filter_policy": "replace",
},
}
@ -70,6 +72,7 @@ class TestMemoryEmbeddingRetriever:
"top_k": 5,
"scale_score": True,
"return_embedding": True,
"filter_policy": "replace",
},
}
@ -83,6 +86,7 @@ class TestMemoryEmbeddingRetriever:
},
"filters": {"name": "test.txt"},
"top_k": 5,
"filter_policy": "merge",
},
}
component = InMemoryEmbeddingRetriever.from_dict(data)
@ -90,6 +94,7 @@ class TestMemoryEmbeddingRetriever:
assert component.filters == {"name": "test.txt"}
assert component.top_k == 5
assert component.scale_score is False
assert component.filter_policy == FilterPolicy.MERGE
def test_from_dict_without_docstore(self):
data = {