mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 21:33:39 +00:00
Add InMemoryDocumentStore (#76)
This commit is contained in:
parent
a78659f234
commit
6038d40a53
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,3 +1,8 @@
|
|||||||
|
# Local run files
|
||||||
|
qa.db
|
||||||
|
**/qa.db
|
||||||
|
**/*qa*.db
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class BaseDocumentStore:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_document_ids_by_tag(self, tag):
|
def get_document_ids_by_tags(self, tag):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -38,3 +38,9 @@ class Document(BaseModel):
|
|||||||
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
question: Optional[str] = Field(None, description="Question text for FAQs.")
|
||||||
query_score: Optional[int] = Field(None, description="Elasticsearch query score for a retrieved document")
|
query_score: Optional[int] = Field(None, description="Elasticsearch query score for a retrieved document")
|
||||||
meta: Optional[Dict[str, Optional[str]]] = Field(None, description="")
|
meta: Optional[Dict[str, Optional[str]]] = Field(None, description="")
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
if item == 'text':
|
||||||
|
return self.text
|
||||||
|
if item == 'id':
|
||||||
|
return self.id
|
||||||
|
|||||||
41
haystack/database/memory.py
Normal file
41
haystack/database/memory.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from haystack.database.base import BaseDocumentStore, Document
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
|
"""
|
||||||
|
In-memory document store
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.docs = {}
|
||||||
|
self.doc_tags = {}
|
||||||
|
|
||||||
|
def write_documents(self, documents):
|
||||||
|
import hashlib
|
||||||
|
for document in documents:
|
||||||
|
name = document.get("name", None)
|
||||||
|
text = document.get("text", None)
|
||||||
|
|
||||||
|
if name is None or text is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
signature = name + text
|
||||||
|
hash = hashlib.md5(signature.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
self.docs[hash] = document
|
||||||
|
|
||||||
|
def get_document_by_id(self, id):
|
||||||
|
return self.docs[id]
|
||||||
|
|
||||||
|
def get_document_ids_by_tags(self, tags):
|
||||||
|
"""
|
||||||
|
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||||
|
The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_document_count(self):
|
||||||
|
return len(self.docs.items())
|
||||||
|
|
||||||
|
def get_all_documents(self):
|
||||||
|
return [Document(id=item[0], text=item[1]['text'], name=item[1]['name']) for item in self.docs.items()]
|
||||||
@ -2,6 +2,7 @@ import logging
|
|||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from haystack.database.base import Document
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,7 +38,6 @@ class Finder:
|
|||||||
results = {"question": question, "answers": []}
|
results = {"question": question, "answers": []}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# 2) Apply reader to get granular answer(s)
|
|
||||||
len_chars = sum([len(d.text) for d in documents])
|
len_chars = sum([len(d.text) for d in documents])
|
||||||
logger.info(f"Reader is looking for detailed answer in {len_chars} chars ...")
|
logger.info(f"Reader is looking for detailed answer in {len_chars} chars ...")
|
||||||
results = self.reader.predict(question=question,
|
results = self.reader.predict(question=question,
|
||||||
|
|||||||
@ -8,4 +8,5 @@ psycopg2-binary
|
|||||||
sklearn
|
sklearn
|
||||||
elasticsearch
|
elasticsearch
|
||||||
elastic-apm
|
elastic-apm
|
||||||
|
tox
|
||||||
# optional: sentence-transformers
|
# optional: sentence-transformers
|
||||||
|
|||||||
40
test/test_finder.py
Normal file
40
test/test_finder.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from haystack import Finder
|
||||||
|
from haystack.database.sql import SQLDocumentStore
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
from haystack.retriever.tfidf import TfidfRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def test_finder_get_answers():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
||||||
|
]
|
||||||
|
|
||||||
|
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
retriever = TfidfRetriever(document_store=document_store)
|
||||||
|
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||||
|
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||||
|
finder = Finder(reader, retriever)
|
||||||
|
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
||||||
|
top_k_reader=5)
|
||||||
|
assert prediction is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_finder_get_answers_single_result():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
||||||
|
]
|
||||||
|
|
||||||
|
document_store = SQLDocumentStore(url="sqlite:///qa_test.db")
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
retriever = TfidfRetriever(document_store=document_store)
|
||||||
|
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||||
|
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||||
|
finder = Finder(reader, retriever)
|
||||||
|
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
|
||||||
|
top_k_reader=1)
|
||||||
|
assert prediction is not None
|
||||||
19
test/test_imports.py
Normal file
19
test/test_imports.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
def test_module_imports():
|
||||||
|
from haystack import Finder
|
||||||
|
from haystack.database.sql import SQLDocumentStore
|
||||||
|
from haystack.indexing.cleaning import clean_wiki_text
|
||||||
|
from haystack.indexing.io import write_documents_to_db, fetch_archive_from_http
|
||||||
|
from haystack.reader.farm import FARMReader
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
from haystack.retriever.tfidf import TfidfRetriever
|
||||||
|
from haystack.utils import print_answers
|
||||||
|
|
||||||
|
assert Finder is not None
|
||||||
|
assert SQLDocumentStore is not None
|
||||||
|
assert clean_wiki_text is not None
|
||||||
|
assert write_documents_to_db is not None
|
||||||
|
assert fetch_archive_from_http is not None
|
||||||
|
assert FARMReader is not None
|
||||||
|
assert TransformersReader is not None
|
||||||
|
assert TfidfRetriever is not None
|
||||||
|
assert print_answers is not None
|
||||||
22
test/test_in_memory_store.py
Normal file
22
test/test_in_memory_store.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from haystack import Finder
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
from haystack.retriever.tfidf import TfidfRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def test_finder_get_answers_with_in_memory_store():
|
||||||
|
test_docs = [
|
||||||
|
{"name": "testing the finder 1", "text": "testing the finder with pyhton unit test 1"},
|
||||||
|
{"name": "testing the finder 2", "text": "testing the finder with pyhton unit test 2"},
|
||||||
|
{"name": "testing the finder 3", "text": "testing the finder with pyhton unit test 3"}
|
||||||
|
]
|
||||||
|
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
document_store.write_documents(test_docs)
|
||||||
|
retriever = TfidfRetriever(document_store=document_store)
|
||||||
|
reader = TransformersReader(model="distilbert-base-uncased-distilled-squad",
|
||||||
|
tokenizer="distilbert-base-uncased", use_gpu=-1)
|
||||||
|
finder = Finder(reader, retriever)
|
||||||
|
prediction = finder.get_answers(question="testing finder", top_k_retriever=10,
|
||||||
|
top_k_reader=5)
|
||||||
|
assert prediction is not None
|
||||||
Loading…
x
Reference in New Issue
Block a user