mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 16:08:38 +00:00
feat: Add filter_policy init parameter to in memory retrievers (#7795)
* Add filter_policy init parameter to in-memory retrievers
This commit is contained in:
parent
fd838fc573
commit
678f193f10
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
|
||||
|
||||
@component
|
||||
@ -40,6 +41,7 @@ class InMemoryBM25Retriever:
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
scale_score: bool = False,
|
||||
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
|
||||
):
|
||||
"""
|
||||
Create the InMemoryBM25Retriever component.
|
||||
@ -52,7 +54,7 @@ class InMemoryBM25Retriever:
|
||||
The maximum number of documents to retrieve.
|
||||
:param scale_score:
|
||||
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
|
||||
|
||||
:param filter_policy: The filter policy to apply during retrieval.
|
||||
:raises ValueError:
|
||||
If the specified `top_k` is not > 0.
|
||||
"""
|
||||
@ -67,6 +69,7 @@ class InMemoryBM25Retriever:
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
self.scale_score = scale_score
|
||||
self.filter_policy = filter_policy
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -83,7 +86,12 @@ class InMemoryBM25Retriever:
|
||||
"""
|
||||
docstore = self.document_store.to_dict()
|
||||
return default_to_dict(
|
||||
self, document_store=docstore, filters=self.filters, top_k=self.top_k, scale_score=self.scale_score
|
||||
self,
|
||||
document_store=docstore,
|
||||
filters=self.filters,
|
||||
top_k=self.top_k,
|
||||
scale_score=self.scale_score,
|
||||
filter_policy=self.filter_policy.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -101,6 +109,8 @@ class InMemoryBM25Retriever:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if "filter_policy" in init_params:
|
||||
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
|
||||
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
|
||||
data["init_parameters"]["document_store"]
|
||||
)
|
||||
@ -132,8 +142,10 @@ class InMemoryBM25Retriever:
|
||||
:raises ValueError:
|
||||
If the specified DocumentStore is not found or is not a InMemoryDocumentStore instance.
|
||||
"""
|
||||
if filters is None:
|
||||
filters = self.filters
|
||||
if self.filter_policy == FilterPolicy.MERGE and filters:
|
||||
filters = {**(self.filters or {}), **filters}
|
||||
else:
|
||||
filters = filters or self.filters
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
if scale_score is None:
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
|
||||
|
||||
@component
|
||||
@ -50,6 +51,7 @@ class InMemoryEmbeddingRetriever:
|
||||
top_k: int = 10,
|
||||
scale_score: bool = False,
|
||||
return_embedding: bool = False,
|
||||
filter_policy: FilterPolicy = FilterPolicy.REPLACE,
|
||||
):
|
||||
"""
|
||||
Create the InMemoryEmbeddingRetriever component.
|
||||
@ -64,7 +66,7 @@ class InMemoryEmbeddingRetriever:
|
||||
Scales the BM25 score to a unit interval in the range of 0 to 1, where 1 means extremely relevant. If set to `False`, uses raw similarity scores.
|
||||
:param return_embedding:
|
||||
Whether to return the embedding of the retrieved Documents.
|
||||
|
||||
:param filter_policy: The filter policy to apply during retrieval.
|
||||
:raises ValueError:
|
||||
If the specified top_k is not > 0.
|
||||
"""
|
||||
@ -80,6 +82,7 @@ class InMemoryEmbeddingRetriever:
|
||||
self.top_k = top_k
|
||||
self.scale_score = scale_score
|
||||
self.return_embedding = return_embedding
|
||||
self.filter_policy = filter_policy
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -102,6 +105,7 @@ class InMemoryEmbeddingRetriever:
|
||||
top_k=self.top_k,
|
||||
scale_score=self.scale_score,
|
||||
return_embedding=self.return_embedding,
|
||||
filter_policy=self.filter_policy.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -119,6 +123,8 @@ class InMemoryEmbeddingRetriever:
|
||||
raise DeserializationError("Missing 'document_store' in serialization data")
|
||||
if "type" not in init_params["document_store"]:
|
||||
raise DeserializationError("Missing 'type' in document store's serialization data")
|
||||
if "filter_policy" in init_params:
|
||||
init_params["filter_policy"] = FilterPolicy.from_str(init_params["filter_policy"])
|
||||
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
|
||||
data["init_parameters"]["document_store"]
|
||||
)
|
||||
@ -153,8 +159,10 @@ class InMemoryEmbeddingRetriever:
|
||||
:raises ValueError:
|
||||
If the specified DocumentStore is not found or is not an InMemoryDocumentStore instance.
|
||||
"""
|
||||
if filters is None:
|
||||
filters = self.filters
|
||||
if self.filter_policy == FilterPolicy.MERGE and filters:
|
||||
filters = {**(self.filters or {}), **filters}
|
||||
else:
|
||||
filters = filters or self.filters
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
if scale_score is None:
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .filter_policy import FilterPolicy
|
||||
from .policy import DuplicatePolicy
|
||||
from .protocol import DocumentStore
|
||||
|
||||
__all__ = ["DocumentStore", "DuplicatePolicy"]
|
||||
__all__ = ["DocumentStore", "DuplicatePolicy", "FilterPolicy"]
|
||||
|
||||
35
haystack/document_stores/types/filter_policy.py
Normal file
35
haystack/document_stores/types/filter_policy.py
Normal file
@ -0,0 +1,35 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FilterPolicy(Enum):
|
||||
"""
|
||||
Policy to determine how filters are applied in retrievers interacting with document stores.
|
||||
"""
|
||||
|
||||
# Runtime filters replace init filters during retriever run invocation.
|
||||
REPLACE = "replace"
|
||||
|
||||
# Runtime filters are merged with init filters, with runtime filters overwriting init values.
|
||||
MERGE = "merge"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def from_str(filter_policy: str) -> "FilterPolicy":
|
||||
"""
|
||||
Convert a string to a FilterPolicy enum.
|
||||
|
||||
:param filter_policy: The string to convert.
|
||||
:return: The corresponding FilterPolicy enum.
|
||||
"""
|
||||
enum_map = {e.value: e for e in FilterPolicy}
|
||||
policy = enum_map.get(filter_policy)
|
||||
if policy is None:
|
||||
msg = f"Unknown FilterPolicy type '{filter_policy}'. Supported types are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return policy
|
||||
@ -1,4 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.
|
||||
Provides users the ability to customize text extraction from PDF files. It is particularly useful for PDFs with unusual layouts, such as those containing multiple text columns. For instance, users can configure the object to retain the reading order.
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Introduced a 'filter_policy' init parameter for both InMemoryBM25Retriever and InMemoryEmbeddingRetriever, allowing users to define how runtime filters should be applied with options to either 'replace' the initial filters or 'merge' them, providing greater flexibility in filtering query results.
|
||||
@ -6,6 +6,7 @@ from typing import Dict, Any
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline, DeserializationError
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
from haystack.testing.factory import document_store_class
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
||||
from haystack.dataclasses import Document
|
||||
@ -56,6 +57,7 @@ class TestMemoryBM25Retriever:
|
||||
"filters": None,
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
|
||||
@ -77,6 +79,7 @@ class TestMemoryBM25Retriever:
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
"scale_score": True,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
|
||||
@ -99,6 +102,7 @@ class TestMemoryBM25Retriever:
|
||||
assert component.filters == {"name": "test.txt"}
|
||||
assert component.top_k == 5
|
||||
assert component.scale_score is False
|
||||
assert component.filter_policy == FilterPolicy.REPLACE
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
|
||||
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
import numpy as np
|
||||
|
||||
from haystack import Pipeline, DeserializationError
|
||||
from haystack.document_stores.types import FilterPolicy
|
||||
from haystack.testing.factory import document_store_class
|
||||
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
|
||||
from haystack.dataclasses import Document
|
||||
@ -47,6 +48,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
"top_k": 10,
|
||||
"scale_score": False,
|
||||
"return_embedding": False,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
|
||||
@ -70,6 +72,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
"top_k": 5,
|
||||
"scale_score": True,
|
||||
"return_embedding": True,
|
||||
"filter_policy": "replace",
|
||||
},
|
||||
}
|
||||
|
||||
@ -83,6 +86,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
},
|
||||
"filters": {"name": "test.txt"},
|
||||
"top_k": 5,
|
||||
"filter_policy": "merge",
|
||||
},
|
||||
}
|
||||
component = InMemoryEmbeddingRetriever.from_dict(data)
|
||||
@ -90,6 +94,7 @@ class TestMemoryEmbeddingRetriever:
|
||||
assert component.filters == {"name": "test.txt"}
|
||||
assert component.top_k == 5
|
||||
assert component.scale_score is False
|
||||
assert component.filter_policy == FilterPolicy.MERGE
|
||||
|
||||
def test_from_dict_without_docstore(self):
|
||||
data = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user