diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 737673a5d..a1125d26c 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1857,7 +1857,16 @@ class EmbeddingRetriever(DenseRetriever): doc.content = doc.content.to_csv(index=False) else: raise HaystackError("Documents of type 'table' need to have a pd.DataFrame as content field") - meta_data_fields = [doc.meta[key] for key in self.embed_meta_fields if key in doc.meta and doc.meta[key]] + # Gather all relevant metadata fields + meta_data_fields = [] + for key in self.embed_meta_fields: + if key in doc.meta and doc.meta[key]: + if isinstance(doc.meta[key], list): + meta_data_fields.extend([item for item in doc.meta[key]]) + else: + meta_data_fields.append(doc.meta[key]) + # Convert to type string (e.g. for ints or floats) + meta_data_fields = [str(field) for field in meta_data_fields] doc.content = "\n".join(meta_data_fields + [doc.content]) linearized_docs.append(doc) return linearized_docs diff --git a/test/conftest.py b/test/conftest.py index 9cdf34422..aec71800e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -440,24 +440,49 @@ def docs_all_formats() -> List[Union[Document, Dict[str, Any]]]: "name": "filename2", "date_field": "2019-10-01", "numeric_field": 5.0, + "list_field": ["item0.1", "item0.2"], }, # "dict" format { "content": "My name is Carla and I live in Berlin", - "meta": {"meta_field": "test1", "name": "filename1", "date_field": "2020-03-01", "numeric_field": 5.5}, + "meta": { + "meta_field": "test1", + "name": "filename1", + "date_field": "2020-03-01", + "numeric_field": 5.5, + "list_field": ["item1.1", "item1.2"], + }, }, # Document object Document( content="My name is Christelle and I live in Paris", - meta={"meta_field": "test3", "name": "filename3", "date_field": "2018-10-01", "numeric_field": 4.5}, + meta={ + "meta_field": "test3", + "name": "filename3", + "date_field": "2018-10-01", + "numeric_field": 4.5, + "list_field": ["item2.1", "item2.2"], + }, ), Document( content="My name is Camila and I live in Madrid", - meta={"meta_field": "test4", "name": "filename4", "date_field": "2021-02-01", "numeric_field": 3.0}, + meta={ + "meta_field": "test4", + "name": "filename4", + "date_field": "2021-02-01", + "numeric_field": 3.0, + "list_field": ["item3.1", "item3.2"], + }, ), Document( content="My name is Matteo and I live in Rome", - meta={"meta_field": "test5", "name": "filename5", "date_field": "2019-01-01", "numeric_field": 0.0}, + meta={ + "meta_field": "test5", + "name": "filename5", + "date_field": "2019-01-01", + "numeric_field": 0.0, + "list_field": ["item4.1", "item4.2"], + }, ), ] diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index d86b30dfc..9f50cf8f7 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -220,6 +220,54 @@ def test_batch_retrieval_multiple_queries_with_filters(retriever_with_docs, docu assert res[1][0].meta["name"] == "filename2" +@pytest.mark.unit +def test_embed_meta_fields(docs_with_ids): + with patch( + "haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__" + ) as mock_init: + mock_init.return_value = None + retriever = EmbeddingRetriever( + embedding_model="sentence-transformers/all-mpnet-base-v2", + model_format="sentence_transformers", + embed_meta_fields=["date_field", "numeric_field", "list_field"], + ) + docs_with_embedded_meta = retriever._preprocess_documents(docs=docs_with_ids[:2]) + assert docs_with_embedded_meta[0].content.startswith("2019-10-01\n5.0\nitem0.1\nitem0.2") + assert docs_with_embedded_meta[1].content.startswith("2020-03-01\n5.5\nitem1.1\nitem1.2") + + +@pytest.mark.unit +def test_embed_meta_fields_empty(): + doc = Document(content="My name is Matteo and I live in Rome", meta={"meta_field": "", "list_field": []}) + with patch( + "haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__" + ) as mock_init: + mock_init.return_value = None + retriever = EmbeddingRetriever( + embedding_model="sentence-transformers/all-mpnet-base-v2", + model_format="sentence_transformers", + embed_meta_fields=["meta_field", "list_field"], + ) + docs_with_embedded_meta = retriever._preprocess_documents(docs=[doc]) + assert docs_with_embedded_meta[0].content == "My name is Matteo and I live in Rome" + + +@pytest.mark.unit +def test_embed_meta_fields_list_with_one_item(): + doc = Document(content="My name is Matteo and I live in Rome", meta={"list_field": ["one_item"]}) + with patch( + "haystack.nodes.retriever._embedding_encoder._SentenceTransformersEmbeddingEncoder.__init__" + ) as mock_init: + mock_init.return_value = None + retriever = EmbeddingRetriever( + embedding_model="sentence-transformers/all-mpnet-base-v2", + model_format="sentence_transformers", + embed_meta_fields=["list_field"], + ) + docs_with_embedded_meta = retriever._preprocess_documents(docs=[doc]) + assert docs_with_embedded_meta[0].content == "one_item\nMy name is Matteo and I live in Rome" + + @pytest.mark.elasticsearch def test_elasticsearch_custom_query(): client = Elasticsearch()