From bf8e506c45a9c7084e5f122d0ada706ac070e1e3 Mon Sep 17 00:00:00 2001 From: Stan Kirdey Date: Mon, 18 May 2020 05:47:41 -0700 Subject: [PATCH] Add embedding query for InMemoryDocumentStore --- haystack/database/base.py | 4 ++++ haystack/database/memory.py | 27 ++++++++++++++++++++++++++ haystack/retriever/tfidf.py | 2 -- test/test_db.py | 1 - test/test_faq_retriever.py | 38 +++++++++++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 test/test_faq_retriever.py diff --git a/haystack/database/base.py b/haystack/database/base.py index 5ede723a9..15a28500c 100644 --- a/haystack/database/base.py +++ b/haystack/database/base.py @@ -25,6 +25,10 @@ class BaseDocumentStore: def get_document_count(self): pass + @abstractmethod + def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None): + pass + class Document(BaseModel): id: str = Field(..., description="_id field from Elasticsearch") diff --git a/haystack/database/memory.py b/haystack/database/memory.py index fb7abf070..ed620e441 100644 --- a/haystack/database/memory.py +++ b/haystack/database/memory.py @@ -51,6 +51,33 @@ class InMemoryDocumentStore(BaseDocumentStore): def get_document_by_id(self, id): return self.docs[id] + def _convert_memory_hit_to_document(self, hit, doc_id=None) -> Document: + document = Document( + id=doc_id, + text=hit[0].get('text', None), + meta=hit[0].get('meta', {}), + query_score=hit[1], + ) + return document + + def query_by_embedding(self, query_emb, top_k=10, candidate_doc_ids=None) -> [Document]: + from haystack.api import config + from numpy import dot + from numpy.linalg import norm + + embedding_field_name = config.EMBEDDING_FIELD_NAME + if embedding_field_name is None: + return [] + + if query_emb is None: + return [] + + candidate_docs = [self._convert_memory_hit_to_document( + (doc, dot(query_emb, doc[embedding_field_name]) / (norm(query_emb) * norm(doc[embedding_field_name]))), doc_id=idx) for idx, doc in self.docs.items() + ] + + return sorted(candidate_docs, key=lambda x: x.query_score, reverse=True)[0:top_k] + def get_document_ids_by_tags(self, tags): """ The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} diff --git a/haystack/retriever/tfidf.py b/haystack/retriever/tfidf.py index 03257c86a..9a304ce24 100644 --- a/haystack/retriever/tfidf.py +++ b/haystack/retriever/tfidf.py @@ -72,8 +72,6 @@ class TfidfRetriever(BaseRetriever): # get scores indices_and_scores = self._calc_scores(query) - print(indices_and_scores) - # rank paragraphs df_sliced = self.df.loc[indices_and_scores.keys()] df_sliced = df_sliced[:top_k] diff --git a/test/test_db.py b/test/test_db.py index 5977ddbcf..89eba0f37 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -20,7 +20,6 @@ def test_elasticsearch_write_read(elasticsearch_fixture): write_documents_to_db(document_store=document_store, document_dir="samples/docs") sleep(2) # wait for documents to be available for query documents = document_store.get_all_documents() - print(documents) assert len(documents) == 2 assert documents[0].id assert documents[0].text diff --git a/test/test_faq_retriever.py b/test/test_faq_retriever.py new file mode 100644 index 000000000..986317faa --- /dev/null +++ b/test/test_faq_retriever.py @@ -0,0 +1,38 @@ +from haystack import Finder + + +def test_faq_retriever_in_memory_store(monkeypatch): + monkeypatch.setenv("EMBEDDING_FIELD_NAME", "embedding") + + from haystack.database.memory import InMemoryDocumentStore + from haystack.retriever.elasticsearch import EmbeddingRetriever + + document_store = InMemoryDocumentStore() + + documents = [ + {'name': 'How to test this library?', 'text': 'By running tox in the command line!', 'meta': {'question': 'How to test this library?'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + {'name': 'blah blah blah', 'text': 'By running tox in the command line!', 'meta': {'question': 'blah blah blah'}}, + ] + + retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/sentence_bert", gpu=False) + + embedded = [] + for doc in documents: + doc['embedding'] = retriever.create_embedding([doc['meta']['question']])[0] + embedded.append(doc) + + document_store.write_documents(embedded) + + finder = Finder(reader=None, retriever=retriever) + prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1) + + assert len(prediction.get('answers', [])) == 1