haystack/test/nodes/test_retriever.py
Sara Zan 101d2bc86c
feat: MultiModalRetriever (#2891)
* Adding Data2VecVision and Data2VecText to the supported models and adapt Tokenizers accordingly

* content_types

* Splitting classes into respective folders

* small changes

* Fix EOF

* eof

* black

* API

* EOF

* whitespace

* api

* improve multimodal similarity processor

* tokenizer -> feature extractor

* Making feature vectors come out of the feature extractor in the similarity head

* embed_queries is now self-sufficient

* couple trivial errors

* Implemented separate language model classes for multimodal inference

* Document embedding seems to work

* removing batch_encode_plus, is deprecated anyway

* Realized the base Data2Vec models are not trained on retrieval tasks

* Issue with the generated embeddings

* Add batching

* Try to fit CLIP in

* Stub of CLIP integration

* Retrieval goes through but returns noise only

* Still working on the scores

* Introduce temporary adapter for CLIP models

* Image retrieval now works with sentence-transformers

* Tidying up the code

* Refactoring is now functional

* Add MPNet to the supported sentence transformers models

* Remove unused classes

* pylint

* docs

* docs

* Remove the method renaming

* mpyp first pass

* docs

* tutorial

* schema

* mypy

* Move devices setup into get_model

* more mypy

* mypy

* pylint

* Move a few params in HaystackModel's init

* make feature extractor work with squadprocessor

* fix feature_extractor_kwargs forwarding

* Forgotten part of the fix

* Revert unrelated ES change

* Revert unrelated memdocstore changes

* comment

* Small corrections

* mypy and pylint

* mypy

* typo

* mypy

* Refactor the  call

* mypy

* Do not make FARMReader use the new FeatureExtractor

* mypy

* Detach DPR tests from FeatureExtractor too

* Detach processor tests too

* Add end2end marker

* extract end2end feature extractor tests

* temporary disable feature extraction tests

* Introduce end2end tests for tokenizer tests

* pylint

* Fix model loading from folder in FeatureExtractor

* working o n end2end

* end2end keeps failing

* Restructuring retriever tests

* Restructuring retriever tests

* remove covert_dataset_to_dataloader

* remove comment

* Better check sentence-transformers models

* Use embed_meta_fields properly

* rename passage into document

* Embedding dims can't be found

* Add check for models that support it

* pylint

* Split all retriever tests into suites, running mostly on InMemory only

* fix mypy

* fix tfidf test

* fix weaviate tests

* Parallelize on every docstore

* Fix schema and specify modality in base retriever suite

* tests

* Add first image tests

* remove comment

* Revert to simpler tests

* Update docs/_src/api/api/primitives.md

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Update haystack/modeling/model/multimodal/__init__.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* get_args

* mypy

* Update haystack/modeling/model/multimodal/__init__.py

* Update haystack/modeling/model/multimodal/base.py

* Update haystack/modeling/model/multimodal/base.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Update haystack/modeling/model/multimodal/sentence_transformers.py

* Update haystack/modeling/model/multimodal/sentence_transformers.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Update haystack/modeling/model/multimodal/transformers.py

* Update haystack/modeling/model/multimodal/transformers.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Update haystack/modeling/model/multimodal/transformers.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Update haystack/nodes/retriever/multimodal/retriever.py

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* mypy

* mypy

* removing more ContentTypes

* more contentypes

* pylint

* add to __init__

* revert end2end workflow for now

* missing integration markers

* Update haystack/nodes/retriever/multimodal/embedder.py

Co-authored-by: bogdankostic <bogdankostic@web.de>

* review feedback, removing HaystackImageTransformerModel

* review feedback part 2

* mypy & pylint

* mypy

* mypy

* fix multimodal docs also for Pinecone

* add note on internal constants

* Fix pinecone write_documents

* schemas

* keep support for sentence-transformers only

* fix pinecone test

* schemas

* fix pinecone again

* temporarily disable some tests, need to understand if they're still relevant

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
Co-authored-by: bogdankostic <bogdankostic@web.de>
2022-10-17 18:58:35 +02:00

807 lines
31 KiB
Python

from typing import List
import os
import logging
import os
from math import isclose
import pytest
import numpy as np
import pandas as pd
from pandas.testing import assert_frame_equal
from elasticsearch import Elasticsearch
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.memory import InMemoryDocumentStore
from haystack.document_stores import WeaviateDocumentStore
from haystack.nodes.retriever.base import BaseRetriever
from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores import MilvusDocumentStore
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.multimodal import MultiModalRetriever
from ..conftest import SAMPLES_PATH
# TODO check if we this works with only "memory" arg
@pytest.mark.parametrize(
"retriever_with_docs,document_store_with_docs",
[
("mdr", "elasticsearch"),
("mdr", "faiss"),
("mdr", "memory"),
("mdr", "milvus1"),
("dpr", "elasticsearch"),
("dpr", "faiss"),
("dpr", "memory"),
("dpr", "milvus1"),
("embedding", "elasticsearch"),
("embedding", "faiss"),
("embedding", "memory"),
("embedding", "milvus1"),
("elasticsearch", "elasticsearch"),
("es_filter_only", "elasticsearch"),
("tfidf", "memory"),
],
indirect=True,
)
def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore):
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
document_store_with_docs.update_embeddings(retriever_with_docs)
# test without filters
# NOTE: FilterRetriever simply returns all documents matching a filter,
# so without filters applied it does nothing
if not isinstance(retriever_with_docs, FilterRetriever):
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
assert res[0].content == "My name is Carla and I live in Berlin"
assert len(res) == 5
assert res[0].meta["name"] == "filename1"
# test with filters
if not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore)) and not isinstance(
retriever_with_docs, TfidfRetriever
):
# single filter
result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5)
assert len(result) == 1
assert type(result[0]) == Document
assert result[0].content == "My name is Christelle and I live in Paris"
assert result[0].meta["name"] == "filename3"
# multiple filters
result = retriever_with_docs.retrieve(
query="Paul", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5
)
assert len(result) == 1
assert type(result[0]) == Document
assert result[0].meta["name"] == "filename2"
result = retriever_with_docs.retrieve(
query="Carla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5
)
assert len(result) == 0
def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs):
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
document_store_with_docs.update_embeddings(retriever_with_docs)
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?"])
# Expected return type: List of lists of Documents
assert isinstance(res, list)
assert isinstance(res[0], list)
assert isinstance(res[0][0], Document)
assert len(res) == 1
assert len(res[0]) == 5
assert res[0][0].content == "My name is Carla and I live in Berlin"
assert res[0][0].meta["name"] == "filename1"
def test_batch_retrieval_multiple_queries(retriever_with_docs, document_store_with_docs):
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
document_store_with_docs.update_embeddings(retriever_with_docs)
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?", "Who lives in New York?"])
# Expected return type: list of lists of Documents
assert isinstance(res, list)
assert isinstance(res[0], list)
assert isinstance(res[0][0], Document)
assert res[0][0].content == "My name is Carla and I live in Berlin"
assert len(res[0]) == 5
assert res[0][0].meta["name"] == "filename1"
assert res[1][0].content == "My name is Paul and I live in New York"
assert len(res[1]) == 5
assert res[1][0].meta["name"] == "filename2"
@pytest.mark.elasticsearch
def test_elasticsearch_custom_query():
client = Elasticsearch()
client.indices.delete(index="haystack_test_custom", ignore=[404])
document_store = ElasticsearchDocumentStore(
index="haystack_test_custom", content_field="custom_text_field", embedding_field="custom_embedding_field"
)
documents = [
{"content": "test_1", "meta": {"year": "2019"}},
{"content": "test_2", "meta": {"year": "2020"}},
{"content": "test_3", "meta": {"year": "2021"}},
{"content": "test_4", "meta": {"year": "2021"}},
{"content": "test_5", "meta": {"year": "2021"}},
]
document_store.write_documents(documents)
# test custom "terms" query
retriever = BM25Retriever(
document_store=document_store,
custom_query="""
{
"size": 10,
"query": {
"bool": {
"should": [{
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
"filter": [{"terms": {"year": ${years}}}]}}}""",
)
results = retriever.retrieve(query="test", filters={"years": ["2020", "2021"]})
assert len(results) == 4
# test custom "term" query
retriever = BM25Retriever(
document_store=document_store,
custom_query="""
{
"size": 10,
"query": {
"bool": {
"should": [{
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
"filter": [{"term": {"year": ${years}}}]}}}""",
)
results = retriever.retrieve(query="test", filters={"years": "2021"})
assert len(results) == 3
@pytest.mark.integration
@pytest.mark.parametrize(
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs.sort(key=lambda d: d.id)
print([doc.id for doc in docs])
expected_values = [0.00892, 0.00780, 0.00482, -0.00626, 0.010966]
for doc, expected_value in zip(docs, expected_values):
embedding = doc.embedding
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding /= np.linalg.norm(embedding)
assert len(embedding) == 768
assert isclose(embedding[0], expected_value, rel_tol=0.01)
@pytest.mark.integration
@pytest.mark.parametrize(
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.embedding_dim(128)
def test_retribert_embedding(document_store, retriever, docs_with_ids):
if isinstance(document_store, WeaviateDocumentStore):
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
# We need 128 here and therefore initialize a new WeaviateDocumentStore.
document_store = WeaviateDocumentStore(index="haystack_test", embedding_dim=128, recreate_index=True)
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs = sorted(docs, key=lambda d: d.id)
expected_values = [0.14017, 0.05975, 0.14267, 0.15099, 0.14383]
for doc, expected_value in zip(docs, expected_values):
embedding = doc.embedding
assert len(embedding) == 128
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding /= np.linalg.norm(embedding)
assert isclose(embedding[0], expected_value, rel_tol=0.001)
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
@pytest.mark.embedding_dim(1024)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_openai_embedding(document_store, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs = sorted(docs, key=lambda d: d.id)
for doc in docs:
assert len(doc.embedding) == 1024
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
@pytest.mark.embedding_dim(1024)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_retriever_basic_search(document_store, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
p_retrieval = DocumentSearchPipeline(retriever)
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
assert len(res["documents"]) == 1
assert "Madrid" in res["documents"][0].content
@pytest.mark.integration
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
@pytest.mark.embedding_dim(512)
def test_table_text_retriever_embedding(document_store, retriever, docs):
document_store.return_embedding = True
document_store.write_documents(docs)
table_data = {
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
"Height": ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"],
}
table = pd.DataFrame(table_data)
table_doc = Document(content=table, content_type="table", id="6")
document_store.write_documents([table_doc])
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs = sorted(docs, key=lambda d: d.id)
expected_values = [0.061191384, 0.038075786, 0.27447605, 0.09399721, 0.0959682]
for doc, expected_value in zip(docs, expected_values):
assert len(doc.embedding) == 512
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
retriever.save(f"{tmp_path}/test_dpr_save")
def sum_params(model):
s = []
for p in model.parameters():
n = p.cpu().data.numpy()
s.append(np.sum(n))
return sum(s)
original_sum_query = sum_params(retriever.query_encoder)
original_sum_passage = sum_params(retriever.passage_encoder)
del retriever
loaded_retriever = DensePassageRetriever.load(f"{tmp_path}/test_dpr_save", document_store)
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
assert abs(original_sum_query - loaded_sum_query) < 0.1
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
# comparison of weights (RAM intense!)
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
#
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
# attributes
assert loaded_retriever.processor.embed_title == True
assert loaded_retriever.batch_size == 16
assert loaded_retriever.processor.max_seq_len_passage == 256
assert loaded_retriever.processor.max_seq_len_query == 64
# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
assert loaded_retriever.query_tokenizer.vocab_size == 30522
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
@pytest.mark.embedding_dim(512)
def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_store):
retriever.save(f"{tmp_path}/test_table_text_retriever_save")
def sum_params(model):
s = []
for p in model.parameters():
n = p.cpu().data.numpy()
s.append(np.sum(n))
return sum(s)
original_sum_query = sum_params(retriever.query_encoder)
original_sum_passage = sum_params(retriever.passage_encoder)
original_sum_table = sum_params(retriever.table_encoder)
del retriever
loaded_retriever = TableTextRetriever.load(f"{tmp_path}/test_table_text_retriever_save", document_store)
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
loaded_sum_table = sum_params(loaded_retriever.table_encoder)
assert abs(original_sum_query - loaded_sum_query) < 0.1
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
assert abs(original_sum_table - loaded_sum_table) < 0.01
# attributes
assert loaded_retriever.processor.embed_meta_fields == ["name", "section_title", "caption"]
assert loaded_retriever.batch_size == 16
assert loaded_retriever.processor.max_seq_len_passage == 256
assert loaded_retriever.processor.max_seq_len_table == 256
assert loaded_retriever.processor.max_seq_len_query == 64
# Tokenizer
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast)
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
assert loaded_retriever.passage_tokenizer.do_lower_case == True
assert loaded_retriever.table_tokenizer.do_lower_case == True
assert loaded_retriever.query_tokenizer.do_lower_case == True
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
assert loaded_retriever.table_tokenizer.vocab_size == 30522
assert loaded_retriever.query_tokenizer.vocab_size == 30522
@pytest.mark.embedding_dim(128)
def test_table_text_retriever_training(document_store):
retriever = TableTextRetriever(
document_store=document_store,
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
use_gpu=False,
)
retriever.train(
data_dir=SAMPLES_PATH / "mmr",
train_filename="sample.json",
n_epochs=1,
n_gpu=0,
save_dir="test_table_text_retriever_train",
)
# Load trained model
retriever = TableTextRetriever.load(load_dir="test_table_text_retriever_train", document_store=document_store)
@pytest.mark.elasticsearch
def test_elasticsearch_highlight():
client = Elasticsearch()
client.indices.delete(index="haystack_hl_test", ignore=[404])
# Mapping the content and title field as "text" perform search on these both fields.
document_store = ElasticsearchDocumentStore(
index="haystack_hl_test",
content_field="title",
custom_mapping={"mappings": {"properties": {"content": {"type": "text"}, "title": {"type": "text"}}}},
)
documents = [
{
"title": "Green tea components",
"meta": {
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink"
},
"id": "1",
},
{
"title": "Green tea catechin",
"meta": {"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG)."},
"id": "2",
},
{
"title": "Minerals in Green tea",
"meta": {"content": "Green tea also has small amounts of minerals that can benefit your health."},
"id": "3",
},
{
"title": "Green tea Benefits",
"meta": {"content": "Green tea does more than just keep you alert, it may also help boost brain function."},
"id": "4",
},
]
document_store.write_documents(documents)
# Enabled highlighting on "title"&"content" field only using custom query
retriever_1 = BM25Retriever(
document_store=document_store,
custom_query="""{
"size": 20,
"query": {
"bool": {
"should": [
{
"multi_match": {
"query": ${query},
"fields": [
"content^3",
"title^5"
]
}
}
]
}
},
"highlight": {
"pre_tags": [
"**"
],
"post_tags": [
"**"
],
"number_of_fragments": 3,
"fragment_size": 5,
"fields": {
"content": {},
"title": {}
}
}
}""",
)
results = retriever_1.retrieve(query="is green tea healthy")
assert len(results[0].meta["highlighted"]) == 2
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
assert results[0].meta["highlighted"]["content"] == ["The **green**", "**tea** plant", "range of **healthy**"]
# Enabled highlighting on "title" field only using custom query
retriever_2 = BM25Retriever(
document_store=document_store,
custom_query="""{
"size": 20,
"query": {
"bool": {
"should": [
{
"multi_match": {
"query": ${query},
"fields": [
"content^3",
"title^5"
]
}
}
]
}
},
"highlight": {
"pre_tags": [
"**"
],
"post_tags": [
"**"
],
"number_of_fragments": 3,
"fragment_size": 5,
"fields": {
"title": {}
}
}
}""",
)
results = retriever_2.retrieve(query="is green tea healthy")
assert len(results[0].meta["highlighted"]) == 1
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
def test_elasticsearch_filter_must_not_increase_results():
index = "filter_must_not_increase_results"
client = Elasticsearch()
client.indices.delete(index=index, ignore=[404])
documents = [
{
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
"meta": {"content_type": "text"},
"id": "1",
},
{
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
"meta": {"content_type": "text"},
"id": "2",
},
{
"content": "Green tea also has small amounts of minerals that can benefit your health.",
"meta": {"content_type": "text"},
"id": "3",
},
{
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
"meta": {"content_type": "text"},
"id": "4",
},
]
doc_store = ElasticsearchDocumentStore(index=index)
doc_store.write_documents(documents)
results_wo_filter = doc_store.query(query="drink")
assert len(results_wo_filter) == 1
results_w_filter = doc_store.query(query="drink", filters={"content_type": "text"})
assert len(results_w_filter) == 1
doc_store.delete_index(index)
def test_elasticsearch_all_terms_must_match():
index = "all_terms_must_match"
client = Elasticsearch()
client.indices.delete(index=index, ignore=[404])
documents = [
{
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
"meta": {"content_type": "text"},
"id": "1",
},
{
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
"meta": {"content_type": "text"},
"id": "2",
},
{
"content": "Green tea also has small amounts of minerals that can benefit your health.",
"meta": {"content_type": "text"},
"id": "3",
},
{
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
"meta": {"content_type": "text"},
"id": "4",
},
]
doc_store = ElasticsearchDocumentStore(index=index)
doc_store.write_documents(documents)
results_wo_all_terms_must_match = doc_store.query(query="drink green tea")
assert len(results_wo_all_terms_must_match) == 4
results_w_all_terms_must_match = doc_store.query(query="drink green tea", all_terms_must_match=True)
assert len(results_w_all_terms_must_match) == 1
doc_store.delete_index(index)
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format(caplog):
document_store = InMemoryDocumentStore()
with caplog.at_level(logging.WARNING):
EmbeddingRetriever(
document_store=document_store,
embedding_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
model_format="farm",
)
assert (
"You may need to set model_format='sentence_transformers' to ensure correct loading of model."
in caplog.text
)
@pytest.mark.parametrize("retriever", ["es_filter_only"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
def test_es_filter_only(document_store, retriever):
docs = [
Document(content="Doc1", meta={"f1": "0"}),
Document(content="Doc2", meta={"f1": "0"}),
Document(content="Doc3", meta={"f1": "0"}),
Document(content="Doc4", meta={"f1": "0"}),
Document(content="Doc5", meta={"f1": "0"}),
Document(content="Doc6", meta={"f1": "0"}),
Document(content="Doc7", meta={"f1": "1"}),
Document(content="Doc8", meta={"f1": "0"}),
Document(content="Doc9", meta={"f1": "0"}),
Document(content="Doc10", meta={"f1": "0"}),
Document(content="Doc11", meta={"f1": "0"}),
Document(content="Doc12", meta={"f1": "0"}),
]
document_store.write_documents(docs)
retrieved_docs = retriever.retrieve(query="", filters={"f1": ["0"]})
assert len(retrieved_docs) == 11
#
# MultiModal
#
@pytest.fixture
def text_docs() -> List[Document]:
return [
Document(
content="My name is Paul and I live in New York",
meta={
"meta_field": "test2",
"name": "filename2",
"date_field": "2019-10-01",
"numeric_field": 5.0,
"odd_field": 0,
},
),
Document(
content="My name is Carla and I live in Berlin",
meta={
"meta_field": "test1",
"name": "filename1",
"date_field": "2020-03-01",
"numeric_field": 5.5,
"odd_field": 1,
},
),
Document(
content="My name is Christelle and I live in Paris",
meta={
"meta_field": "test3",
"name": "filename3",
"date_field": "2018-10-01",
"numeric_field": 4.5,
"odd_field": 1,
},
),
Document(
content="My name is Camila and I live in Madrid",
meta={
"meta_field": "test4",
"name": "filename4",
"date_field": "2021-02-01",
"numeric_field": 3.0,
"odd_field": 0,
},
),
Document(
content="My name is Matteo and I live in Rome",
meta={
"meta_field": "test5",
"name": "filename5",
"date_field": "2019-01-01",
"numeric_field": 0.0,
"odd_field": 1,
},
),
]
@pytest.fixture
def table_docs() -> List[Document]:
return [
Document(
content=pd.DataFrame(
{
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
"Height": ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"],
}
),
content_type="table",
),
Document(
content=pd.DataFrame(
{
"City": ["Paris", "Lyon", "Marseille", "Lille", "Toulouse", "Bordeaux"],
"Population": ["13,114,718", "2,280,845", "1,873,270 ", "1,510,079", "1,454,158", "1,363,711"],
}
),
content_type="table",
),
Document(
content=pd.DataFrame(
{
"City": ["Berlin", "Hamburg", "Munich", "Cologne"],
"Population": ["3,644,826", "1,841,179", "1,471,508", "1,085,664"],
}
),
content_type="table",
),
]
@pytest.fixture
def image_docs() -> List[Document]:
return [
Document(content=str(SAMPLES_PATH / "images" / imagefile), content_type="image")
for imagefile in os.listdir(SAMPLES_PATH / "images")
]
@pytest.mark.integration
def test_multimodal_text_retrieval(text_docs: List[Document]):
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True),
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
)
retriever.document_store.write_documents(text_docs)
retriever.document_store.update_embeddings(retriever=retriever)
results = retriever.retrieve(query="Who lives in Paris?")
assert results[0].content == "My name is Christelle and I live in Paris"
@pytest.mark.integration
def test_multimodal_table_retrieval(table_docs: List[Document]):
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True),
query_embedding_model="deepset/all-mpnet-base-v2-table",
document_embedding_models={"table": "deepset/all-mpnet-base-v2-table"},
)
retriever.document_store.write_documents(table_docs)
retriever.document_store.update_embeddings(retriever=retriever)
results = retriever.retrieve(query="How many people live in Hamburg?")
assert_frame_equal(
results[0].content,
pd.DataFrame(
{
"City": ["Berlin", "Hamburg", "Munich", "Cologne"],
"Population": ["3,644,826", "1,841,179", "1,471,508", "1,085,664"],
}
),
)
@pytest.mark.integration
def test_multimodal_image_retrieval(image_docs: List[Document]):
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
query_embedding_model="sentence-transformers/clip-ViT-B-32",
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
)
retriever.document_store.write_documents(image_docs)
retriever.document_store.update_embeddings(retriever=retriever)
results = retriever.retrieve(query="What's a cat?")
assert str(results[0].content) == str(SAMPLES_PATH / "images" / "cat.jpg")
@pytest.mark.skip("Not working yet as intended")
@pytest.mark.integration
def test_multimodal_text_image_retrieval(text_docs: List[Document], image_docs: List[Document]):
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
query_embedding_model="sentence-transformers/clip-ViT-B-32",
document_embedding_models={
"text": "sentence-transformers/clip-ViT-B-32",
"image": "sentence-transformers/clip-ViT-B-32",
},
)
retriever.document_store.write_documents(image_docs)
retriever.document_store.write_documents(text_docs)
retriever.document_store.update_embeddings(retriever=retriever)
results = retriever.retrieve(query="What's Paris?")
text_results = [result for result in results if result.content_type == "text"]
image_results = [result for result in results if result.content_type == "image"]
assert str(image_results[0].content) == str(SAMPLES_PATH / "images" / "paris.jpg")
assert text_results[0].content == "My name is Christelle and I live in Paris"