2022-04-12 11:52:27 +02:00
import logging
2020-08-18 14:04:31 +02:00
import time
2020-06-30 19:05:45 +02:00
2021-02-12 14:57:06 +01:00
import numpy as np
2021-10-25 12:27:02 +02:00
import pandas as pd
2022-04-12 11:52:27 +02:00
from haystack . document_stores . memory import InMemoryDocumentStore
2021-02-12 14:57:06 +01:00
import pytest
2022-01-26 18:12:55 +01:00
from pathlib import Path
2021-02-12 14:57:06 +01:00
from elasticsearch import Elasticsearch
2021-11-04 09:27:12 +01:00
from haystack . document_stores import WeaviateDocumentStore
2021-10-25 15:50:23 +02:00
from haystack . schema import Document
from haystack . document_stores . elasticsearch import ElasticsearchDocumentStore
from haystack . document_stores . faiss import FAISSDocumentStore
2022-02-24 17:43:38 +01:00
from haystack . document_stores import MilvusDocumentStore
2022-04-12 11:52:27 +02:00
from haystack . nodes . retriever . dense import DensePassageRetriever , EmbeddingRetriever , TableTextRetriever
2022-04-29 10:16:02 +02:00
from haystack . nodes . retriever . sparse import BM25Retriever , FilterRetriever , TfidfRetriever
2020-11-05 13:29:23 +01:00
from transformers import DPRContextEncoderTokenizerFast , DPRQuestionEncoderTokenizerFast
2020-06-30 19:05:45 +02:00
2022-05-17 10:55:53 +02:00
from . . conftest import SAMPLES_PATH
2022-01-26 18:12:55 +01:00
2021-10-25 12:27:02 +02:00
2021-11-04 09:27:12 +01:00
@pytest.fixture ( )
def docs ( ) :
documents = [
Document (
content = """ Aaron Aaron ( or ; " " Ahärôn " " ) is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother ' s spokesman ( " " prophet " " ) to the Pharaoh. Part of the Law (Torah) that Moses received from """ ,
meta = { " name " : " 0 " } ,
id = " 1 " ,
) ,
Document (
content = """ Democratic Republic of the Congo to the south. Angola ' s capital, Luanda, lies on the Atlantic coast in the northwest of the country. Angola, although located in a tropical zone, has a climate that is not characterized for this region, due to the confluence of three factors: As a result, Angola ' s climate is characterized by two seasons: rainfall from October to April and drought, known as " " Cacimbo " " , from May to August, drier, as the name implies, and with lower temperatures. On the other hand, while the coastline has high rainfall rates, decreasing from North to South and from to , with """ ,
id = " 2 " ,
) ,
Document (
content = """ Schopenhauer, describing him as an ultimately shallow thinker: " " Schopenhauer has quite a crude mind ... where real depth starts, his comes to an end. " " His friend Bertrand Russell had a low opinion on the philosopher, and attacked him in his famous " " History of Western Philosophy " " for hypocritically praising asceticism yet not acting upon it. On the opposite isle of Russell on the foundations of mathematics, the Dutch mathematician L. E. J. Brouwer incorporated the ideas of Kant and Schopenhauer in intuitionism, where mathematics is considered a purely mental activity, instead of an analytic activity wherein objective properties of reality are """ ,
meta = { " name " : " 1 " } ,
id = " 3 " ,
) ,
Document (
content = """ The Dothraki vocabulary was created by David J. Peterson well in advance of the adaptation. HBO hired the Language Creatio """ ,
meta = { " name " : " 2 " } ,
id = " 4 " ,
) ,
Document (
content = """ The title of the episode refers to the Great Sept of Baelor, the main religious building in King ' s Landing, where the episode ' s pivotal scene takes place. In the world created by George R. R. Martin """ ,
meta = { } ,
id = " 5 " ,
) ,
]
return documents
2021-02-12 14:57:06 +01:00
2022-02-03 13:43:18 +01:00
# TODO check if we this works with only "memory" arg
2021-02-12 14:57:06 +01:00
@pytest.mark.parametrize (
" retriever_with_docs,document_store_with_docs " ,
[
( " dpr " , " elasticsearch " ) ,
( " dpr " , " faiss " ) ,
( " dpr " , " memory " ) ,
2022-02-24 17:43:38 +01:00
( " dpr " , " milvus1 " ) ,
2021-02-12 14:57:06 +01:00
( " embedding " , " elasticsearch " ) ,
( " embedding " , " faiss " ) ,
( " embedding " , " memory " ) ,
2022-02-24 17:43:38 +01:00
( " embedding " , " milvus1 " ) ,
2021-02-12 14:57:06 +01:00
( " elasticsearch " , " elasticsearch " ) ,
( " es_filter_only " , " elasticsearch " ) ,
( " tfidf " , " memory " ) ,
] ,
indirect = True ,
)
def test_retrieval ( retriever_with_docs , document_store_with_docs ) :
2022-04-29 10:16:02 +02:00
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever , TfidfRetriever ) ) :
2021-02-12 14:57:06 +01:00
document_store_with_docs . update_embeddings ( retriever_with_docs )
# test without filters
res = retriever_with_docs . retrieve ( query = " Who lives in Berlin? " )
2021-10-13 14:23:23 +02:00
assert res [ 0 ] . content == " My name is Carla and I live in Berlin "
2022-02-04 13:43:12 +01:00
assert len ( res ) == 5
2021-02-12 14:57:06 +01:00
assert res [ 0 ] . meta [ " name " ] == " filename1 "
# test with filters
if not isinstance ( document_store_with_docs , ( FAISSDocumentStore , MilvusDocumentStore ) ) and not isinstance (
retriever_with_docs , TfidfRetriever
) :
# single filter
2022-03-25 17:53:42 +01:00
result = retriever_with_docs . retrieve ( query = " Christelle " , filters = { " name " : [ " filename3 " ] } , top_k = 5 )
2021-02-12 14:57:06 +01:00
assert len ( result ) == 1
assert type ( result [ 0 ] ) == Document
2021-10-13 14:23:23 +02:00
assert result [ 0 ] . content == " My name is Christelle and I live in Paris "
2021-02-12 14:57:06 +01:00
assert result [ 0 ] . meta [ " name " ] == " filename3 "
# multiple filters
result = retriever_with_docs . retrieve (
2022-03-25 17:53:42 +01:00
query = " Paul " , filters = { " name " : [ " filename2 " ] , " meta_field " : [ " test2 " , " test3 " ] } , top_k = 5
2021-02-12 14:57:06 +01:00
)
assert len ( result ) == 1
assert type ( result [ 0 ] ) == Document
assert result [ 0 ] . meta [ " name " ] == " filename2 "
result = retriever_with_docs . retrieve (
2022-03-25 17:53:42 +01:00
query = " Carla " , filters = { " name " : [ " filename1 " ] , " meta_field " : [ " test2 " , " test3 " ] } , top_k = 5
2021-02-12 14:57:06 +01:00
)
assert len ( result ) == 0
2022-05-11 11:11:00 +02:00
def test_batch_retrieval_single_query ( retriever_with_docs , document_store_with_docs ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever , TfidfRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
2022-05-24 12:33:45 +02:00
res = retriever_with_docs . retrieve_batch ( queries = [ " Who lives in Berlin? " ] )
2022-05-11 11:11:00 +02:00
2022-05-24 12:33:45 +02:00
# Expected return type: List of lists of Documents
2022-05-11 11:11:00 +02:00
assert isinstance ( res , list )
2022-05-24 12:33:45 +02:00
assert isinstance ( res [ 0 ] , list )
assert isinstance ( res [ 0 ] [ 0 ] , Document )
2022-05-11 11:11:00 +02:00
2022-05-24 12:33:45 +02:00
assert len ( res ) == 1
assert len ( res [ 0 ] ) == 5
assert res [ 0 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert res [ 0 ] [ 0 ] . meta [ " name " ] == " filename1 "
2022-05-11 11:11:00 +02:00
def test_batch_retrieval_multiple_queries ( retriever_with_docs , document_store_with_docs ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever , TfidfRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
res = retriever_with_docs . retrieve_batch ( queries = [ " Who lives in Berlin? " , " Who lives in New York? " ] )
# Expected return type: list of lists of Documents
assert isinstance ( res , list )
assert isinstance ( res [ 0 ] , list )
assert isinstance ( res [ 0 ] [ 0 ] , Document )
assert res [ 0 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert len ( res [ 0 ] ) == 5
assert res [ 0 ] [ 0 ] . meta [ " name " ] == " filename1 "
assert res [ 1 ] [ 0 ] . content == " My name is Paul and I live in New York "
assert len ( res [ 1 ] ) == 5
assert res [ 1 ] [ 0 ] . meta [ " name " ] == " filename2 "
2021-02-12 14:57:06 +01:00
@pytest.mark.elasticsearch
2021-10-29 13:52:28 +05:30
def test_elasticsearch_custom_query ( ) :
2021-02-12 14:57:06 +01:00
client = Elasticsearch ( )
client . indices . delete ( index = " haystack_test_custom " , ignore = [ 404 ] )
document_store = ElasticsearchDocumentStore (
2021-10-13 14:23:23 +02:00
index = " haystack_test_custom " , content_field = " custom_text_field " , embedding_field = " custom_embedding_field "
2021-02-12 14:57:06 +01:00
)
documents = [
2021-10-13 14:23:23 +02:00
{ " content " : " test_1 " , " meta " : { " year " : " 2019 " } } ,
{ " content " : " test_2 " , " meta " : { " year " : " 2020 " } } ,
{ " content " : " test_3 " , " meta " : { " year " : " 2021 " } } ,
{ " content " : " test_4 " , " meta " : { " year " : " 2021 " } } ,
{ " content " : " test_5 " , " meta " : { " year " : " 2021 " } } ,
2021-02-12 14:57:06 +01:00
]
document_store . write_documents ( documents )
# test custom "terms" query
2022-04-26 16:09:39 +02:00
retriever = BM25Retriever (
2021-02-12 14:57:06 +01:00
document_store = document_store ,
custom_query = """
{
2022-04-12 11:52:27 +02:00
" size " : 10 ,
2021-02-12 14:57:06 +01:00
" query " : {
" bool " : {
" should " : [ {
2021-10-13 14:23:23 +02:00
" multi_match " : { " query " : $ { query } , " type " : " most_fields " , " fields " : [ " content " ] } } ] ,
2021-02-12 14:57:06 +01:00
" filter " : [ { " terms " : { " year " : $ { years } } } ] } } } """ ,
)
2021-02-16 16:24:28 +01:00
results = retriever . retrieve ( query = " test " , filters = { " years " : [ " 2020 " , " 2021 " ] } )
2021-02-12 14:57:06 +01:00
assert len ( results ) == 4
# test custom "term" query
2022-04-26 16:09:39 +02:00
retriever = BM25Retriever (
2021-02-12 14:57:06 +01:00
document_store = document_store ,
custom_query = """
{
2022-04-12 11:52:27 +02:00
" size " : 10 ,
2021-02-12 14:57:06 +01:00
" query " : {
" bool " : {
" should " : [ {
2021-10-13 14:23:23 +02:00
" multi_match " : { " query " : $ { query } , " type " : " most_fields " , " fields " : [ " content " ] } } ] ,
2021-02-12 14:57:06 +01:00
" filter " : [ { " term " : { " year " : $ { years } } } ] } } } """ ,
)
2021-02-16 16:24:28 +01:00
results = retriever . retrieve ( query = " test " , filters = { " years " : " 2021 " } )
2021-02-12 14:57:06 +01:00
assert len ( results ) == 3
2020-10-26 19:19:10 +01:00
@pytest.mark.slow
2022-03-03 15:19:27 +01:00
@pytest.mark.parametrize (
2022-03-21 22:24:09 +07:00
" document_store " , [ " elasticsearch " , " faiss " , " memory " , " milvus1 " , " milvus " , " weaviate " , " pinecone " ] , indirect = True
2022-03-03 15:19:27 +01:00
)
2020-10-14 16:15:04 +02:00
@pytest.mark.parametrize ( " retriever " , [ " dpr " ] , indirect = True )
2021-11-04 09:27:12 +01:00
def test_dpr_embedding ( document_store , retriever , docs ) :
2020-06-30 19:05:45 +02:00
2021-02-12 14:57:06 +01:00
document_store . return_embedding = True
2021-11-04 09:27:12 +01:00
document_store . write_documents ( docs )
2020-10-27 08:33:39 +01:00
document_store . update_embeddings ( retriever = retriever )
time . sleep ( 1 )
2022-01-12 19:28:20 +01:00
# always normalize vector as faiss returns normalized vectors and other document stores do not
doc_1 = document_store . get_document_by_id ( " 1 " ) . embedding
doc_1 / = np . linalg . norm ( doc_1 )
assert len ( doc_1 ) == 768
assert abs ( doc_1 [ 0 ] - ( - 0.0250 ) ) < 0.001
doc_2 = document_store . get_document_by_id ( " 2 " ) . embedding
doc_2 / = np . linalg . norm ( doc_2 )
assert abs ( doc_2 [ 0 ] - ( - 0.0314 ) ) < 0.001
doc_3 = document_store . get_document_by_id ( " 3 " ) . embedding
doc_3 / = np . linalg . norm ( doc_3 )
assert abs ( doc_3 [ 0 ] - ( - 0.0200 ) ) < 0.001
doc_4 = document_store . get_document_by_id ( " 4 " ) . embedding
doc_4 / = np . linalg . norm ( doc_4 )
assert abs ( doc_4 [ 0 ] - ( - 0.0070 ) ) < 0.001
doc_5 = document_store . get_document_by_id ( " 5 " ) . embedding
doc_5 / = np . linalg . norm ( doc_5 )
assert abs ( doc_5 [ 0 ] - ( - 0.0049 ) ) < 0.001
2021-01-20 12:52:52 +01:00
2020-11-05 13:29:23 +01:00
2021-06-14 17:53:43 +02:00
@pytest.mark.slow
2022-03-03 15:19:27 +01:00
@pytest.mark.parametrize (
2022-03-21 22:24:09 +07:00
" document_store " , [ " elasticsearch " , " faiss " , " memory " , " milvus1 " , " milvus " , " weaviate " , " pinecone " ] , indirect = True
2022-03-03 15:19:27 +01:00
)
2021-06-14 17:53:43 +02:00
@pytest.mark.parametrize ( " retriever " , [ " retribert " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 128 )
2021-11-04 09:27:12 +01:00
def test_retribert_embedding ( document_store , retriever , docs ) :
if isinstance ( document_store , WeaviateDocumentStore ) :
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
# We need 128 here and therefore initialize a new WeaviateDocumentStore.
2022-04-26 19:06:30 +02:00
document_store = WeaviateDocumentStore ( index = " haystack_test " , embedding_dim = 128 , recreate_index = True )
2021-06-14 17:53:43 +02:00
document_store . return_embedding = True
2021-11-04 09:27:12 +01:00
document_store . write_documents ( docs )
2021-06-14 17:53:43 +02:00
document_store . update_embeddings ( retriever = retriever )
time . sleep ( 1 )
assert len ( document_store . get_document_by_id ( " 1 " ) . embedding ) == 128
assert abs ( document_store . get_document_by_id ( " 1 " ) . embedding [ 0 ] ) < 0.6
assert abs ( document_store . get_document_by_id ( " 2 " ) . embedding [ 0 ] ) < 0.03
assert abs ( document_store . get_document_by_id ( " 3 " ) . embedding [ 0 ] ) < 0.095
assert abs ( document_store . get_document_by_id ( " 4 " ) . embedding [ 0 ] ) < 0.3
assert abs ( document_store . get_document_by_id ( " 5 " ) . embedding [ 0 ] ) < 0.32
2021-10-25 12:27:02 +02:00
@pytest.mark.slow
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " elasticsearch " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 512 )
2021-11-04 09:27:12 +01:00
def test_table_text_retriever_embedding ( document_store , retriever , docs ) :
2021-10-25 12:27:02 +02:00
document_store . return_embedding = True
2021-11-04 09:27:12 +01:00
document_store . write_documents ( docs )
2021-10-25 12:27:02 +02:00
table_data = {
" Mountain " : [ " Mount Everest " , " K2 " , " Kangchenjunga " , " Lhotse " , " Makalu " ] ,
2022-02-03 13:43:18 +01:00
" Height " : [ " 8848m " , " 8,611 m " , " 8 586m " , " 8 516 m " , " 8,485m " ] ,
2021-10-25 12:27:02 +02:00
}
table = pd . DataFrame ( table_data )
table_doc = Document ( content = table , content_type = " table " , id = " 6 " )
document_store . write_documents ( [ table_doc ] )
document_store . update_embeddings ( retriever = retriever )
time . sleep ( 1 )
doc_1 = document_store . get_document_by_id ( " 1 " )
assert len ( doc_1 . embedding ) == 512
assert abs ( doc_1 . embedding [ 0 ] - ( 0.0593 ) ) < 0.001
doc_2 = document_store . get_document_by_id ( " 2 " )
assert abs ( doc_2 . embedding [ 0 ] - ( 0.9031 ) ) < 0.001
doc_3 = document_store . get_document_by_id ( " 3 " )
assert abs ( doc_3 . embedding [ 0 ] - ( 0.1366 ) ) < 0.001
doc_4 = document_store . get_document_by_id ( " 4 " )
assert abs ( doc_4 . embedding [ 0 ] - ( 0.0575 ) ) < 0.001
doc_5 = document_store . get_document_by_id ( " 5 " )
assert abs ( doc_5 . embedding [ 0 ] - ( 0.1486 ) ) < 0.001
doc_6 = document_store . get_document_by_id ( " 6 " )
assert len ( doc_6 . embedding ) == 512
assert abs ( doc_6 . embedding [ 0 ] - ( 0.2745 ) ) < 0.001
2020-11-05 13:29:23 +01:00
@pytest.mark.parametrize ( " retriever " , [ " dpr " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
2022-03-15 11:17:26 +01:00
def test_dpr_saving_and_loading ( tmp_path , retriever , document_store ) :
retriever . save ( f " { tmp_path } /test_dpr_save " )
2021-02-12 14:57:06 +01:00
2020-11-05 13:29:23 +01:00
def sum_params ( model ) :
s = [ ]
for p in model . parameters ( ) :
n = p . cpu ( ) . data . numpy ( )
s . append ( np . sum ( n ) )
return sum ( s )
2021-02-12 14:57:06 +01:00
2020-11-05 13:29:23 +01:00
original_sum_query = sum_params ( retriever . query_encoder )
original_sum_passage = sum_params ( retriever . passage_encoder )
del retriever
2022-03-15 11:17:26 +01:00
loaded_retriever = DensePassageRetriever . load ( f " { tmp_path } /test_dpr_save " , document_store )
2020-11-05 13:29:23 +01:00
loaded_sum_query = sum_params ( loaded_retriever . query_encoder )
loaded_sum_passage = sum_params ( loaded_retriever . passage_encoder )
assert abs ( original_sum_query - loaded_sum_query ) < 0.1
assert abs ( original_sum_passage - loaded_sum_passage ) < 0.1
# comparison of weights (RAM intense!)
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
#
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
# attributes
2021-04-01 18:23:05 +02:00
assert loaded_retriever . processor . embed_title == True
2020-11-05 13:29:23 +01:00
assert loaded_retriever . batch_size == 16
2021-04-01 18:23:05 +02:00
assert loaded_retriever . processor . max_seq_len_passage == 256
assert loaded_retriever . processor . max_seq_len_query == 64
2020-11-05 13:29:23 +01:00
# Tokenizer
assert isinstance ( loaded_retriever . passage_tokenizer , DPRContextEncoderTokenizerFast )
assert isinstance ( loaded_retriever . query_tokenizer , DPRQuestionEncoderTokenizerFast )
assert loaded_retriever . passage_tokenizer . do_lower_case == True
assert loaded_retriever . query_tokenizer . do_lower_case == True
assert loaded_retriever . passage_tokenizer . vocab_size == 30522
assert loaded_retriever . query_tokenizer . vocab_size == 30522
2021-01-21 10:15:41 +01:00
assert loaded_retriever . passage_tokenizer . model_max_length == 512
assert loaded_retriever . query_tokenizer . model_max_length == 512
2021-10-25 12:27:02 +02:00
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 512 )
2022-03-15 11:17:26 +01:00
def test_table_text_retriever_saving_and_loading ( tmp_path , retriever , document_store ) :
retriever . save ( f " { tmp_path } /test_table_text_retriever_save " )
2021-10-25 12:27:02 +02:00
def sum_params ( model ) :
s = [ ]
for p in model . parameters ( ) :
n = p . cpu ( ) . data . numpy ( )
s . append ( np . sum ( n ) )
return sum ( s )
original_sum_query = sum_params ( retriever . query_encoder )
original_sum_passage = sum_params ( retriever . passage_encoder )
original_sum_table = sum_params ( retriever . table_encoder )
del retriever
2022-03-15 11:17:26 +01:00
loaded_retriever = TableTextRetriever . load ( f " { tmp_path } /test_table_text_retriever_save " , document_store )
2021-10-25 12:27:02 +02:00
loaded_sum_query = sum_params ( loaded_retriever . query_encoder )
loaded_sum_passage = sum_params ( loaded_retriever . passage_encoder )
loaded_sum_table = sum_params ( loaded_retriever . table_encoder )
assert abs ( original_sum_query - loaded_sum_query ) < 0.1
assert abs ( original_sum_passage - loaded_sum_passage ) < 0.1
assert abs ( original_sum_table - loaded_sum_table ) < 0.01
# attributes
assert loaded_retriever . processor . embed_meta_fields == [ " name " , " section_title " , " caption " ]
assert loaded_retriever . batch_size == 16
assert loaded_retriever . processor . max_seq_len_passage == 256
assert loaded_retriever . processor . max_seq_len_table == 256
assert loaded_retriever . processor . max_seq_len_query == 64
# Tokenizer
assert isinstance ( loaded_retriever . passage_tokenizer , DPRContextEncoderTokenizerFast )
assert isinstance ( loaded_retriever . table_tokenizer , DPRContextEncoderTokenizerFast )
assert isinstance ( loaded_retriever . query_tokenizer , DPRQuestionEncoderTokenizerFast )
assert loaded_retriever . passage_tokenizer . do_lower_case == True
assert loaded_retriever . table_tokenizer . do_lower_case == True
assert loaded_retriever . query_tokenizer . do_lower_case == True
assert loaded_retriever . passage_tokenizer . vocab_size == 30522
assert loaded_retriever . table_tokenizer . vocab_size == 30522
assert loaded_retriever . query_tokenizer . vocab_size == 30522
assert loaded_retriever . passage_tokenizer . model_max_length == 512
assert loaded_retriever . table_tokenizer . model_max_length == 512
assert loaded_retriever . query_tokenizer . model_max_length == 512
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 128 )
2021-10-25 12:27:02 +02:00
def test_table_text_retriever_training ( document_store ) :
retriever = TableTextRetriever (
document_store = document_store ,
query_embedding_model = " prajjwal1/bert-tiny " ,
passage_embedding_model = " prajjwal1/bert-tiny " ,
table_embedding_model = " prajjwal1/bert-tiny " ,
2022-02-03 13:43:18 +01:00
use_gpu = False ,
2021-10-25 12:27:02 +02:00
)
retriever . train (
2022-02-03 13:43:18 +01:00
data_dir = SAMPLES_PATH / " mmr " ,
2021-10-25 12:27:02 +02:00
train_filename = " sample.json " ,
n_epochs = 1 ,
n_gpu = 0 ,
2022-02-03 13:43:18 +01:00
save_dir = " test_table_text_retriever_train " ,
2021-10-25 12:27:02 +02:00
)
# Load trained model
retriever = TableTextRetriever . load ( load_dir = " test_table_text_retriever_train " , document_store = document_store )
2022-01-26 22:05:33 +05:30
@pytest.mark.elasticsearch
def test_elasticsearch_highlight ( ) :
client = Elasticsearch ( )
2022-02-03 13:43:18 +01:00
client . indices . delete ( index = " haystack_hl_test " , ignore = [ 404 ] )
2022-01-26 22:05:33 +05:30
# Mapping the content and title field as "text" perform search on these both fields.
2022-02-03 13:43:18 +01:00
document_store = ElasticsearchDocumentStore (
index = " haystack_hl_test " ,
content_field = " title " ,
custom_mapping = { " mappings " : { " properties " : { " content " : { " type " : " text " } , " title " : { " type " : " text " } } } } ,
2022-01-26 22:05:33 +05:30
)
documents = [
2022-02-03 13:43:18 +01:00
{
" title " : " Green tea components " ,
" meta " : {
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink "
} ,
" id " : " 1 " ,
} ,
{
" title " : " Green tea catechin " ,
" meta " : { " content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " } ,
" id " : " 2 " ,
} ,
{
" title " : " Minerals in Green tea " ,
" meta " : { " content " : " Green tea also has small amounts of minerals that can benefit your health. " } ,
" id " : " 3 " ,
} ,
{
" title " : " Green tea Benefits " ,
" meta " : { " content " : " Green tea does more than just keep you alert, it may also help boost brain function. " } ,
" id " : " 4 " ,
} ,
2022-01-26 22:05:33 +05:30
]
document_store . write_documents ( documents )
# Enabled highlighting on "title"&"content" field only using custom query
2022-04-26 16:09:39 +02:00
retriever_1 = BM25Retriever (
2022-02-03 13:43:18 +01:00
document_store = document_store ,
custom_query = """ {
2022-01-26 22:05:33 +05:30
" size " : 20 ,
" query " : {
" bool " : {
" should " : [
{
" multi_match " : {
" query " : $ { query } ,
" fields " : [
" content^3 " ,
" title^5 "
]
}
}
]
}
} ,
" highlight " : {
" pre_tags " : [
" ** "
] ,
" post_tags " : [
" ** "
] ,
" number_of_fragments " : 3 ,
" fragment_size " : 5 ,
" fields " : {
" content " : { } ,
" title " : { }
}
}
} """ ,
)
results = retriever_1 . retrieve ( query = " is green tea healthy " )
2022-02-03 13:43:18 +01:00
assert len ( results [ 0 ] . meta [ " highlighted " ] ) == 2
assert results [ 0 ] . meta [ " highlighted " ] [ " title " ] == [ " **Green** " , " **tea** components " ]
assert results [ 0 ] . meta [ " highlighted " ] [ " content " ] == [ " The **green** " , " **tea** plant " , " range of **healthy** " ]
2022-01-26 22:05:33 +05:30
2022-02-03 13:43:18 +01:00
# Enabled highlighting on "title" field only using custom query
2022-04-26 16:09:39 +02:00
retriever_2 = BM25Retriever (
2022-02-03 13:43:18 +01:00
document_store = document_store ,
custom_query = """ {
2022-01-26 22:05:33 +05:30
" size " : 20 ,
" query " : {
" bool " : {
" should " : [
{
" multi_match " : {
" query " : $ { query } ,
" fields " : [
" content^3 " ,
" title^5 "
]
}
}
]
}
} ,
" highlight " : {
" pre_tags " : [
" ** "
] ,
" post_tags " : [
" ** "
] ,
" number_of_fragments " : 3 ,
" fragment_size " : 5 ,
" fields " : {
" title " : { }
}
}
} """ ,
)
results = retriever_2 . retrieve ( query = " is green tea healthy " )
2022-02-03 13:43:18 +01:00
assert len ( results [ 0 ] . meta [ " highlighted " ] ) == 1
assert results [ 0 ] . meta [ " highlighted " ] [ " title " ] == [ " **Green** " , " **tea** components " ]
2022-03-25 17:53:42 +01:00
def test_elasticsearch_filter_must_not_increase_results ( ) :
index = " filter_must_not_increase_results "
client = Elasticsearch ( )
client . indices . delete ( index = index , ignore = [ 404 ] )
documents = [
{
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink " ,
" meta " : { " content_type " : " text " } ,
" id " : " 1 " ,
} ,
{
" content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " ,
" meta " : { " content_type " : " text " } ,
" id " : " 2 " ,
} ,
{
" content " : " Green tea also has small amounts of minerals that can benefit your health. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 3 " ,
} ,
{
" content " : " Green tea does more than just keep you alert, it may also help boost brain function. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 4 " ,
} ,
]
doc_store = ElasticsearchDocumentStore ( index = index )
doc_store . write_documents ( documents )
results_wo_filter = doc_store . query ( query = " drink " )
assert len ( results_wo_filter ) == 1
results_w_filter = doc_store . query ( query = " drink " , filters = { " content_type " : " text " } )
assert len ( results_w_filter ) == 1
doc_store . delete_index ( index )
2022-03-28 22:10:50 +02:00
def test_elasticsearch_all_terms_must_match ( ) :
index = " all_terms_must_match "
client = Elasticsearch ( )
client . indices . delete ( index = index , ignore = [ 404 ] )
documents = [
{
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink " ,
" meta " : { " content_type " : " text " } ,
" id " : " 1 " ,
} ,
{
" content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " ,
" meta " : { " content_type " : " text " } ,
" id " : " 2 " ,
} ,
{
" content " : " Green tea also has small amounts of minerals that can benefit your health. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 3 " ,
} ,
{
" content " : " Green tea does more than just keep you alert, it may also help boost brain function. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 4 " ,
} ,
]
doc_store = ElasticsearchDocumentStore ( index = index )
doc_store . write_documents ( documents )
results_wo_all_terms_must_match = doc_store . query ( query = " drink green tea " )
assert len ( results_wo_all_terms_must_match ) == 4
results_w_all_terms_must_match = doc_store . query ( query = " drink green tea " , all_terms_must_match = True )
assert len ( results_w_all_terms_must_match ) == 1
doc_store . delete_index ( index )
2022-04-12 11:52:27 +02:00
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format ( caplog ) :
document_store = InMemoryDocumentStore ( )
with caplog . at_level ( logging . WARNING ) :
EmbeddingRetriever (
document_store = document_store , embedding_model = " sentence-transformers/paraphrase-multilingual-mpnet-base-v2 "
)
assert (
" You may need to set ' model_format= ' sentence_transformers ' to ensure correct loading of model. "
in caplog . text
)
2022-04-25 00:53:48 -07:00
@pytest.mark.parametrize ( " retriever " , [ " es_filter_only " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " elasticsearch " ] , indirect = True )
def test_es_filter_only ( document_store , retriever ) :
docs = [
Document ( content = " Doc1 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc2 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc3 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc4 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc5 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc6 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc7 " , meta = { " f1 " : " 1 " } ) ,
Document ( content = " Doc8 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc9 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc10 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc11 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc12 " , meta = { " f1 " : " 0 " } ) ,
]
document_store . write_documents ( docs )
retrieved_docs = retriever . retrieve ( query = " " , filters = { " f1 " : [ " 0 " ] } )
assert len ( retrieved_docs ) == 11