Add embedding query for InMemoryDocumentStore

This commit is contained in:
Stan Kirdey 2020-05-18 05:47:41 -07:00 committed by GitHub
parent 5eee61a47b
commit bf8e506c45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 3 deletions

View File

@ -25,6 +25,10 @@ class BaseDocumentStore:
def get_document_count(self):
pass
@abstractmethod
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None):
pass
class Document(BaseModel):
id: str = Field(..., description="_id field from Elasticsearch")

View File

@ -51,6 +51,33 @@ class InMemoryDocumentStore(BaseDocumentStore):
def get_document_by_id(self, id):
return self.docs[id]
def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document:
document = Document(
id=doc_id,
text=hit[0].get('text', None),
meta=hit[0].get('meta', {}),
query_score=hit[1],
)
return document
def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]:
from haystack.api import config
from numpy import dot
from numpy.linalg import norm
embedding_field_name = config.EMBEDDING_FIELD_NAME
if embedding_field_name is None:
return []
if query_emb is None:
return []
candidate_docs = [self._convert_memory_hit_to_document(
(doc, dot(query_emb, doc[embedding_field_name]) / (norm(query_emb) * norm(doc[embedding_field_name]))), doc_id=idx) for idx, doc in self.docs.items()
]
return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k]
def get_document_ids_by_tags(self, tags):
"""
The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...}

View File

@ -72,8 +72,6 @@ class TfidfRetriever(BaseRetriever):
# get scores
indices_and_scores = self._calc_scores(query)
print(indices_and_scores)
# rank paragraphs
df_sliced = self.df.loc[indices_and_scores.keys()]
df_sliced = df_sliced[:top_k]

View File

@ -20,7 +20,6 @@ def test_elasticsearch_write_read(elasticsearch_fixture):
write_documents_to_db(document_store=document_store, document_dir="samples/docs")
sleep(2) # wait for documents to be available for query
documents = document_store.get_all_documents()
print(documents)
assert len(documents) == 2
assert documents[0].id
assert documents[0].text

View File

@ -0,0 +1,38 @@
from haystack import Finder
def test_faq_retriever_in_memory_store(monkeypatch):
monkeypatch.setenv("EMBEDDING_FIELD_NAME", "embedding")
from haystack.database.memory import InMemoryDocumentStore
from haystack.retriever.elasticsearch import EmbeddingRetriever
document_store = InMemoryDocumentStore()
documents = [
{'name': 'How to test this library?', 'text': 'By running tox in the command line!', 'meta': {'question': 'How to test this library?'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
{'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}},
]
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", gpu=False)
embedded = []
for doc in documents:
doc['embedding'] = retriever.create_embedding([doc['meta']['question']])[0]
embedded.append(doc)
document_store.write_documents(embedded)
finder = Finder(reader=None, retriever=retriever)
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
assert len(prediction.get('answers', [])) == 1