diff --git a/haystack/database/base.py b/haystack/database/base.py index 10b0361de..7ffe219cd 100644 --- a/haystack/database/base.py +++ b/haystack/database/base.py @@ -44,23 +44,34 @@ class Document: self.meta = meta self.embedding = embedding - def to_dict(self): - return self.__dict__ + def to_dict(self, field_map={}): + inv_field_map = {v:k for k, v in field_map.items()} + _doc: Dict[str, str] = {} + for k, v in self.__dict__.items(): + k = k if k not in inv_field_map else inv_field_map[k] + _doc[k] = v + return _doc @classmethod - def from_dict(cls, dict): + def from_dict(cls, dict, field_map={}): _doc = dict.copy() init_args = ["text", "id", "query_score", "question", "meta", "embedding"] if "meta" not in _doc.keys(): _doc["meta"] = {} # copy additional fields into "meta" for k, v in _doc.items(): - if k not in init_args: + if k not in init_args and k not in field_map: _doc["meta"][k] = v # remove additional fields from top level - _doc = {k: v for k, v in _doc.items() if k in init_args} + _new_doc = {} + for k, v in _doc.items(): + if k in init_args: + _new_doc[k] = v + elif k in field_map: + k = field_map[k] + _new_doc[k] = v - return cls(**_doc) + return cls(**_new_doc) class Label: diff --git a/haystack/database/elasticsearch.py b/haystack/database/elasticsearch.py index d9e1d8d78..31058f7ac 100644 --- a/haystack/database/elasticsearch.py +++ b/haystack/database/elasticsearch.py @@ -128,6 +128,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore): } self.client.indices.create(index=index_name, ignore=400, body=mapping) + # TODO: Add flexibility to define other non-meta and meta fields expected by the Document class + def _create_document_field_map(self) -> Dict: + return { + self.text_field: "text", + self.embedding_field: "embedding", + self.faq_question_field if self.faq_question_field else "question": "question" + } + def get_document_by_id(self, id: str, index=None) -> Optional[Document]: index = index or self.index documents = self.get_documents_by_id([id], index=index) @@ -167,7 +175,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore): index = self.index # Make sure we comply to Document class format - documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] + documents_objects = [Document.from_dict(d, field_map=self._create_document_field_map()) + if isinstance(d, dict) else d for d in documents] documents_to_index = [] for doc in documents_objects: @@ -175,7 +184,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): _doc = { "_op_type": "index" if self.update_existing_documents else "create", "_index": index, - **doc.to_dict() + **doc.to_dict(field_map=self._create_document_field_map()) } # type: Dict[str, Any] # rename id for elastic diff --git a/rest_api/config.py b/rest_api/config.py index 1be706242..d54785594 100644 --- a/rest_api/config.py +++ b/rest_api/config.py @@ -20,6 +20,7 @@ DB_INDEX = os.getenv("DB_INDEX", "document") DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label") ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http") TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text") +NAME_FIELD_NAME = os.getenv("NAME_FIELD_NAME", "name") SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text") FAQ_QUESTION_FIELD_NAME = os.getenv("FAQ_QUESTION_FIELD_NAME", "question") EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", None) diff --git a/rest_api/controller/search.py b/rest_api/controller/search.py index d56bc282b..4b6e7827c 100644 --- a/rest_api/controller/search.py +++ b/rest_api/controller/search.py @@ -14,7 +14,7 @@ from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_ EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \ BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \ DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \ - EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER + EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME from rest_api.controller.utils import RequestLimiter from haystack.database.elasticsearch import ElasticsearchDocumentStore from haystack.reader.farm import FARMReader @@ -37,6 +37,7 @@ document_store = ElasticsearchDocumentStore( ca_certs=False, verify_certs=False, text_field=TEXT_FIELD_NAME, + name_field=NAME_FIELD_NAME, search_fields=SEARCH_FIELD_NAME, embedding_dim=EMBEDDING_DIM, embedding_field=EMBEDDING_FIELD_NAME, diff --git a/test/test_db.py b/test/test_db.py index 65979be12..5b0a6a7ef 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,6 +1,9 @@ +import numpy as np import pytest +from elasticsearch import Elasticsearch from haystack.database.base import Document, Label +from haystack.database.elasticsearch import ElasticsearchDocumentStore def test_get_all_documents_without_filters(document_store_with_docs): @@ -64,3 +67,17 @@ def test_elasticsearch_update_meta(document_store_with_docs): document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"}) updated_document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0] assert updated_document.meta["meta_field"] == "updated_meta" + + +def test_elasticsearch_custom_fields(elasticsearch_fixture): + client = Elasticsearch() + client.indices.delete(index='haystack_test_custom', ignore=[404]) + document_store = ElasticsearchDocumentStore(index="haystack_test_custom", text_field="custom_text_field", + embedding_field="custom_embedding_field") + + doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)} + document_store.write_documents([doc_to_write]) + documents = document_store.get_all_documents() + assert len(documents) == 1 + assert documents[0].text == "test" + np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)