mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-14 00:54:22 +00:00
Add filtering by tags for InMemoryDocumentStore (#108)
This commit is contained in:
parent
73aed42f14
commit
72a3b70d7a
@ -223,3 +223,8 @@ You will find the Swagger API documentation at http://127.0.0.1:80/docs
|
|||||||
* Coming soon: more file formats for document upload, metrics for label quality ...
|
* Coming soon: more file formats for document upload, metrics for label quality ...
|
||||||
|
|
||||||
.. image:: https://raw.githubusercontent.com/deepset-ai/haystack/master/docs/img/annotation_tool.png
|
.. image:: https://raw.githubusercontent.com/deepset-ai/haystack/master/docs/img/annotation_tool.png
|
||||||
|
|
||||||
|
|
||||||
|
7. Development
|
||||||
|
-------------------
|
||||||
|
* Unit tests are executed by running ```tox```
|
||||||
@ -38,9 +38,4 @@ class Document(BaseModel):
|
|||||||
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
||||||
query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document")
|
query_score: Optional[float] = Field(None, description="Elasticsearch query score for a retrieved document")
|
||||||
meta: Optional[Dict[str, Any]] = Field(None, description="")
|
meta: Optional[Dict[str, Any]] = Field(None, description="")
|
||||||
|
tags: Optional[Dict[str, Any]] = Field(None, description="Tags that allow filtering of the data")
|
||||||
def __getitem__(self, item):
|
|
||||||
if item == 'text':
|
|
||||||
return self.text
|
|
||||||
if item == 'id':
|
|
||||||
return self.id
|
|
||||||
|
|||||||
@ -12,6 +12,10 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
def write_documents(self, documents):
|
def write_documents(self, documents):
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
if documents is None:
|
||||||
|
return
|
||||||
|
|
||||||
for document in documents:
|
for document in documents:
|
||||||
name = document.get("name", None)
|
name = document.get("name", None)
|
||||||
text = document.get("text", None)
|
text = document.get("text", None)
|
||||||
@ -20,10 +24,30 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
signature = name + text
|
signature = name + text
|
||||||
|
|
||||||
hash = hashlib.md5(signature.encode("utf-8")).hexdigest()
|
hash = hashlib.md5(signature.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
self.docs[hash] = document
|
self.docs[hash] = document
|
||||||
|
|
||||||
|
tags = document.get('tags', [])
|
||||||
|
|
||||||
|
self._map_tags_to_ids(hash, tags)
|
||||||
|
|
||||||
|
def _map_tags_to_ids(self, hash, tags):
|
||||||
|
if isinstance(tags, list):
|
||||||
|
for tag in tags:
|
||||||
|
if isinstance(tag, dict):
|
||||||
|
tag_keys = tag.keys()
|
||||||
|
for tag_key in tag_keys:
|
||||||
|
tag_values = tag.get(tag_key, [])
|
||||||
|
if tag_values:
|
||||||
|
for tag_value in tag_values:
|
||||||
|
comp_key = str((tag_key, tag_value))
|
||||||
|
if comp_key in self.doc_tags:
|
||||||
|
self.doc_tags[comp_key].append(hash)
|
||||||
|
else:
|
||||||
|
self.doc_tags[comp_key] = [hash]
|
||||||
|
|
||||||
def get_document_by_id(self, id):
|
def get_document_by_id(self, id):
|
||||||
return self.docs[id]
|
return self.docs[id]
|
||||||
|
|
||||||
@ -32,10 +56,27 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||||
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
|
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
|
||||||
"""
|
"""
|
||||||
pass
|
if not isinstance(tags, list):
|
||||||
|
tags = [tags]
|
||||||
|
result = self._find_ids_by_tags(tags)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _find_ids_by_tags(self, tags):
|
||||||
|
result = []
|
||||||
|
for tag in tags:
|
||||||
|
tag_keys = tag.keys()
|
||||||
|
for tag_key in tag_keys:
|
||||||
|
tag_values = tag.get(tag_key, None)
|
||||||
|
if tag_values:
|
||||||
|
for tag_value in tag_values:
|
||||||
|
comp_key = str((tag_key, tag_value))
|
||||||
|
doc_ids = self.doc_tags.get(comp_key, [])
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
result.append(self.docs.get(doc_id))
|
||||||
|
return result
|
||||||
|
|
||||||
def get_document_count(self):
|
def get_document_count(self):
|
||||||
return len(self.docs.items())
|
return len(self.docs.items())
|
||||||
|
|
||||||
def get_all_documents(self):
|
def get_all_documents(self):
|
||||||
return [Document(id=item[0], text=item[1]['text'], name=item[1]['name']) for item in self.docs.items()]
|
return [Document(id=item[0], text=item[1]['text'], name=item[1]['name'], meta=item[1].get('meta', {})) for item in self.docs.items()]
|
||||||
|
|||||||
@ -20,6 +20,7 @@ def test_elasticsearch_write_read(elasticsearch_fixture):
|
|||||||
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
|
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
|
||||||
sleep(2) # wait for documents to be available for query
|
sleep(2) # wait for documents to be available for query
|
||||||
documents = document_store.get_all_documents()
|
documents = document_store.get_all_documents()
|
||||||
|
print(documents)
|
||||||
assert len(documents) == 2
|
assert len(documents) == 2
|
||||||
assert documents[0].id
|
assert documents[0].id
|
||||||
assert documents[0].text
|
assert documents[0].text
|
||||||
|
|||||||
@ -4,4 +4,3 @@ from haystack.database.base import Document
|
|||||||
def test_document_data_access():
|
def test_document_data_access():
|
||||||
doc = Document(id=1, text="test")
|
doc = Document(id=1, text="test")
|
||||||
assert doc.text == "test"
|
assert doc.text == "test"
|
||||||
assert doc['text'] == "test"
|
|
||||||
|
|||||||
@ -1,5 +1,3 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from haystack.reader.farm import FARMReader
|
from haystack.reader.farm import FARMReader
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,14 +5,15 @@ from haystack.retriever.tfidf import TfidfRetriever
|
|||||||
|
|
||||||
def test_finder_get_answers_with_in_memory_store():
|
def test_finder_get_answers_with_in_memory_store():
|
||||||
test_docs = [
|
test_docs = [
|
||||||
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url'}},
|
||||||
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", 'meta': {'url': 'url'}},
|
||||||
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", 'meta': {'url': 'url'}}
|
||||||
]
|
]
|
||||||
|
|
||||||
from haystack.database.memory import InMemoryDocumentStore
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
document_store = InMemoryDocumentStore()
|
document_store = InMemoryDocumentStore()
|
||||||
document_store.write_documents(test_docs)
|
document_store.write_documents(test_docs)
|
||||||
|
|
||||||
retriever = TfidfRetriever(document_store=document_store)
|
retriever = TfidfRetriever(document_store=document_store)
|
||||||
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||||
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||||
@ -20,3 +21,66 @@ def test_finder_get_answers_with_in_memory_store():
|
|||||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
||||||
top_k_reader=5)
|
top_k_reader=5)
|
||||||
assert prediction is not None
|
assert prediction is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_store_get_by_tags():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url'}},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", 'meta': {'url': None}},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", 'meta': {'url': 'url'}}
|
||||||
|
]
|
||||||
|
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
|
||||||
|
docs = document_store.get_document_ids_by_tags({'has_url': 'false'})
|
||||||
|
|
||||||
|
assert docs == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_store_get_by_tag_lists_union():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url'}, 'tags': [{'tag2': ["1"]}]},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", 'meta': {'url': None}, 'tags': [{'tag1': ['1']}]},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", 'meta': {'url': 'url'}, 'tags': [{'tag2': ["1", "2"]}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
|
||||||
|
docs = document_store.get_document_ids_by_tags({'tag2': ["1"]})
|
||||||
|
|
||||||
|
assert docs == [
|
||||||
|
{'name': 'testing the finder 1', 'text': 'testing the finder with pyhton unit test 1', 'meta': {'url': 'url'}, 'tags': [{'tag2': ['1']}]},
|
||||||
|
{'name': 'testing the finder 3', 'text': 'testing the finder with pyhton unit test 3', 'meta': {'url': 'url'}, 'tags': [{'tag2': ['1', '2']}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_store_get_by_tag_lists_non_existent_tag():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url'}, 'tags': [{'tag1': ["1"]}]},
|
||||||
|
]
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
docs = document_store.get_document_ids_by_tags({'tag1': ["3"]})
|
||||||
|
assert docs == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_store_get_by_tag_lists_disjoint():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url'}, 'tags': [{'tag1': ["1"]}]},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2", 'meta': {'url': None}, 'tags': [{'tag2': ['1']}]},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3", 'meta': {'url': 'url'}, 'tags': [{'tag3': ["1", "2"]}]},
|
||||||
|
{"name": "testing the finder 4", "text": "testing the finder with pyhton unit test 3", 'meta': {'url': 'url'}, 'tags': [{'tag3': ["1", "3"]}]}
|
||||||
|
]
|
||||||
|
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
|
||||||
|
docs = document_store.get_document_ids_by_tags({'tag3': ["3"]})
|
||||||
|
|
||||||
|
assert docs == [{'name': 'testing the finder 4', 'text': 'testing the finder with pyhton unit test 3', 'meta': {'url': 'url'}, 'tags': [{'tag3': ['1', '3']}]}]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user