feat: Add support for meta fields that are lists when using embed_meta_fields (#5307)

* Add support for meta fields that are lists when using embed_meta_fields

* Make sure unit test doesn't download model

* Adding more unit tests
This commit is contained in:
Sebastian Husch Lee 2023-07-11 17:32:33 +02:00 committed by GitHub
parent 6632505540
commit b5aef24a7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 5 deletions

View File

@ -1857,7 +1857,16 @@ class EmbeddingRetriever(DenseRetriever):
doc.content = doc.content.to_csv(index=False)
else:
raise HaystackError("Documents of type 'table' need to have a pd.DataFrame as content field")
meta_data_fields = [doc.meta[key] for key in self.embed_meta_fields if key in doc.meta and doc.meta[key]]
# Gather all relevant metadata fields
meta_data_fields = []
for key in self.embed_meta_fields:
if key in doc.meta and doc.meta[key]:
if isinstance(doc.meta[key], list):
meta_data_fields.extend([item for item in doc.meta[key]])
else:
meta_data_fields.append(doc.meta[key])
# Convert to type string (e.g. for ints or floats)
meta_data_fields = [str(field) for field in meta_data_fields]
doc.content = "\n".join(meta_data_fields + [doc.content])
linearized_docs.append(doc)
return linearized_docs

View File

@ -440,24 +440,49 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]:
"name": "filename2",
"date_field": "2019-10-01",
"numeric_field": 5.0,
"list_field": ["item0.1", "item0.2"],
},
# "dict" format
{
"content": "My name is Carla and I live in Berlin",
"meta": {"meta_field": "test1", "name": "filename1", "date_field": "2020-03-01", "numeric_field": 5.5},
"meta": {
"meta_field": "test1",
"name": "filename1",
"date_field": "2020-03-01",
"numeric_field": 5.5,
"list_field": ["item1.1", "item1.2"],
},
},
# Document object
Document(
content="My name is Christelle and I live in Paris",
meta={"meta_field": "test3", "name": "filename3", "date_field": "2018-10-01", "numeric_field": 4.5},
meta={
"meta_field": "test3",
"name": "filename3",
"date_field": "2018-10-01",
"numeric_field": 4.5,
"list_field": ["item2.1", "item2.2"],
},
),
Document(
content="My name is Camila and I live in Madrid",
meta={"meta_field": "test4", "name": "filename4", "date_field": "2021-02-01", "numeric_field": 3.0},
meta={
"meta_field": "test4",
"name": "filename4",
"date_field": "2021-02-01",
"numeric_field": 3.0,
"list_field": ["item3.1", "item3.2"],
},
),
Document(
content="My name is Matteo and I live in Rome",
meta={"meta_field": "test5", "name": "filename5", "date_field": "2019-01-01", "numeric_field": 0.0},
meta={
"meta_field": "test5",
"name": "filename5",
"date_field": "2019-01-01",
"numeric_field": 0.0,
"list_field": ["item4.1", "item4.2"],
},
),
]

View File

@ -220,6 +220,54 @@ def test_batch_retrieval_multiple_queries_with_filters(retriever_with_docs, docu
assert res[1][0].meta["name"] == "filename2"
@pytest.mark.unit
def test_embed_meta_fields(docs_with_ids):
with patch(
"haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__"
) as mock_init:
mock_init.return_value = None
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-mpnet-base-v2",
model_format="sentence_transformers",
embed_meta_fields=["date_field", "numeric_field", "list_field"],
)
docs_with_embedded_meta = retriever._preprocess_documents(docs=docs_with_ids[:2])
assert docs_with_embedded_meta[0].content.startswith("2019-10-01\n5.0\nitem0.1\nitem0.2")
assert docs_with_embedded_meta[1].content.startswith("2020-03-01\n5.5\nitem1.1\nitem1.2")
@pytest.mark.unit
def test_embed_meta_fields_empty():
doc = Document(content="My name is Matteo and I live in Rome", meta={"meta_field": "", "list_field": []})
with patch(
"haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__"
) as mock_init:
mock_init.return_value = None
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-mpnet-base-v2",
model_format="sentence_transformers",
embed_meta_fields=["meta_field", "list_field"],
)
docs_with_embedded_meta = retriever._preprocess_documents(docs=[doc])
assert docs_with_embedded_meta[0].content == "My name is Matteo and I live in Rome"
@pytest.mark.unit
def test_embed_meta_fields_list_with_one_item():
doc = Document(content="My name is Matteo and I live in Rome", meta={"list_field": ["one_item"]})
with patch(
"haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__"
) as mock_init:
mock_init.return_value = None
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-mpnet-base-v2",
model_format="sentence_transformers",
embed_meta_fields=["list_field"],
)
docs_with_embedded_meta = retriever._preprocess_documents(docs=[doc])
assert docs_with_embedded_meta[0].content == "one_item\nMy name is Matteo and I live in Rome"
@pytest.mark.elasticsearch
def test_elasticsearch_custom_query():
client = Elasticsearch()