mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 16:43:44 +00:00
Add embedding query for InMemoryDocumentStore
This commit is contained in:
parent
5eee61a47b
commit
bf8e506c45
@ -25,6 +25,10 @@ class BaseDocumentStore:
|
|||||||
def get_document_count(self):
|
def get_document_count(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(BaseModel):
|
||||||
id: str = Field(..., description="_id field from Elasticsearch")
|
id: str = Field(..., description="_id field from Elasticsearch")
|
||||||
|
|||||||
@ -51,6 +51,33 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
def get_document_by_id(self, id):
|
def get_document_by_id(self, id):
|
||||||
return self.docs[id]
|
return self.docs[id]
|
||||||
|
|
||||||
|
def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document:
|
||||||
|
document = Document(
|
||||||
|
id=doc_id,
|
||||||
|
text=hit[0].get('text', None),
|
||||||
|
meta=hit[0].get('meta', {}),
|
||||||
|
query_score=hit[1],
|
||||||
|
)
|
||||||
|
return document
|
||||||
|
|
||||||
|
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]:
|
||||||
|
from haystack.api import config
|
||||||
|
from numpy import dot
|
||||||
|
from numpy.linalg import norm
|
||||||
|
|
||||||
|
embedding_field_name = config.EMBEDDING_FIELD_NAME
|
||||||
|
if embedding_field_name is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if query_emb is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
candidate_docs = [self._convert_memory_hit_to_document(
|
||||||
|
(doc, dot(query_emb, doc[embedding_field_name]) / (norm(query_emb) * norm(doc[embedding_field_name]))), doc_id=idx) for idx, doc in self.docs.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
|
||||||
|
|
||||||
def get_document_ids_by_tags(self, tags):
|
def get_document_ids_by_tags(self, tags):
|
||||||
"""
|
"""
|
||||||
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}
|
||||||
|
|||||||
@ -72,8 +72,6 @@ class TfidfRetriever(BaseRetriever):
|
|||||||
# get scores
|
# get scores
|
||||||
indices_and_scores = self._calc_scores(query)
|
indices_and_scores = self._calc_scores(query)
|
||||||
|
|
||||||
print(indices_and_scores)
|
|
||||||
|
|
||||||
# rank paragraphs
|
# rank paragraphs
|
||||||
df_sliced = self.df.loc[indices_and_scores.keys()]
|
df_sliced = self.df.loc[indices_and_scores.keys()]
|
||||||
df_sliced = df_sliced[:top_k]
|
df_sliced = df_sliced[:top_k]
|
||||||
|
|||||||
@ -20,7 +20,6 @@ def test_elasticsearch_write_read(elasticsearch_fixture):
|
|||||||
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
|
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
|
||||||
sleep(2) # wait for documents to be available for query
|
sleep(2) # wait for documents to be available for query
|
||||||
documents = document_store.get_all_documents()
|
documents = document_store.get_all_documents()
|
||||||
print(documents)
|
|
||||||
assert len(documents) == 2
|
assert len(documents) == 2
|
||||||
assert documents[0].id
|
assert documents[0].id
|
||||||
assert documents[0].text
|
assert documents[0].text
|
||||||
|
|||||||
38
test/test_faq_retriever.py
Normal file
38
test/test_faq_retriever.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from haystack import Finder
|
||||||
|
|
||||||
|
|
||||||
|
def test_faq_retriever_in_memory_store(monkeypatch):
|
||||||
|
monkeypatch.setenv("EMBEDDING_FIELD_NAME", "embedding")
|
||||||
|
|
||||||
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
|
from haystack.retriever.elasticsearch import EmbeddingRetriever
|
||||||
|
|
||||||
|
document_store = InMemoryDocumentStore()
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
{'name': 'How to test this library?', 'text': 'By running tox in the command line!', 'meta': {'question': 'How to test this library?'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
|
||||||
|
]
|
||||||
|
|
||||||
|
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", gpu=False)
|
||||||
|
|
||||||
|
embedded = []
|
||||||
|
for doc in documents:
|
||||||
|
doc['embedding'] = retriever.create_embedding([doc['meta']['question']])[0]
|
||||||
|
embedded.append(doc)
|
||||||
|
|
||||||
|
document_store.write_documents(embedded)
|
||||||
|
|
||||||
|
finder = Finder(reader=None, retriever=retriever)
|
||||||
|
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||||
|
|
||||||
|
assert len(prediction.get('answers', [])) == 1
|
||||||
Loading…
x
Reference in New Issue
Block a user