2022-04-12 11:52:27 +02:00
|
|
|
import logging
|
2020-08-18 14:04:31 +02:00
|
|
|
import time
|
2022-06-10 18:22:48 +02:00
|
|
|
from math import isclose
|
2020-06-30 19:05:45 +02:00
|
|
|
|
2021-02-12 14:57:06 +01:00
|
|
|
import numpy as np
|
2021-10-25 12:27:02 +02:00
|
|
|
import pandas as pd
|
2022-06-10 18:22:48 +02:00
|
|
|
from haystack.document_stores.base import BaseDocumentStore
|
2022-04-12 11:52:27 +02:00
|
|
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
2021-02-12 14:57:06 +01:00
|
|
|
import pytest
|
2022-01-26 18:12:55 +01:00
|
|
|
from pathlib import Path
|
2021-02-12 14:57:06 +01:00
|
|
|
from elasticsearch import Elasticsearch
|
2021-11-04 09:27:12 +01:00
|
|
|
|
|
|
|
from haystack.document_stores import WeaviateDocumentStore
|
2022-07-22 16:29:30 +02:00
|
|
|
from haystack.nodes.retriever.base import BaseRetriever
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.schema import Document
|
|
|
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
|
|
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
2022-02-24 17:43:38 +01:00
|
|
|
from haystack.document_stores import MilvusDocumentStore
|
2022-07-05 11:31:11 +02:00
|
|
|
from haystack.nodes.retriever.dense import (
|
|
|
|
DensePassageRetriever,
|
|
|
|
EmbeddingRetriever,
|
|
|
|
TableTextRetriever,
|
|
|
|
MultihopEmbeddingRetriever,
|
|
|
|
)
|
2022-04-29 10:16:02 +02:00
|
|
|
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
2022-07-04 18:18:14 +02:00
|
|
|
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
2020-06-30 19:05:45 +02:00
|
|
|
|
2022-05-17 10:55:53 +02:00
|
|
|
from ..conftest import SAMPLES_PATH
|
2022-01-26 18:12:55 +01:00
|
|
|
|
2021-10-25 12:27:02 +02:00
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
# TODO check if we this works with only "memory" arg
|
2021-02-12 14:57:06 +01:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"retriever_with_docs,document_store_with_docs",
|
|
|
|
[
|
2022-07-05 11:31:11 +02:00
|
|
|
("mdr", "elasticsearch"),
|
|
|
|
("mdr", "faiss"),
|
|
|
|
("mdr", "memory"),
|
|
|
|
("mdr", "milvus1"),
|
2021-02-12 14:57:06 +01:00
|
|
|
("dpr", "elasticsearch"),
|
|
|
|
("dpr", "faiss"),
|
|
|
|
("dpr", "memory"),
|
2022-02-24 17:43:38 +01:00
|
|
|
("dpr", "milvus1"),
|
2021-02-12 14:57:06 +01:00
|
|
|
("embedding", "elasticsearch"),
|
|
|
|
("embedding", "faiss"),
|
|
|
|
("embedding", "memory"),
|
2022-02-24 17:43:38 +01:00
|
|
|
("embedding", "milvus1"),
|
2021-02-12 14:57:06 +01:00
|
|
|
("elasticsearch", "elasticsearch"),
|
|
|
|
("es_filter_only", "elasticsearch"),
|
|
|
|
("tfidf", "memory"),
|
|
|
|
],
|
|
|
|
indirect=True,
|
|
|
|
)
|
2022-07-22 16:29:30 +02:00
|
|
|
def test_retrieval(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore):
|
2022-04-29 10:16:02 +02:00
|
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
2021-02-12 14:57:06 +01:00
|
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
|
|
|
|
# test without filters
|
2022-06-10 18:22:48 +02:00
|
|
|
# NOTE: FilterRetriever simply returns all documents matching a filter,
|
|
|
|
# so without filters applied it does nothing
|
|
|
|
if not isinstance(retriever_with_docs, FilterRetriever):
|
|
|
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
|
|
|
assert res[0].content == "My name is Carla and I live in Berlin"
|
|
|
|
assert len(res) == 5
|
|
|
|
assert res[0].meta["name"] == "filename1"
|
2021-02-12 14:57:06 +01:00
|
|
|
|
|
|
|
# test with filters
|
|
|
|
if not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore)) and not isinstance(
|
|
|
|
retriever_with_docs, TfidfRetriever
|
|
|
|
):
|
|
|
|
# single filter
|
2022-03-25 17:53:42 +01:00
|
|
|
result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5)
|
2021-02-12 14:57:06 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert type(result[0]) == Document
|
2021-10-13 14:23:23 +02:00
|
|
|
assert result[0].content == "My name is Christelle and I live in Paris"
|
2021-02-12 14:57:06 +01:00
|
|
|
assert result[0].meta["name"] == "filename3"
|
|
|
|
|
|
|
|
# multiple filters
|
|
|
|
result = retriever_with_docs.retrieve(
|
2022-03-25 17:53:42 +01:00
|
|
|
query="Paul", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5
|
2021-02-12 14:57:06 +01:00
|
|
|
)
|
|
|
|
assert len(result) == 1
|
|
|
|
assert type(result[0]) == Document
|
|
|
|
assert result[0].meta["name"] == "filename2"
|
|
|
|
|
|
|
|
result = retriever_with_docs.retrieve(
|
2022-03-25 17:53:42 +01:00
|
|
|
query="Carla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5
|
2021-02-12 14:57:06 +01:00
|
|
|
)
|
|
|
|
assert len(result) == 0
|
|
|
|
|
|
|
|
|
2022-05-11 11:11:00 +02:00
|
|
|
def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs):
|
|
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
|
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
|
2022-05-24 12:33:45 +02:00
|
|
|
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?"])
|
2022-05-11 11:11:00 +02:00
|
|
|
|
2022-05-24 12:33:45 +02:00
|
|
|
# Expected return type: List of lists of Documents
|
2022-05-11 11:11:00 +02:00
|
|
|
assert isinstance(res, list)
|
2022-05-24 12:33:45 +02:00
|
|
|
assert isinstance(res[0], list)
|
|
|
|
assert isinstance(res[0][0], Document)
|
2022-05-11 11:11:00 +02:00
|
|
|
|
2022-05-24 12:33:45 +02:00
|
|
|
assert len(res) == 1
|
|
|
|
assert len(res[0]) == 5
|
|
|
|
assert res[0][0].content == "My name is Carla and I live in Berlin"
|
|
|
|
assert res[0][0].meta["name"] == "filename1"
|
2022-05-11 11:11:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_batch_retrieval_multiple_queries(retriever_with_docs, document_store_with_docs):
|
|
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
|
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
|
|
|
|
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?", "Who lives in New York?"])
|
|
|
|
|
|
|
|
# Expected return type: list of lists of Documents
|
|
|
|
assert isinstance(res, list)
|
|
|
|
assert isinstance(res[0], list)
|
|
|
|
assert isinstance(res[0][0], Document)
|
|
|
|
|
|
|
|
assert res[0][0].content == "My name is Carla and I live in Berlin"
|
|
|
|
assert len(res[0]) == 5
|
|
|
|
assert res[0][0].meta["name"] == "filename1"
|
|
|
|
|
|
|
|
assert res[1][0].content == "My name is Paul and I live in New York"
|
|
|
|
assert len(res[1]) == 5
|
|
|
|
assert res[1][0].meta["name"] == "filename2"
|
|
|
|
|
|
|
|
|
2021-02-12 14:57:06 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2021-10-29 13:52:28 +05:30
|
|
|
def test_elasticsearch_custom_query():
|
2021-02-12 14:57:06 +01:00
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index="haystack_test_custom", ignore=[404])
|
|
|
|
document_store = ElasticsearchDocumentStore(
|
2021-10-13 14:23:23 +02:00
|
|
|
index="haystack_test_custom", content_field="custom_text_field", embedding_field="custom_embedding_field"
|
2021-02-12 14:57:06 +01:00
|
|
|
)
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "test_1", "meta": {"year": "2019"}},
|
|
|
|
{"content": "test_2", "meta": {"year": "2020"}},
|
|
|
|
{"content": "test_3", "meta": {"year": "2021"}},
|
|
|
|
{"content": "test_4", "meta": {"year": "2021"}},
|
|
|
|
{"content": "test_5", "meta": {"year": "2021"}},
|
2021-02-12 14:57:06 +01:00
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
# test custom "terms" query
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever = BM25Retriever(
|
2021-02-12 14:57:06 +01:00
|
|
|
document_store=document_store,
|
|
|
|
custom_query="""
|
|
|
|
{
|
2022-04-12 11:52:27 +02:00
|
|
|
"size": 10,
|
2021-02-12 14:57:06 +01:00
|
|
|
"query": {
|
|
|
|
"bool": {
|
|
|
|
"should": [{
|
2021-10-13 14:23:23 +02:00
|
|
|
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
2021-02-12 14:57:06 +01:00
|
|
|
"filter": [{"terms": {"year": ${years}}}]}}}""",
|
|
|
|
)
|
2021-02-16 16:24:28 +01:00
|
|
|
results = retriever.retrieve(query="test", filters={"years": ["2020", "2021"]})
|
2021-02-12 14:57:06 +01:00
|
|
|
assert len(results) == 4
|
|
|
|
|
|
|
|
# test custom "term" query
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever = BM25Retriever(
|
2021-02-12 14:57:06 +01:00
|
|
|
document_store=document_store,
|
|
|
|
custom_query="""
|
|
|
|
{
|
2022-04-12 11:52:27 +02:00
|
|
|
"size": 10,
|
2021-02-12 14:57:06 +01:00
|
|
|
"query": {
|
|
|
|
"bool": {
|
|
|
|
"should": [{
|
2021-10-13 14:23:23 +02:00
|
|
|
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
2021-02-12 14:57:06 +01:00
|
|
|
"filter": [{"term": {"year": ${years}}}]}}}""",
|
|
|
|
)
|
2021-02-16 16:24:28 +01:00
|
|
|
results = retriever.retrieve(query="test", filters={"years": "2021"})
|
2021-02-12 14:57:06 +01:00
|
|
|
assert len(results) == 3
|
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2022-03-03 15:19:27 +01:00
|
|
|
@pytest.mark.parametrize(
|
2022-03-21 22:24:09 +07:00
|
|
|
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
|
2022-03-03 15:19:27 +01:00
|
|
|
)
|
2020-10-14 16:15:04 +02:00
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
2022-06-10 18:22:48 +02:00
|
|
|
def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_ids):
|
2021-02-12 14:57:06 +01:00
|
|
|
document_store.return_embedding = True
|
2022-06-10 18:22:48 +02:00
|
|
|
document_store.write_documents(docs_with_ids)
|
2020-10-27 08:33:39 +01:00
|
|
|
document_store.update_embeddings(retriever=retriever)
|
2022-06-10 18:22:48 +02:00
|
|
|
|
|
|
|
docs = document_store.get_all_documents()
|
|
|
|
docs.sort(key=lambda d: d.id)
|
|
|
|
|
|
|
|
print([doc.id for doc in docs])
|
|
|
|
|
|
|
|
expected_values = [0.00892, 0.00780, 0.00482, -0.00626, 0.010966]
|
|
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
|
|
embedding = doc.embedding
|
|
|
|
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
|
|
|
embedding /= np.linalg.norm(embedding)
|
|
|
|
assert len(embedding) == 768
|
|
|
|
assert isclose(embedding[0], expected_value, rel_tol=0.001)
|
2021-01-20 12:52:52 +01:00
|
|
|
|
2020-11-05 13:29:23 +01:00
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2022-03-03 15:19:27 +01:00
|
|
|
@pytest.mark.parametrize(
|
2022-03-21 22:24:09 +07:00
|
|
|
"document_store", ["elasticsearch", "faiss", "memory", "milvus1", "milvus", "weaviate", "pinecone"], indirect=True
|
2022-03-03 15:19:27 +01:00
|
|
|
)
|
2021-06-14 17:53:43 +02:00
|
|
|
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
2022-01-10 17:10:32 +00:00
|
|
|
@pytest.mark.embedding_dim(128)
|
2022-06-10 18:22:48 +02:00
|
|
|
def test_retribert_embedding(document_store, retriever, docs_with_ids):
|
2021-11-04 09:27:12 +01:00
|
|
|
if isinstance(document_store, WeaviateDocumentStore):
|
|
|
|
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
|
|
|
|
# We need 128 here and therefore initialize a new WeaviateDocumentStore.
|
2022-04-26 19:06:30 +02:00
|
|
|
document_store = WeaviateDocumentStore(index="haystack_test", embedding_dim=128, recreate_index=True)
|
2021-06-14 17:53:43 +02:00
|
|
|
document_store.return_embedding = True
|
2022-06-10 18:22:48 +02:00
|
|
|
document_store.write_documents(docs_with_ids)
|
2021-06-14 17:53:43 +02:00
|
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
|
2022-06-10 18:22:48 +02:00
|
|
|
docs = document_store.get_all_documents()
|
|
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
|
|
|
|
expected_values = [0.14017, 0.05975, 0.14267, 0.15099, 0.14383]
|
|
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
|
|
embedding = doc.embedding
|
|
|
|
assert len(embedding) == 128
|
|
|
|
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
|
|
|
embedding /= np.linalg.norm(embedding)
|
|
|
|
assert isclose(embedding[0], expected_value, rel_tol=0.001)
|
2021-06-14 17:53:43 +02:00
|
|
|
|
|
|
|
|
2022-06-07 09:23:03 +02:00
|
|
|
@pytest.mark.integration
|
2021-10-25 12:27:02 +02:00
|
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
2022-06-10 18:22:48 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
2022-01-10 17:10:32 +00:00
|
|
|
@pytest.mark.embedding_dim(512)
|
2021-11-04 09:27:12 +01:00
|
|
|
def test_table_text_retriever_embedding(document_store, retriever, docs):
|
2021-10-25 12:27:02 +02:00
|
|
|
|
|
|
|
document_store.return_embedding = True
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(docs)
|
2021-10-25 12:27:02 +02:00
|
|
|
table_data = {
|
|
|
|
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
2022-02-03 13:43:18 +01:00
|
|
|
"Height": ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"],
|
2021-10-25 12:27:02 +02:00
|
|
|
}
|
|
|
|
table = pd.DataFrame(table_data)
|
|
|
|
table_doc = Document(content=table, content_type="table", id="6")
|
|
|
|
document_store.write_documents([table_doc])
|
|
|
|
document_store.update_embeddings(retriever=retriever)
|
2022-06-10 18:22:48 +02:00
|
|
|
|
|
|
|
docs = document_store.get_all_documents()
|
|
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
|
|
|
|
expected_values = [0.061191384, 0.038075786, 0.27447605, 0.09399721, 0.0959682]
|
|
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
|
|
assert len(doc.embedding) == 512
|
|
|
|
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
|
2021-10-25 12:27:02 +02:00
|
|
|
|
|
|
|
|
2020-11-05 13:29:23 +01:00
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
2022-03-15 11:17:26 +01:00
|
|
|
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
|
|
|
retriever.save(f"{tmp_path}/test_dpr_save")
|
2021-02-12 14:57:06 +01:00
|
|
|
|
2020-11-05 13:29:23 +01:00
|
|
|
def sum_params(model):
|
|
|
|
s = []
|
|
|
|
for p in model.parameters():
|
|
|
|
n = p.cpu().data.numpy()
|
|
|
|
s.append(np.sum(n))
|
|
|
|
return sum(s)
|
2021-02-12 14:57:06 +01:00
|
|
|
|
2020-11-05 13:29:23 +01:00
|
|
|
original_sum_query = sum_params(retriever.query_encoder)
|
|
|
|
original_sum_passage = sum_params(retriever.passage_encoder)
|
|
|
|
del retriever
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
loaded_retriever = DensePassageRetriever.load(f"{tmp_path}/test_dpr_save", document_store)
|
2020-11-05 13:29:23 +01:00
|
|
|
|
|
|
|
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
|
|
|
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
|
|
|
|
|
|
|
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
|
|
|
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
|
|
|
|
|
|
|
# comparison of weights (RAM intense!)
|
|
|
|
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
|
|
|
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
|
|
|
#
|
|
|
|
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
|
|
|
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
|
|
|
|
|
|
|
# attributes
|
2021-04-01 18:23:05 +02:00
|
|
|
assert loaded_retriever.processor.embed_title == True
|
2020-11-05 13:29:23 +01:00
|
|
|
assert loaded_retriever.batch_size == 16
|
2021-04-01 18:23:05 +02:00
|
|
|
assert loaded_retriever.processor.max_seq_len_passage == 256
|
|
|
|
assert loaded_retriever.processor.max_seq_len_query == 64
|
2020-11-05 13:29:23 +01:00
|
|
|
|
|
|
|
# Tokenizer
|
2022-07-04 18:18:14 +02:00
|
|
|
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
|
|
|
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
2020-11-05 13:29:23 +01:00
|
|
|
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
|
|
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
|
|
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
|
|
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
2021-10-25 12:27:02 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
2022-01-10 17:10:32 +00:00
|
|
|
@pytest.mark.embedding_dim(512)
|
2022-03-15 11:17:26 +01:00
|
|
|
def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_store):
|
|
|
|
retriever.save(f"{tmp_path}/test_table_text_retriever_save")
|
2021-10-25 12:27:02 +02:00
|
|
|
|
|
|
|
def sum_params(model):
|
|
|
|
s = []
|
|
|
|
for p in model.parameters():
|
|
|
|
n = p.cpu().data.numpy()
|
|
|
|
s.append(np.sum(n))
|
|
|
|
return sum(s)
|
|
|
|
|
|
|
|
original_sum_query = sum_params(retriever.query_encoder)
|
|
|
|
original_sum_passage = sum_params(retriever.passage_encoder)
|
|
|
|
original_sum_table = sum_params(retriever.table_encoder)
|
|
|
|
del retriever
|
|
|
|
|
2022-03-15 11:17:26 +01:00
|
|
|
loaded_retriever = TableTextRetriever.load(f"{tmp_path}/test_table_text_retriever_save", document_store)
|
2021-10-25 12:27:02 +02:00
|
|
|
|
|
|
|
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
|
|
|
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
|
|
|
loaded_sum_table = sum_params(loaded_retriever.table_encoder)
|
|
|
|
|
|
|
|
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
|
|
|
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
|
|
|
assert abs(original_sum_table - loaded_sum_table) < 0.01
|
|
|
|
|
|
|
|
# attributes
|
|
|
|
assert loaded_retriever.processor.embed_meta_fields == ["name", "section_title", "caption"]
|
|
|
|
assert loaded_retriever.batch_size == 16
|
|
|
|
assert loaded_retriever.processor.max_seq_len_passage == 256
|
|
|
|
assert loaded_retriever.processor.max_seq_len_table == 256
|
|
|
|
assert loaded_retriever.processor.max_seq_len_query == 64
|
|
|
|
|
|
|
|
# Tokenizer
|
2022-07-04 18:18:14 +02:00
|
|
|
assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)
|
|
|
|
assert isinstance(loaded_retriever.table_tokenizer, DPRContextEncoderTokenizerFast)
|
|
|
|
assert isinstance(loaded_retriever.query_tokenizer, DPRQuestionEncoderTokenizerFast)
|
2021-10-25 12:27:02 +02:00
|
|
|
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
|
|
|
assert loaded_retriever.table_tokenizer.do_lower_case == True
|
|
|
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
|
|
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
|
|
|
assert loaded_retriever.table_tokenizer.vocab_size == 30522
|
|
|
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
|
|
|
|
|
|
|
|
2022-01-10 17:10:32 +00:00
|
|
|
@pytest.mark.embedding_dim(128)
|
2021-10-25 12:27:02 +02:00
|
|
|
def test_table_text_retriever_training(document_store):
|
|
|
|
retriever = TableTextRetriever(
|
|
|
|
document_store=document_store,
|
2022-07-22 16:29:30 +02:00
|
|
|
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
|
|
|
|
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
|
|
|
|
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
|
2022-02-03 13:43:18 +01:00
|
|
|
use_gpu=False,
|
2021-10-25 12:27:02 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
retriever.train(
|
2022-02-03 13:43:18 +01:00
|
|
|
data_dir=SAMPLES_PATH / "mmr",
|
2021-10-25 12:27:02 +02:00
|
|
|
train_filename="sample.json",
|
|
|
|
n_epochs=1,
|
|
|
|
n_gpu=0,
|
2022-02-03 13:43:18 +01:00
|
|
|
save_dir="test_table_text_retriever_train",
|
2021-10-25 12:27:02 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
# Load trained model
|
|
|
|
retriever = TableTextRetriever.load(load_dir="test_table_text_retriever_train", document_store=document_store)
|
2022-01-26 22:05:33 +05:30
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
|
|
def test_elasticsearch_highlight():
|
|
|
|
client = Elasticsearch()
|
2022-02-03 13:43:18 +01:00
|
|
|
client.indices.delete(index="haystack_hl_test", ignore=[404])
|
2022-01-26 22:05:33 +05:30
|
|
|
|
|
|
|
# Mapping the content and title field as "text" perform search on these both fields.
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store = ElasticsearchDocumentStore(
|
|
|
|
index="haystack_hl_test",
|
|
|
|
content_field="title",
|
|
|
|
custom_mapping={"mappings": {"properties": {"content": {"type": "text"}, "title": {"type": "text"}}}},
|
2022-01-26 22:05:33 +05:30
|
|
|
)
|
|
|
|
documents = [
|
2022-02-03 13:43:18 +01:00
|
|
|
{
|
|
|
|
"title": "Green tea components",
|
|
|
|
"meta": {
|
|
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink"
|
|
|
|
},
|
|
|
|
"id": "1",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"title": "Green tea catechin",
|
|
|
|
"meta": {"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG)."},
|
|
|
|
"id": "2",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"title": "Minerals in Green tea",
|
|
|
|
"meta": {"content": "Green tea also has small amounts of minerals that can benefit your health."},
|
|
|
|
"id": "3",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"title": "Green tea Benefits",
|
|
|
|
"meta": {"content": "Green tea does more than just keep you alert, it may also help boost brain function."},
|
|
|
|
"id": "4",
|
|
|
|
},
|
2022-01-26 22:05:33 +05:30
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
# Enabled highlighting on "title"&"content" field only using custom query
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever_1 = BM25Retriever(
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store=document_store,
|
|
|
|
custom_query="""{
|
2022-01-26 22:05:33 +05:30
|
|
|
"size": 20,
|
|
|
|
"query": {
|
|
|
|
"bool": {
|
|
|
|
"should": [
|
|
|
|
{
|
|
|
|
"multi_match": {
|
|
|
|
"query": ${query},
|
|
|
|
"fields": [
|
|
|
|
"content^3",
|
|
|
|
"title^5"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"highlight": {
|
|
|
|
"pre_tags": [
|
|
|
|
"**"
|
|
|
|
],
|
|
|
|
"post_tags": [
|
|
|
|
"**"
|
|
|
|
],
|
|
|
|
"number_of_fragments": 3,
|
|
|
|
"fragment_size": 5,
|
|
|
|
"fields": {
|
|
|
|
"content": {},
|
|
|
|
"title": {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}""",
|
|
|
|
)
|
|
|
|
results = retriever_1.retrieve(query="is green tea healthy")
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
assert len(results[0].meta["highlighted"]) == 2
|
|
|
|
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
|
|
|
|
assert results[0].meta["highlighted"]["content"] == ["The **green**", "**tea** plant", "range of **healthy**"]
|
2022-01-26 22:05:33 +05:30
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
# Enabled highlighting on "title" field only using custom query
|
2022-04-26 16:09:39 +02:00
|
|
|
retriever_2 = BM25Retriever(
|
2022-02-03 13:43:18 +01:00
|
|
|
document_store=document_store,
|
|
|
|
custom_query="""{
|
2022-01-26 22:05:33 +05:30
|
|
|
"size": 20,
|
|
|
|
"query": {
|
|
|
|
"bool": {
|
|
|
|
"should": [
|
|
|
|
{
|
|
|
|
"multi_match": {
|
|
|
|
"query": ${query},
|
|
|
|
"fields": [
|
|
|
|
"content^3",
|
|
|
|
"title^5"
|
|
|
|
]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"highlight": {
|
|
|
|
"pre_tags": [
|
|
|
|
"**"
|
|
|
|
],
|
|
|
|
"post_tags": [
|
|
|
|
"**"
|
|
|
|
],
|
|
|
|
"number_of_fragments": 3,
|
|
|
|
"fragment_size": 5,
|
|
|
|
"fields": {
|
|
|
|
"title": {}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}""",
|
|
|
|
)
|
|
|
|
results = retriever_2.retrieve(query="is green tea healthy")
|
|
|
|
|
2022-02-03 13:43:18 +01:00
|
|
|
assert len(results[0].meta["highlighted"]) == 1
|
|
|
|
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
|
2022-03-25 17:53:42 +01:00
|
|
|
|
|
|
|
|
|
|
|
def test_elasticsearch_filter_must_not_increase_results():
|
|
|
|
index = "filter_must_not_increase_results"
|
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index=index, ignore=[404])
|
|
|
|
documents = [
|
|
|
|
{
|
|
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "1",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "2",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "3",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "4",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
doc_store = ElasticsearchDocumentStore(index=index)
|
|
|
|
doc_store.write_documents(documents)
|
|
|
|
results_wo_filter = doc_store.query(query="drink")
|
|
|
|
assert len(results_wo_filter) == 1
|
|
|
|
results_w_filter = doc_store.query(query="drink", filters={"content_type": "text"})
|
|
|
|
assert len(results_w_filter) == 1
|
|
|
|
doc_store.delete_index(index)
|
2022-03-28 22:10:50 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_elasticsearch_all_terms_must_match():
|
|
|
|
index = "all_terms_must_match"
|
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index=index, ignore=[404])
|
|
|
|
documents = [
|
|
|
|
{
|
|
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "1",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "2",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "3",
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
|
|
"meta": {"content_type": "text"},
|
|
|
|
"id": "4",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
doc_store = ElasticsearchDocumentStore(index=index)
|
|
|
|
doc_store.write_documents(documents)
|
|
|
|
results_wo_all_terms_must_match = doc_store.query(query="drink green tea")
|
|
|
|
assert len(results_wo_all_terms_must_match) == 4
|
|
|
|
results_w_all_terms_must_match = doc_store.query(query="drink green tea", all_terms_must_match=True)
|
|
|
|
assert len(results_w_all_terms_must_match) == 1
|
|
|
|
doc_store.delete_index(index)
|
2022-04-12 11:52:27 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format(caplog):
|
|
|
|
document_store = InMemoryDocumentStore()
|
|
|
|
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
|
|
EmbeddingRetriever(
|
2022-06-02 15:05:29 +02:00
|
|
|
document_store=document_store,
|
|
|
|
embedding_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
|
|
|
model_format="farm",
|
2022-04-12 11:52:27 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
assert (
|
2022-06-02 15:05:29 +02:00
|
|
|
"You may need to set model_format='sentence_transformers' to ensure correct loading of model."
|
2022-04-12 11:52:27 +02:00
|
|
|
in caplog.text
|
|
|
|
)
|
2022-04-25 00:53:48 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["es_filter_only"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
|
|
|
def test_es_filter_only(document_store, retriever):
|
|
|
|
docs = [
|
|
|
|
Document(content="Doc1", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc2", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc3", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc4", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc5", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc6", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc7", meta={"f1": "1"}),
|
|
|
|
Document(content="Doc8", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc9", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc10", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc11", meta={"f1": "0"}),
|
|
|
|
Document(content="Doc12", meta={"f1": "0"}),
|
|
|
|
]
|
|
|
|
document_store.write_documents(docs)
|
|
|
|
retrieved_docs = retriever.retrieve(query="", filters={"f1": ["0"]})
|
|
|
|
assert len(retrieved_docs) == 11
|