mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 02:48:30 +00:00
Custom fields for indexing in ElasticsearchDocumentStore (#297)
This commit is contained in:
parent
2d27f19a71
commit
c7078a36c0
@ -44,23 +44,34 @@ class Document:
|
||||
self.meta = meta
|
||||
self.embedding = embedding
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
def to_dict(self, field_map={}):
|
||||
inv_field_map = {v:k for k, v in field_map.items()}
|
||||
_doc: Dict[str, str] = {}
|
||||
for k, v in self.__dict__.items():
|
||||
k = k if k not in inv_field_map else inv_field_map[k]
|
||||
_doc[k] = v
|
||||
return _doc
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict):
|
||||
def from_dict(cls, dict, field_map={}):
|
||||
_doc = dict.copy()
|
||||
init_args = ["text", "id", "query_score", "question", "meta", "embedding"]
|
||||
if "meta" not in _doc.keys():
|
||||
_doc["meta"] = {}
|
||||
# copy additional fields into "meta"
|
||||
for k, v in _doc.items():
|
||||
if k not in init_args:
|
||||
if k not in init_args and k not in field_map:
|
||||
_doc["meta"][k] = v
|
||||
# remove additional fields from top level
|
||||
_doc = {k: v for k, v in _doc.items() if k in init_args}
|
||||
_new_doc = {}
|
||||
for k, v in _doc.items():
|
||||
if k in init_args:
|
||||
_new_doc[k] = v
|
||||
elif k in field_map:
|
||||
k = field_map[k]
|
||||
_new_doc[k] = v
|
||||
|
||||
return cls(**_doc)
|
||||
return cls(**_new_doc)
|
||||
|
||||
|
||||
class Label:
|
||||
|
@ -128,6 +128,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
}
|
||||
self.client.indices.create(index=index_name, ignore=400, body=mapping)
|
||||
|
||||
# TODO: Add flexibility to define other non-meta and meta fields expected by the Document class
|
||||
def _create_document_field_map(self) -> Dict:
|
||||
return {
|
||||
self.text_field: "text",
|
||||
self.embedding_field: "embedding",
|
||||
self.faq_question_field if self.faq_question_field else "question": "question"
|
||||
}
|
||||
|
||||
def get_document_by_id(self, id: str, index=None) -> Optional[Document]:
|
||||
index = index or self.index
|
||||
documents = self.get_documents_by_id([id], index=index)
|
||||
@ -167,7 +175,8 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
index = self.index
|
||||
|
||||
# Make sure we comply to Document class format
|
||||
documents_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
documents_objects = [Document.from_dict(d, field_map=self._create_document_field_map())
|
||||
if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
documents_to_index = []
|
||||
for doc in documents_objects:
|
||||
@ -175,7 +184,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
_doc = {
|
||||
"_op_type": "index" if self.update_existing_documents else "create",
|
||||
"_index": index,
|
||||
**doc.to_dict()
|
||||
**doc.to_dict(field_map=self._create_document_field_map())
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
# rename id for elastic
|
||||
|
@ -20,6 +20,7 @@ DB_INDEX = os.getenv("DB_INDEX", "document")
|
||||
DB_INDEX_FEEDBACK = os.getenv("DB_INDEX_FEEDBACK", "label")
|
||||
ES_CONN_SCHEME = os.getenv("ES_CONN_SCHEME", "http")
|
||||
TEXT_FIELD_NAME = os.getenv("TEXT_FIELD_NAME", "text")
|
||||
NAME_FIELD_NAME = os.getenv("NAME_FIELD_NAME", "name")
|
||||
SEARCH_FIELD_NAME = os.getenv("SEARCH_FIELD_NAME", "text")
|
||||
FAQ_QUESTION_FIELD_NAME = os.getenv("FAQ_QUESTION_FIELD_NAME", "question")
|
||||
EMBEDDING_FIELD_NAME = os.getenv("EMBEDDING_FIELD_NAME", None)
|
||||
|
@ -14,7 +14,7 @@ from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_
|
||||
EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \
|
||||
BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \
|
||||
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \
|
||||
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER
|
||||
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER, NAME_FIELD_NAME
|
||||
from rest_api.controller.utils import RequestLimiter
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.reader.farm import FARMReader
|
||||
@ -37,6 +37,7 @@ document_store = ElasticsearchDocumentStore(
|
||||
ca_certs=False,
|
||||
verify_certs=False,
|
||||
text_field=TEXT_FIELD_NAME,
|
||||
name_field=NAME_FIELD_NAME,
|
||||
search_fields=SEARCH_FIELD_NAME,
|
||||
embedding_dim=EMBEDDING_DIM,
|
||||
embedding_field=EMBEDDING_FIELD_NAME,
|
||||
|
@ -1,6 +1,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack.database.base import Document, Label
|
||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
|
||||
def test_get_all_documents_without_filters(document_store_with_docs):
|
||||
@ -64,3 +67,17 @@ def test_elasticsearch_update_meta(document_store_with_docs):
|
||||
document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"})
|
||||
updated_document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0]
|
||||
assert updated_document.meta["meta_field"] == "updated_meta"
|
||||
|
||||
|
||||
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test_custom', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test_custom", text_field="custom_text_field",
|
||||
embedding_field="custom_embedding_field")
|
||||
|
||||
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
||||
document_store.write_documents([doc_to_write])
|
||||
documents = document_store.get_all_documents()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].text == "test"
|
||||
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
||||
|
Loading…
x
Reference in New Issue
Block a user