Custom fields for indexing in ElasticsearchDocumentStore (#297)

This commit is contained in:
Karim Jana 2020-08-10 05:34:39 -04:00 committed by GitHub
parent 2d27f19a71
commit c7078a36c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 9 deletions

View File

@ -44,23 +44,34 @@ class Document:
self.meta = meta
self.embedding = embedding
def to_dict(self):
return self.__dict__
def to_dict(self, field_map={}):
inv_field_map = {v:k for k, v in field_map.items()}
_doc: Dict[str, str] = {}
for k, v in self.__dict__.items():
k = k if k not in inv_field_map else inv_field_map[k]
_doc[k] = v
return _doc
@classmethod
def from_dict(cls, dict):
def from_dict(cls, dict, field_map={}):
_doc = dict.copy()
init_args = ["text", "id", "query_score", "question", "meta", "embedding"]
if "meta" not in _doc.keys():
_doc["meta"] = {}
# copy additional fields into "meta"
for k, v in _doc.items():
if k not in init_args:
if k not in init_args and k not in field_map:
_doc["meta"][k] = v
# remove additional fields from top level
_doc = {k: v for k, v in _doc.items() if k in init_args}
_new_doc = {}
for k, v in _doc.items():
if k in init_args:
_new_doc[k] = v
elif k in field_map:
k = field_map[k]
_new_doc[k] = v
return cls(**_doc)
return cls(**_new_doc)
class Label:

View File

@ -128,6 +128,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
}
self.client.indices.create(index=index_name, ignore=400, body=mapping)
# TODO: Add flexibility to define other non-meta and meta fields expected by the Document class
def _create_document_field_map(self) -> Dict:
return {
self.text_field: "text",
self.embedding_field: "embedding",
self.faq_question_field if self.faq_question_field else "question": "question"
}
def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
index = index or self.index
documents = self.get_documents_by_id([id], index=index)
@ -167,7 +175,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
index = self.index
# Make sure we comply to Document class format
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
documents_objects = [Document.from_dict(d, field_map=self._create_document_field_map())
if isinstance(d, dict) else d for d in documents]
documents_to_index = []
for doc in documents_objects:
@ -175,7 +184,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
_doc = {
"_op_type": "index" if self.update_existing_documents else "create",
"_index": index,
**doc.to_dict()
**doc.to_dict(field_map=self._create_document_field_map())
} # type: Dict[str, Any]
# rename id for elastic

View File

@ -20,6 +20,7 @@ DB_INDEX = os.getenv("DB_INDEX", "document")
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label")
ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http")
TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text")
NAME_FIELD_NAME = os.getenv("NAME_FIELD_NAME", "name")
SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text")
FAQ_QUESTION_FIELD_NAME = os.getenv("FAQ_QUESTION_FIELD_NAME", "question")
EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", None)

View File

@ -14,7 +14,7 @@ from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_
EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \
BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME
from rest_api.controller.utils import RequestLimiter
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.farm import FARMReader
@ -37,6 +37,7 @@ document_store = ElasticsearchDocumentStore(
ca_certs=False,
verify_certs=False,
text_field=TEXT_FIELD_NAME,
name_field=NAME_FIELD_NAME,
search_fields=SEARCH_FIELD_NAME,
embedding_dim=EMBEDDING_DIM,
embedding_field=EMBEDDING_FIELD_NAME,

View File

@ -1,6 +1,9 @@
import numpy as np
import pytest
from elasticsearch import Elasticsearch
from haystack.database.base import Document, Label
from haystack.database.elasticsearch import ElasticsearchDocumentStore
def test_get_all_documents_without_filters(document_store_with_docs):
@ -64,3 +67,17 @@ def test_elasticsearch_update_meta(document_store_with_docs):
document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"})
updated_document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0]
assert updated_document.meta["meta_field"] == "updated_meta"
def test_elasticsearch_custom_fields(elasticsearch_fixture):
client = Elasticsearch()
client.indices.delete(index='haystack_test_custom', ignore=[404])
document_store = ElasticsearchDocumentStore(index="haystack_test_custom", text_field="custom_text_field",
embedding_field="custom_embedding_field")
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents()
assert len(documents) == 1
assert documents[0].text == "test"
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)