mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
Add embedding query for InMemoryDocumentStore
This commit is contained in:
parent
5eee61a47b
commit
bf8e506c45
@ -25,6 +25,10 @@ class BaseDocumentStore:
|
||||
def get_document_count(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None):
|
||||
pass
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(..., description="_id field from Elasticsearch")
|
||||
|
||||
@ -51,6 +51,33 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
def get_document_by_id(self, id):
|
||||
return self.docs[id]
|
||||
|
||||
def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document:
|
||||
document = Document(
|
||||
id=doc_id,
|
||||
text=hit[0].get('text', None),
|
||||
meta=hit[0].get('meta', {}),
|
||||
query_score=hit[1],
|
||||
)
|
||||
return document
|
||||
|
||||
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]:
|
||||
from haystack.api import config
|
||||
from numpy import dot
|
||||
from numpy.linalg import norm
|
||||
|
||||
embedding_field_name = config.EMBEDDING_FIELD_NAME
|
||||
if embedding_field_name is None:
|
||||
return []
|
||||
|
||||
if query_emb is None:
|
||||
return []
|
||||
|
||||
candidate_docs = [self._convert_memory_hit_to_document(
|
||||
(doc, dot(query_emb, doc[embedding_field_name]) / (norm(query_emb) * norm(doc[embedding_field_name]))), doc_id=idx) for idx, doc in self.docs.items()
|
||||
]
|
||||
|
||||
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
|
||||
|
||||
def get_document_ids_by_tags(self, tags):
|
||||
"""
|
||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||
|
||||
@ -72,8 +72,6 @@ class TfidfRetriever(BaseRetriever):
|
||||
# get scores
|
||||
indices_and_scores = self._calc_scores(query)
|
||||
|
||||
print(indices_and_scores)
|
||||
|
||||
# rank paragraphs
|
||||
df_sliced = self.df.loc[indices_and_scores.keys()]
|
||||
df_sliced = df_sliced[:top_k]
|
||||
|
||||
@ -20,7 +20,6 @@ def test_elasticsearch_write_read(elasticsearch_fixture):
|
||||
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
|
||||
sleep(2) # wait for documents to be available for query
|
||||
documents = document_store.get_all_documents()
|
||||
print(documents)
|
||||
assert len(documents) == 2
|
||||
assert documents[0].id
|
||||
assert documents[0].text
|
||||
|
||||
38
test/test_faq_retriever.py
Normal file
38
test/test_faq_retriever.py
Normal file
@ -0,0 +1,38 @@
|
||||
from haystack import Finder
|
||||
|
||||
|
||||
def test_faq_retriever_in_memory_store(monkeypatch):
|
||||
monkeypatch.setenv("EMBEDDING_FIELD_NAME", "embedding")
|
||||
|
||||
from haystack.database.memory import InMemoryDocumentStore
|
||||
from haystack.retriever.elasticsearch import EmbeddingRetriever
|
||||
|
||||
document_store = InMemoryDocumentStore()
|
||||
|
||||
documents = [
|
||||
{'name': 'How to test this library?', 'text': 'By running tox in the command line!', 'meta': {'question': 'How to test this library?'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||
]
|
||||
|
||||
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", gpu=False)
|
||||
|
||||
embedded = []
|
||||
for doc in documents:
|
||||
doc['embedding'] = retriever.create_embedding([doc['meta']['question']])[0]
|
||||
embedded.append(doc)
|
||||
|
||||
document_store.write_documents(embedded)
|
||||
|
||||
finder = Finder(reader=None, retriever=retriever)
|
||||
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||
|
||||
assert len(prediction.get('answers', [])) == 1
|
||||
Loading…
x
Reference in New Issue
Block a user