2022-04-12 11:52:27 +02:00
import logging
2022-10-14 15:01:03 +02:00
import os
2022-06-10 18:22:48 +02:00
from math import isclose
2023-03-27 18:14:58 +02:00
from typing import Dict , List , Optional , Union , Tuple
2023-03-27 15:31:22 +02:00
from unittest . mock import patch , Mock , DEFAULT
2020-06-30 19:05:45 +02:00
2022-10-17 18:58:35 +02:00
import pytest
2021-02-12 14:57:06 +01:00
import numpy as np
2021-10-25 12:27:02 +02:00
import pandas as pd
2023-03-27 18:14:58 +02:00
import requests
from boilerpy3 . extractors import ArticleExtractor
2022-10-17 18:58:35 +02:00
from pandas . testing import assert_frame_equal
2023-05-17 18:54:34 +02:00
from transformers import PreTrainedTokenizerFast
2021-11-04 09:27:12 +01:00
2023-04-26 10:14:20 +02:00
try :
from elasticsearch import Elasticsearch
except ( ImportError , ModuleNotFoundError ) as ie :
from haystack . utils . import_utils import _optional_component_not_installed
_optional_component_not_installed ( __name__ , " elasticsearch " , ie )
2022-12-12 14:04:29 +01:00
from haystack . document_stores . base import BaseDocumentStore , FilterType
2022-10-17 18:58:35 +02:00
from haystack . document_stores . memory import InMemoryDocumentStore
2021-11-04 09:27:12 +01:00
from haystack . document_stores import WeaviateDocumentStore
2022-07-22 16:29:30 +02:00
from haystack . nodes . retriever . base import BaseRetriever
2023-03-27 18:14:58 +02:00
from haystack . nodes . retriever . web import WebRetriever
from haystack . nodes . search_engine import WebSearch
2023-03-27 15:31:22 +02:00
from haystack . nodes . retriever import Text2SparqlRetriever
2022-10-14 15:01:03 +02:00
from haystack . pipelines import DocumentSearchPipeline
2021-10-25 15:50:23 +02:00
from haystack . schema import Document
from haystack . document_stores . elasticsearch import ElasticsearchDocumentStore
2022-09-26 15:18:12 +02:00
from haystack . nodes . retriever . dense import DensePassageRetriever , EmbeddingRetriever , TableTextRetriever
2022-04-29 10:16:02 +02:00
from haystack . nodes . retriever . sparse import BM25Retriever , FilterRetriever , TfidfRetriever
2022-10-17 18:58:35 +02:00
from haystack . nodes . retriever . multimodal import MultiModalRetriever
2020-06-30 19:05:45 +02:00
2023-04-26 10:14:20 +02:00
from . . conftest import MockBaseRetriever , fail_at_version
2022-01-26 18:12:55 +01:00
2021-10-25 12:27:02 +02:00
2022-02-03 13:43:18 +01:00
# TODO check if we this works with only "memory" arg
2021-02-12 14:57:06 +01:00
@pytest.mark.parametrize (
" retriever_with_docs,document_store_with_docs " ,
[
2022-07-05 11:31:11 +02:00
( " mdr " , " elasticsearch " ) ,
( " mdr " , " faiss " ) ,
( " mdr " , " memory " ) ,
2022-11-15 09:54:55 +01:00
( " mdr " , " milvus " ) ,
2021-02-12 14:57:06 +01:00
( " dpr " , " elasticsearch " ) ,
( " dpr " , " faiss " ) ,
( " dpr " , " memory " ) ,
2022-11-15 09:54:55 +01:00
( " dpr " , " milvus " ) ,
2021-02-12 14:57:06 +01:00
( " embedding " , " elasticsearch " ) ,
( " embedding " , " faiss " ) ,
( " embedding " , " memory " ) ,
2022-11-15 09:54:55 +01:00
( " embedding " , " milvus " ) ,
2022-11-22 09:24:52 +01:00
( " bm25 " , " elasticsearch " ) ,
( " bm25 " , " memory " ) ,
2022-12-19 11:24:46 -05:00
( " bm25 " , " weaviate " ) ,
( " es_filter_only " , " elasticsearch " ) ,
2021-02-12 14:57:06 +01:00
( " tfidf " , " memory " ) ,
] ,
indirect = True ,
)
2022-11-22 09:24:52 +01:00
def test_retrieval_without_filters ( retriever_with_docs : BaseRetriever , document_store_with_docs : BaseDocumentStore ) :
2022-12-19 12:07:49 +01:00
if not isinstance ( retriever_with_docs , ( BM25Retriever , TfidfRetriever ) ) :
2021-02-12 14:57:06 +01:00
document_store_with_docs . update_embeddings ( retriever_with_docs )
2022-12-19 11:24:46 -05:00
# NOTE: FilterRetriever simply returns all documents matching a filter,
# so without filters applied it does nothing
if not isinstance ( retriever_with_docs , FilterRetriever ) :
# the BM25 implementation in Weaviate would NOT pick up the expected records
2023-01-30 10:07:07 +01:00
# because of the lack of stemming: "Who lives in berlin" returns only 1 record while
# "Who live in berlin" returns all 5 records.
# TODO - In Weaviate 1.19.0 there is a fix for the lack of stemming, which means that once 1.19.0 is released
2022-12-19 11:24:46 -05:00
# this `if` can be removed, as the standard search query "Who lives in Berlin?" should work with Weaviate.
2023-01-30 10:07:07 +01:00
# See https://github.com/weaviate/weaviate/issues/2439
2022-12-19 11:24:46 -05:00
if isinstance ( document_store_with_docs , WeaviateDocumentStore ) :
2023-01-30 10:07:07 +01:00
res = retriever_with_docs . retrieve ( query = " Who live in berlin " )
2022-12-19 11:24:46 -05:00
else :
res = retriever_with_docs . retrieve ( query = " Who lives in Berlin? " )
assert res [ 0 ] . content == " My name is Carla and I live in Berlin "
assert len ( res ) == 5
assert res [ 0 ] . meta [ " name " ] == " filename1 "
2021-02-12 14:57:06 +01:00
2022-11-22 09:24:52 +01:00
@pytest.mark.parametrize (
" retriever_with_docs,document_store_with_docs " ,
[
( " mdr " , " elasticsearch " ) ,
( " mdr " , " memory " ) ,
( " dpr " , " elasticsearch " ) ,
( " dpr " , " memory " ) ,
( " embedding " , " elasticsearch " ) ,
( " embedding " , " memory " ) ,
( " bm25 " , " elasticsearch " ) ,
2023-03-29 10:51:22 -04:00
( " bm25 " , " weaviate " ) ,
2022-11-22 09:24:52 +01:00
( " es_filter_only " , " elasticsearch " ) ,
] ,
indirect = True ,
)
def test_retrieval_with_filters ( retriever_with_docs : BaseRetriever , document_store_with_docs : BaseDocumentStore ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
# single filter
result = retriever_with_docs . retrieve ( query = " Christelle " , filters = { " name " : [ " filename3 " ] } , top_k = 5 )
assert len ( result ) == 1
assert type ( result [ 0 ] ) == Document
assert result [ 0 ] . content == " My name is Christelle and I live in Paris "
assert result [ 0 ] . meta [ " name " ] == " filename3 "
# multiple filters
result = retriever_with_docs . retrieve (
query = " Paul " , filters = { " name " : [ " filename2 " ] , " meta_field " : [ " test2 " , " test3 " ] } , top_k = 5
)
assert len ( result ) == 1
assert type ( result [ 0 ] ) == Document
assert result [ 0 ] . meta [ " name " ] == " filename2 "
result = retriever_with_docs . retrieve (
query = " Carla " , filters = { " name " : [ " filename1 " ] , " meta_field " : [ " test2 " , " test3 " ] } , top_k = 5
)
assert len ( result ) == 0
2021-02-12 14:57:06 +01:00
2022-12-19 12:07:49 +01:00
def test_tfidf_retriever_multiple_indexes ( ) :
docs_index_0 = [ Document ( content = " test_1 " ) , Document ( content = " test_2 " ) , Document ( content = " test_3 " ) ]
docs_index_1 = [ Document ( content = " test_4 " ) , Document ( content = " test_5 " ) ]
ds = InMemoryDocumentStore ( index = " index_0 " )
tfidf_retriever = TfidfRetriever ( document_store = ds )
ds . write_documents ( docs_index_0 )
tfidf_retriever . fit ( ds , index = " index_0 " )
ds . write_documents ( docs_index_1 , index = " index_1 " )
tfidf_retriever . fit ( ds , index = " index_1 " )
assert tfidf_retriever . document_counts [ " index_0 " ] == ds . get_document_count ( index = " index_0 " )
assert tfidf_retriever . document_counts [ " index_1 " ] == ds . get_document_count ( index = " index_1 " )
2022-10-17 19:00:13 +02:00
def test_retrieval_empty_query ( document_store : BaseDocumentStore ) :
# test with empty query using the run() method
mock_document = Document ( id = " 0 " , content = " test " )
retriever = MockBaseRetriever ( document_store = document_store , mock_document = mock_document )
result = retriever . run ( root_node = " Query " , query = " " , filters = { } )
assert result [ 0 ] [ " documents " ] [ 0 ] == mock_document
result = retriever . run_batch ( root_node = " Query " , queries = [ " " ] , filters = { } )
assert result [ 0 ] [ " documents " ] [ 0 ] [ 0 ] == mock_document
2023-02-08 14:39:20 +01:00
@pytest.mark.parametrize ( " retriever_with_docs " , [ " embedding " , " dpr " , " tfidf " ] , indirect = True )
2022-05-11 11:11:00 +02:00
def test_batch_retrieval_single_query ( retriever_with_docs , document_store_with_docs ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever , TfidfRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
2022-05-24 12:33:45 +02:00
res = retriever_with_docs . retrieve_batch ( queries = [ " Who lives in Berlin? " ] )
2022-05-11 11:11:00 +02:00
2022-05-24 12:33:45 +02:00
# Expected return type: List of lists of Documents
2022-05-11 11:11:00 +02:00
assert isinstance ( res , list )
2022-05-24 12:33:45 +02:00
assert isinstance ( res [ 0 ] , list )
assert isinstance ( res [ 0 ] [ 0 ] , Document )
2022-05-11 11:11:00 +02:00
2022-05-24 12:33:45 +02:00
assert len ( res ) == 1
assert len ( res [ 0 ] ) == 5
assert res [ 0 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert res [ 0 ] [ 0 ] . meta [ " name " ] == " filename1 "
2022-05-11 11:11:00 +02:00
2023-02-08 14:39:20 +01:00
@pytest.mark.parametrize ( " retriever_with_docs " , [ " embedding " , " dpr " , " tfidf " ] , indirect = True )
2022-05-11 11:11:00 +02:00
def test_batch_retrieval_multiple_queries ( retriever_with_docs , document_store_with_docs ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever , TfidfRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
res = retriever_with_docs . retrieve_batch ( queries = [ " Who lives in Berlin? " , " Who lives in New York? " ] )
# Expected return type: list of lists of Documents
assert isinstance ( res , list )
assert isinstance ( res [ 0 ] , list )
assert isinstance ( res [ 0 ] [ 0 ] , Document )
assert res [ 0 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert len ( res [ 0 ] ) == 5
assert res [ 0 ] [ 0 ] . meta [ " name " ] == " filename1 "
assert res [ 1 ] [ 0 ] . content == " My name is Paul and I live in New York "
assert len ( res [ 1 ] ) == 5
assert res [ 1 ] [ 0 ] . meta [ " name " ] == " filename2 "
2022-12-12 14:04:29 +01:00
@pytest.mark.parametrize ( " retriever_with_docs " , [ " bm25 " ] , indirect = True )
def test_batch_retrieval_multiple_queries_with_filters ( retriever_with_docs , document_store_with_docs ) :
if not isinstance ( retriever_with_docs , ( BM25Retriever , FilterRetriever ) ) :
document_store_with_docs . update_embeddings ( retriever_with_docs )
2022-12-19 11:24:46 -05:00
# Weaviate does not support BM25 with filters yet, only after Weaviate v1.18.0
# TODO - remove this once Weaviate starts supporting BM25 WITH filters
# You might also need to modify the first query, as Weaviate having problems with
# retrieving the "My name is Carla and I live in Berlin" record just with the
# "Who lives in Berlin?" BM25 query
if isinstance ( document_store_with_docs , WeaviateDocumentStore ) :
return
2022-12-12 14:04:29 +01:00
res = retriever_with_docs . retrieve_batch (
queries = [ " Who lives in Berlin? " , " Who lives in New York? " ] , filters = [ { " name " : " filename1 " } , None ]
)
# Expected return type: list of lists of Documents
assert isinstance ( res , list )
assert isinstance ( res [ 0 ] , list )
assert isinstance ( res [ 0 ] [ 0 ] , Document )
assert res [ 0 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert len ( res [ 0 ] ) == 5
assert res [ 0 ] [ 0 ] . meta [ " name " ] == " filename1 "
assert res [ 1 ] [ 0 ] . content == " My name is Paul and I live in New York "
assert len ( res [ 1 ] ) == 5
assert res [ 1 ] [ 0 ] . meta [ " name " ] == " filename2 "
2021-02-12 14:57:06 +01:00
@pytest.mark.elasticsearch
2021-10-29 13:52:28 +05:30
def test_elasticsearch_custom_query ( ) :
2021-02-12 14:57:06 +01:00
client = Elasticsearch ( )
client . indices . delete ( index = " haystack_test_custom " , ignore = [ 404 ] )
document_store = ElasticsearchDocumentStore (
2021-10-13 14:23:23 +02:00
index = " haystack_test_custom " , content_field = " custom_text_field " , embedding_field = " custom_embedding_field "
2021-02-12 14:57:06 +01:00
)
documents = [
2021-10-13 14:23:23 +02:00
{ " content " : " test_1 " , " meta " : { " year " : " 2019 " } } ,
{ " content " : " test_2 " , " meta " : { " year " : " 2020 " } } ,
{ " content " : " test_3 " , " meta " : { " year " : " 2021 " } } ,
{ " content " : " test_4 " , " meta " : { " year " : " 2021 " } } ,
{ " content " : " test_5 " , " meta " : { " year " : " 2021 " } } ,
2021-02-12 14:57:06 +01:00
]
document_store . write_documents ( documents )
# test custom "terms" query
2022-04-26 16:09:39 +02:00
retriever = BM25Retriever (
2021-02-12 14:57:06 +01:00
document_store = document_store ,
custom_query = """
{
2022-04-12 11:52:27 +02:00
" size " : 10 ,
2021-02-12 14:57:06 +01:00
" query " : {
" bool " : {
" should " : [ {
2021-10-13 14:23:23 +02:00
" multi_match " : { " query " : $ { query } , " type " : " most_fields " , " fields " : [ " content " ] } } ] ,
2021-02-12 14:57:06 +01:00
" filter " : [ { " terms " : { " year " : $ { years } } } ] } } } """ ,
)
2021-02-16 16:24:28 +01:00
results = retriever . retrieve ( query = " test " , filters = { " years " : [ " 2020 " , " 2021 " ] } )
2021-02-12 14:57:06 +01:00
assert len ( results ) == 4
2023-01-05 17:13:04 +01:00
# test linefeeds in query
results = retriever . retrieve ( query = " test \n " , filters = { " years " : [ " 2020 " , " 2021 " ] } )
assert len ( results ) == 3
# test double quote in query
results = retriever . retrieve ( query = ' test " ' , filters = { " years " : [ " 2020 " , " 2021 " ] } )
assert len ( results ) == 3
2021-02-12 14:57:06 +01:00
# test custom "term" query
2022-04-26 16:09:39 +02:00
retriever = BM25Retriever (
2021-02-12 14:57:06 +01:00
document_store = document_store ,
custom_query = """
{
2022-04-12 11:52:27 +02:00
" size " : 10 ,
2021-02-12 14:57:06 +01:00
" query " : {
" bool " : {
" should " : [ {
2021-10-13 14:23:23 +02:00
" multi_match " : { " query " : $ { query } , " type " : " most_fields " , " fields " : [ " content " ] } } ] ,
2021-02-12 14:57:06 +01:00
" filter " : [ { " term " : { " year " : $ { years } } } ] } } } """ ,
)
2021-02-16 16:24:28 +01:00
results = retriever . retrieve ( query = " test " , filters = { " years " : " 2021 " } )
2021-02-12 14:57:06 +01:00
assert len ( results ) == 3
2022-06-07 09:23:03 +02:00
@pytest.mark.integration
2022-03-03 15:19:27 +01:00
@pytest.mark.parametrize (
2022-11-15 09:54:55 +01:00
" document_store " , [ " elasticsearch " , " faiss " , " memory " , " milvus " , " weaviate " , " pinecone " ] , indirect = True
2022-03-03 15:19:27 +01:00
)
2020-10-14 16:15:04 +02:00
@pytest.mark.parametrize ( " retriever " , [ " dpr " ] , indirect = True )
2022-06-10 18:22:48 +02:00
def test_dpr_embedding ( document_store : BaseDocumentStore , retriever , docs_with_ids ) :
2021-02-12 14:57:06 +01:00
document_store . return_embedding = True
2022-06-10 18:22:48 +02:00
document_store . write_documents ( docs_with_ids )
2020-10-27 08:33:39 +01:00
document_store . update_embeddings ( retriever = retriever )
2022-06-10 18:22:48 +02:00
docs = document_store . get_all_documents ( )
docs . sort ( key = lambda d : d . id )
print ( [ doc . id for doc in docs ] )
expected_values = [ 0.00892 , 0.00780 , 0.00482 , - 0.00626 , 0.010966 ]
for doc , expected_value in zip ( docs , expected_values ) :
embedding = doc . embedding
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding / = np . linalg . norm ( embedding )
assert len ( embedding ) == 768
2022-09-22 17:46:49 +02:00
assert isclose ( embedding [ 0 ] , expected_value , rel_tol = 0.01 )
2021-01-20 12:52:52 +01:00
2020-11-05 13:29:23 +01:00
2022-06-07 09:23:03 +02:00
@pytest.mark.integration
2022-03-03 15:19:27 +01:00
@pytest.mark.parametrize (
2022-11-15 09:54:55 +01:00
" document_store " , [ " elasticsearch " , " faiss " , " memory " , " milvus " , " weaviate " , " pinecone " ] , indirect = True
2022-03-03 15:19:27 +01:00
)
2021-06-14 17:53:43 +02:00
@pytest.mark.parametrize ( " retriever " , [ " retribert " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 128 )
2022-06-10 18:22:48 +02:00
def test_retribert_embedding ( document_store , retriever , docs_with_ids ) :
2021-11-04 09:27:12 +01:00
if isinstance ( document_store , WeaviateDocumentStore ) :
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
# We need 128 here and therefore initialize a new WeaviateDocumentStore.
2022-04-26 19:06:30 +02:00
document_store = WeaviateDocumentStore ( index = " haystack_test " , embedding_dim = 128 , recreate_index = True )
2021-06-14 17:53:43 +02:00
document_store . return_embedding = True
2022-06-10 18:22:48 +02:00
document_store . write_documents ( docs_with_ids )
2021-06-14 17:53:43 +02:00
document_store . update_embeddings ( retriever = retriever )
2022-06-10 18:22:48 +02:00
docs = document_store . get_all_documents ( )
docs = sorted ( docs , key = lambda d : d . id )
expected_values = [ 0.14017 , 0.05975 , 0.14267 , 0.15099 , 0.14383 ]
for doc , expected_value in zip ( docs , expected_values ) :
embedding = doc . embedding
assert len ( embedding ) == 128
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding / = np . linalg . norm ( embedding )
assert isclose ( embedding [ 0 ] , expected_value , rel_tol = 0.001 )
2021-06-14 17:53:43 +02:00
2022-12-19 17:06:48 +01:00
def test_openai_embedding_retriever_selection ( ) :
# OpenAI released (Dec 2022) a unifying embedding model called text-embedding-ada-002
# make sure that we can use it with the retriever selection
er = EmbeddingRetriever ( embedding_model = " text-embedding-ada-002 " , document_store = None )
assert er . model_format == " openai "
assert er . embedding_encoder . query_encoder_model == " text-embedding-ada-002 "
assert er . embedding_encoder . doc_encoder_model == " text-embedding-ada-002 "
# but also support old ada and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever ( embedding_model = " ada " , document_store = None )
assert er . model_format == " openai "
assert er . embedding_encoder . query_encoder_model == " text-search-ada-query-001 "
assert er . embedding_encoder . doc_encoder_model == " text-search-ada-doc-001 "
# but also support old babbage and other text-search-<modelname>-*-001 models
er = EmbeddingRetriever ( embedding_model = " babbage " , document_store = None )
assert er . model_format == " openai "
assert er . embedding_encoder . query_encoder_model == " text-search-babbage-query-001 "
assert er . embedding_encoder . doc_encoder_model == " text-search-babbage-doc-001 "
# make sure that we can handle potential unreleased models
er = EmbeddingRetriever ( embedding_model = " text-embedding-babbage-002 " , document_store = None )
assert er . model_format == " openai "
assert er . embedding_encoder . query_encoder_model == " text-embedding-babbage-002 "
assert er . embedding_encoder . doc_encoder_model == " text-embedding-babbage-002 "
# etc etc.
2022-10-14 15:01:03 +02:00
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
2023-03-06 09:37:20 -03:00
@pytest.mark.parametrize ( " retriever " , [ " cohere " ] , indirect = True )
2022-10-14 15:01:03 +02:00
@pytest.mark.embedding_dim ( 1024 )
@pytest.mark.skipif (
2023-03-06 09:37:20 -03:00
not os . environ . get ( " COHERE_API_KEY " , None ) ,
reason = " Please export an env var called COHERE_API_KEY containing " " the Cohere API key to run this test. " ,
2022-10-14 15:01:03 +02:00
)
2023-03-06 09:37:20 -03:00
def test_basic_cohere_embedding ( document_store , retriever , docs_with_ids ) :
2022-10-14 15:01:03 +02:00
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
docs = document_store . get_all_documents ( )
docs = sorted ( docs , key = lambda d : d . id )
for doc in docs :
assert len ( doc . embedding ) == 1024
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
2023-03-06 09:37:20 -03:00
@pytest.mark.parametrize ( " retriever " , [ " openai " ] , indirect = True )
@pytest.mark.embedding_dim ( 1536 )
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = ( " Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test. " ) ,
)
def test_basic_openai_embedding ( document_store , retriever , docs_with_ids ) :
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
docs = document_store . get_all_documents ( )
docs = sorted ( docs , key = lambda d : d . id )
for doc in docs :
assert len ( doc . embedding ) == 1536
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.parametrize ( " retriever " , [ " azure " ] , indirect = True )
@pytest.mark.embedding_dim ( 1536 )
@pytest.mark.skipif (
not os . environ . get ( " AZURE_OPENAI_API_KEY " , None )
and not os . environ . get ( " AZURE_OPENAI_BASE_URL " , None )
and not os . environ . get ( " AZURE_OPENAI_DEPLOYMENT_NAME_EMBED " , None ) ,
reason = (
" Please export env variables called AZURE_OPENAI_API_KEY containing "
" the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
" the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
" the Azure OpenAI deployment name to run this test. "
) ,
)
def test_basic_azure_embedding ( document_store , retriever , docs_with_ids ) :
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
docs = document_store . get_all_documents ( )
docs = sorted ( docs , key = lambda d : d . id )
for doc in docs :
assert len ( doc . embedding ) == 1536
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.parametrize ( " retriever " , [ " cohere " ] , indirect = True )
2022-10-14 15:01:03 +02:00
@pytest.mark.embedding_dim ( 1024 )
@pytest.mark.skipif (
2023-03-06 09:37:20 -03:00
not os . environ . get ( " COHERE_API_KEY " , None ) ,
reason = " Please export an env var called COHERE_API_KEY containing the Cohere API key to run this test. " ,
)
def test_retriever_basic_cohere_search ( document_store , retriever , docs_with_ids ) :
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
p_retrieval = DocumentSearchPipeline ( retriever )
res = p_retrieval . run ( query = " Madrid " , params = { " Retriever " : { " top_k " : 1 } } )
assert len ( res [ " documents " ] ) == 1
assert " Madrid " in res [ " documents " ] [ 0 ] . content
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.parametrize ( " retriever " , [ " openai " ] , indirect = True )
@pytest.mark.embedding_dim ( 1536 )
@pytest.mark.skipif (
not os . environ . get ( " OPENAI_API_KEY " , None ) ,
reason = " Please export env called OPENAI_API_KEY containing the OpenAI API key to run this test. " ,
)
def test_retriever_basic_openai_search ( document_store , retriever , docs_with_ids ) :
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
p_retrieval = DocumentSearchPipeline ( retriever )
res = p_retrieval . run ( query = " Madrid " , params = { " Retriever " : { " top_k " : 1 } } )
assert len ( res [ " documents " ] ) == 1
assert " Madrid " in res [ " documents " ] [ 0 ] . content
@pytest.mark.integration
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.parametrize ( " retriever " , [ " azure " ] , indirect = True )
@pytest.mark.embedding_dim ( 1536 )
@pytest.mark.skipif (
not os . environ . get ( " AZURE_OPENAI_API_KEY " , None )
and not os . environ . get ( " AZURE_OPENAI_BASE_URL " , None )
and not os . environ . get ( " AZURE_OPENAI_DEPLOYMENT_NAME_EMBED " , None ) ,
reason = (
" Please export env variables called AZURE_OPENAI_API_KEY containing "
" the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
" the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
" the Azure OpenAI deployment name to run this test. "
) ,
2022-10-14 15:01:03 +02:00
)
2023-03-06 09:37:20 -03:00
def test_retriever_basic_azure_search ( document_store , retriever , docs_with_ids ) :
2022-10-14 15:01:03 +02:00
document_store . return_embedding = True
document_store . write_documents ( docs_with_ids )
document_store . update_embeddings ( retriever = retriever )
p_retrieval = DocumentSearchPipeline ( retriever )
res = p_retrieval . run ( query = " Madrid " , params = { " Retriever " : { " top_k " : 1 } } )
assert len ( res [ " documents " ] ) == 1
assert " Madrid " in res [ " documents " ] [ 0 ] . content
2022-06-07 09:23:03 +02:00
@pytest.mark.integration
2021-10-25 12:27:02 +02:00
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
2022-06-10 18:22:48 +02:00
@pytest.mark.parametrize ( " document_store " , [ " elasticsearch " , " memory " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 512 )
2021-11-04 09:27:12 +01:00
def test_table_text_retriever_embedding ( document_store , retriever , docs ) :
2022-11-22 09:24:52 +01:00
# BM25 representation is incompatible with table retriever
if isinstance ( document_store , InMemoryDocumentStore ) :
document_store . use_bm25 = False
2021-10-25 12:27:02 +02:00
document_store . return_embedding = True
2021-11-04 09:27:12 +01:00
document_store . write_documents ( docs )
2021-10-25 12:27:02 +02:00
table_data = {
" Mountain " : [ " Mount Everest " , " K2 " , " Kangchenjunga " , " Lhotse " , " Makalu " ] ,
2022-02-03 13:43:18 +01:00
" Height " : [ " 8848m " , " 8,611 m " , " 8 586m " , " 8 516 m " , " 8,485m " ] ,
2021-10-25 12:27:02 +02:00
}
table = pd . DataFrame ( table_data )
table_doc = Document ( content = table , content_type = " table " , id = " 6 " )
document_store . write_documents ( [ table_doc ] )
document_store . update_embeddings ( retriever = retriever )
2022-06-10 18:22:48 +02:00
docs = document_store . get_all_documents ( )
docs = sorted ( docs , key = lambda d : d . id )
expected_values = [ 0.061191384 , 0.038075786 , 0.27447605 , 0.09399721 , 0.0959682 ]
for doc , expected_value in zip ( docs , expected_values ) :
assert len ( doc . embedding ) == 512
assert isclose ( doc . embedding [ 0 ] , expected_value , rel_tol = 0.001 )
2021-10-25 12:27:02 +02:00
2023-02-09 10:38:16 +00:00
@pytest.mark.integration
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.embedding_dim ( 512 )
def test_table_text_retriever_embedding_only_text ( document_store , retriever ) :
docs = [
Document ( content = " This is a test " , content_type = " text " ) ,
Document ( content = " This is another test " , content_type = " text " ) ,
]
document_store . write_documents ( docs )
document_store . update_embeddings ( retriever )
@pytest.mark.integration
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
@pytest.mark.embedding_dim ( 512 )
def test_table_text_retriever_embedding_only_table ( document_store , retriever ) :
doc = Document (
content = pd . DataFrame ( columns = [ " id " , " text " ] , data = [ [ " 1 " , " This is a test " ] , [ " 2 " , " This is another test " ] ] ) ,
content_type = " table " ,
)
document_store . write_documents ( [ doc ] )
document_store . update_embeddings ( retriever )
2020-11-05 13:29:23 +01:00
@pytest.mark.parametrize ( " retriever " , [ " dpr " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " memory " ] , indirect = True )
2022-03-15 11:17:26 +01:00
def test_dpr_saving_and_loading ( tmp_path , retriever , document_store ) :
retriever . save ( f " { tmp_path } /test_dpr_save " )
2021-02-12 14:57:06 +01:00
2020-11-05 13:29:23 +01:00
def sum_params ( model ) :
s = [ ]
for p in model . parameters ( ) :
n = p . cpu ( ) . data . numpy ( )
s . append ( np . sum ( n ) )
return sum ( s )
2021-02-12 14:57:06 +01:00
2020-11-05 13:29:23 +01:00
original_sum_query = sum_params ( retriever . query_encoder )
original_sum_passage = sum_params ( retriever . passage_encoder )
del retriever
2022-03-15 11:17:26 +01:00
loaded_retriever = DensePassageRetriever . load ( f " { tmp_path } /test_dpr_save " , document_store )
2020-11-05 13:29:23 +01:00
loaded_sum_query = sum_params ( loaded_retriever . query_encoder )
loaded_sum_passage = sum_params ( loaded_retriever . passage_encoder )
assert abs ( original_sum_query - loaded_sum_query ) < 0.1
assert abs ( original_sum_passage - loaded_sum_passage ) < 0.1
# comparison of weights (RAM intense!)
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
#
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
# assert (p1.data.ne(p2.data).sum() == 0)
# attributes
2021-04-01 18:23:05 +02:00
assert loaded_retriever . processor . embed_title == True
2020-11-05 13:29:23 +01:00
assert loaded_retriever . batch_size == 16
2021-04-01 18:23:05 +02:00
assert loaded_retriever . processor . max_seq_len_passage == 256
assert loaded_retriever . processor . max_seq_len_query == 64
2020-11-05 13:29:23 +01:00
# Tokenizer
2023-05-17 18:54:34 +02:00
assert isinstance ( loaded_retriever . passage_tokenizer , PreTrainedTokenizerFast )
assert isinstance ( loaded_retriever . query_tokenizer , PreTrainedTokenizerFast )
2020-11-05 13:29:23 +01:00
assert loaded_retriever . passage_tokenizer . do_lower_case == True
assert loaded_retriever . query_tokenizer . do_lower_case == True
assert loaded_retriever . passage_tokenizer . vocab_size == 30522
assert loaded_retriever . query_tokenizer . vocab_size == 30522
2021-10-25 12:27:02 +02:00
@pytest.mark.parametrize ( " retriever " , [ " table_text_retriever " ] , indirect = True )
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 512 )
2022-03-15 11:17:26 +01:00
def test_table_text_retriever_saving_and_loading ( tmp_path , retriever , document_store ) :
retriever . save ( f " { tmp_path } /test_table_text_retriever_save " )
2021-10-25 12:27:02 +02:00
def sum_params ( model ) :
s = [ ]
for p in model . parameters ( ) :
n = p . cpu ( ) . data . numpy ( )
s . append ( np . sum ( n ) )
return sum ( s )
original_sum_query = sum_params ( retriever . query_encoder )
original_sum_passage = sum_params ( retriever . passage_encoder )
original_sum_table = sum_params ( retriever . table_encoder )
del retriever
2022-03-15 11:17:26 +01:00
loaded_retriever = TableTextRetriever . load ( f " { tmp_path } /test_table_text_retriever_save " , document_store )
2021-10-25 12:27:02 +02:00
loaded_sum_query = sum_params ( loaded_retriever . query_encoder )
loaded_sum_passage = sum_params ( loaded_retriever . passage_encoder )
loaded_sum_table = sum_params ( loaded_retriever . table_encoder )
assert abs ( original_sum_query - loaded_sum_query ) < 0.1
assert abs ( original_sum_passage - loaded_sum_passage ) < 0.1
assert abs ( original_sum_table - loaded_sum_table ) < 0.01
# attributes
assert loaded_retriever . processor . embed_meta_fields == [ " name " , " section_title " , " caption " ]
assert loaded_retriever . batch_size == 16
assert loaded_retriever . processor . max_seq_len_passage == 256
assert loaded_retriever . processor . max_seq_len_table == 256
assert loaded_retriever . processor . max_seq_len_query == 64
# Tokenizer
2023-05-17 18:54:34 +02:00
assert isinstance ( loaded_retriever . passage_tokenizer , PreTrainedTokenizerFast )
assert isinstance ( loaded_retriever . table_tokenizer , PreTrainedTokenizerFast )
assert isinstance ( loaded_retriever . query_tokenizer , PreTrainedTokenizerFast )
2021-10-25 12:27:02 +02:00
assert loaded_retriever . passage_tokenizer . do_lower_case == True
assert loaded_retriever . table_tokenizer . do_lower_case == True
assert loaded_retriever . query_tokenizer . do_lower_case == True
assert loaded_retriever . passage_tokenizer . vocab_size == 30522
assert loaded_retriever . table_tokenizer . vocab_size == 30522
assert loaded_retriever . query_tokenizer . vocab_size == 30522
2022-01-10 17:10:32 +00:00
@pytest.mark.embedding_dim ( 128 )
2023-04-11 10:33:43 +02:00
def test_table_text_retriever_training ( tmp_path , document_store , samples_path ) :
2021-10-25 12:27:02 +02:00
retriever = TableTextRetriever (
document_store = document_store ,
2022-07-22 16:29:30 +02:00
query_embedding_model = " deepset/bert-small-mm_retrieval-question_encoder " ,
passage_embedding_model = " deepset/bert-small-mm_retrieval-passage_encoder " ,
table_embedding_model = " deepset/bert-small-mm_retrieval-table_encoder " ,
2022-02-03 13:43:18 +01:00
use_gpu = False ,
2021-10-25 12:27:02 +02:00
)
retriever . train (
2023-04-11 10:33:43 +02:00
data_dir = samples_path / " mmr " ,
2021-10-25 12:27:02 +02:00
train_filename = " sample.json " ,
n_epochs = 1 ,
n_gpu = 0 ,
2022-11-07 10:37:04 +01:00
save_dir = f " { tmp_path } /test_table_text_retriever_train " ,
2021-10-25 12:27:02 +02:00
)
# Load trained model
2022-11-07 10:37:04 +01:00
retriever = TableTextRetriever . load (
load_dir = f " { tmp_path } /test_table_text_retriever_train " , document_store = document_store
)
2022-01-26 22:05:33 +05:30
@pytest.mark.elasticsearch
def test_elasticsearch_highlight ( ) :
client = Elasticsearch ( )
2022-02-03 13:43:18 +01:00
client . indices . delete ( index = " haystack_hl_test " , ignore = [ 404 ] )
2022-01-26 22:05:33 +05:30
# Mapping the content and title field as "text" perform search on these both fields.
2022-02-03 13:43:18 +01:00
document_store = ElasticsearchDocumentStore (
index = " haystack_hl_test " ,
content_field = " title " ,
custom_mapping = { " mappings " : { " properties " : { " content " : { " type " : " text " } , " title " : { " type " : " text " } } } } ,
2022-01-26 22:05:33 +05:30
)
documents = [
2022-02-03 13:43:18 +01:00
{
" title " : " Green tea components " ,
" meta " : {
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink "
} ,
" id " : " 1 " ,
} ,
{
" title " : " Green tea catechin " ,
" meta " : { " content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " } ,
" id " : " 2 " ,
} ,
{
" title " : " Minerals in Green tea " ,
" meta " : { " content " : " Green tea also has small amounts of minerals that can benefit your health. " } ,
" id " : " 3 " ,
} ,
{
" title " : " Green tea Benefits " ,
" meta " : { " content " : " Green tea does more than just keep you alert, it may also help boost brain function. " } ,
" id " : " 4 " ,
} ,
2022-01-26 22:05:33 +05:30
]
document_store . write_documents ( documents )
# Enabled highlighting on "title"&"content" field only using custom query
2022-04-26 16:09:39 +02:00
retriever_1 = BM25Retriever (
2022-02-03 13:43:18 +01:00
document_store = document_store ,
custom_query = """ {
2022-01-26 22:05:33 +05:30
" size " : 20 ,
" query " : {
" bool " : {
" should " : [
{
" multi_match " : {
" query " : $ { query } ,
" fields " : [
" content^3 " ,
" title^5 "
]
}
}
]
}
} ,
" highlight " : {
" pre_tags " : [
" ** "
] ,
" post_tags " : [
" ** "
] ,
" number_of_fragments " : 3 ,
" fragment_size " : 5 ,
" fields " : {
" content " : { } ,
" title " : { }
}
}
} """ ,
)
results = retriever_1 . retrieve ( query = " is green tea healthy " )
2022-02-03 13:43:18 +01:00
assert len ( results [ 0 ] . meta [ " highlighted " ] ) == 2
assert results [ 0 ] . meta [ " highlighted " ] [ " title " ] == [ " **Green** " , " **tea** components " ]
assert results [ 0 ] . meta [ " highlighted " ] [ " content " ] == [ " The **green** " , " **tea** plant " , " range of **healthy** " ]
2022-01-26 22:05:33 +05:30
2022-02-03 13:43:18 +01:00
# Enabled highlighting on "title" field only using custom query
2022-04-26 16:09:39 +02:00
retriever_2 = BM25Retriever (
2022-02-03 13:43:18 +01:00
document_store = document_store ,
custom_query = """ {
2022-01-26 22:05:33 +05:30
" size " : 20 ,
" query " : {
" bool " : {
" should " : [
{
" multi_match " : {
" query " : $ { query } ,
" fields " : [
" content^3 " ,
" title^5 "
]
}
}
]
}
} ,
" highlight " : {
" pre_tags " : [
" ** "
] ,
" post_tags " : [
" ** "
] ,
" number_of_fragments " : 3 ,
" fragment_size " : 5 ,
" fields " : {
" title " : { }
}
}
} """ ,
)
results = retriever_2 . retrieve ( query = " is green tea healthy " )
2022-02-03 13:43:18 +01:00
assert len ( results [ 0 ] . meta [ " highlighted " ] ) == 1
assert results [ 0 ] . meta [ " highlighted " ] [ " title " ] == [ " **Green** " , " **tea** components " ]
2022-03-25 17:53:42 +01:00
def test_elasticsearch_filter_must_not_increase_results ( ) :
index = " filter_must_not_increase_results "
client = Elasticsearch ( )
client . indices . delete ( index = index , ignore = [ 404 ] )
documents = [
{
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink " ,
" meta " : { " content_type " : " text " } ,
" id " : " 1 " ,
} ,
{
" content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " ,
" meta " : { " content_type " : " text " } ,
" id " : " 2 " ,
} ,
{
" content " : " Green tea also has small amounts of minerals that can benefit your health. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 3 " ,
} ,
{
" content " : " Green tea does more than just keep you alert, it may also help boost brain function. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 4 " ,
} ,
]
doc_store = ElasticsearchDocumentStore ( index = index )
doc_store . write_documents ( documents )
results_wo_filter = doc_store . query ( query = " drink " )
assert len ( results_wo_filter ) == 1
results_w_filter = doc_store . query ( query = " drink " , filters = { " content_type " : " text " } )
assert len ( results_w_filter ) == 1
doc_store . delete_index ( index )
2022-03-28 22:10:50 +02:00
def test_elasticsearch_all_terms_must_match ( ) :
index = " all_terms_must_match "
client = Elasticsearch ( )
client . indices . delete ( index = index , ignore = [ 404 ] )
documents = [
{
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink " ,
" meta " : { " content_type " : " text " } ,
" id " : " 1 " ,
} ,
{
" content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " ,
" meta " : { " content_type " : " text " } ,
" id " : " 2 " ,
} ,
{
" content " : " Green tea also has small amounts of minerals that can benefit your health. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 3 " ,
} ,
{
" content " : " Green tea does more than just keep you alert, it may also help boost brain function. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 4 " ,
} ,
]
doc_store = ElasticsearchDocumentStore ( index = index )
doc_store . write_documents ( documents )
results_wo_all_terms_must_match = doc_store . query ( query = " drink green tea " )
assert len ( results_wo_all_terms_must_match ) == 4
results_w_all_terms_must_match = doc_store . query ( query = " drink green tea " , all_terms_must_match = True )
assert len ( results_w_all_terms_must_match ) == 1
doc_store . delete_index ( index )
2022-04-12 11:52:27 +02:00
2022-12-08 12:48:43 +01:00
@pytest.mark.elasticsearch
def test_bm25retriever_all_terms_must_match ( ) :
index = " all_terms_must_match "
client = Elasticsearch ( )
client . indices . delete ( index = index , ignore = [ 404 ] )
documents = [
{
" content " : " The green tea plant contains a range of healthy compounds that make it into the final drink " ,
" meta " : { " content_type " : " text " } ,
" id " : " 1 " ,
} ,
{
" content " : " Green tea contains a catechin called epigallocatechin-3-gallate (EGCG). " ,
" meta " : { " content_type " : " text " } ,
" id " : " 2 " ,
} ,
{
" content " : " Green tea also has small amounts of minerals that can benefit your health. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 3 " ,
} ,
{
" content " : " Green tea does more than just keep you alert, it may also help boost brain function. " ,
" meta " : { " content_type " : " text " } ,
" id " : " 4 " ,
} ,
]
doc_store = ElasticsearchDocumentStore ( index = index )
doc_store . write_documents ( documents )
retriever = BM25Retriever ( document_store = doc_store )
results_wo_all_terms_must_match = retriever . retrieve ( query = " drink green tea " )
assert len ( results_wo_all_terms_must_match ) == 4
retriever = BM25Retriever ( document_store = doc_store , all_terms_must_match = True )
results_w_all_terms_must_match = retriever . retrieve ( query = " drink green tea " )
assert len ( results_w_all_terms_must_match ) == 1
retriever = BM25Retriever ( document_store = doc_store )
results_w_all_terms_must_match = retriever . retrieve ( query = " drink green tea " , all_terms_must_match = True )
assert len ( results_w_all_terms_must_match ) == 1
doc_store . delete_index ( index )
2022-04-12 11:52:27 +02:00
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format ( caplog ) :
document_store = InMemoryDocumentStore ( )
with caplog . at_level ( logging . WARNING ) :
EmbeddingRetriever (
2022-06-02 15:05:29 +02:00
document_store = document_store ,
embedding_model = " sentence-transformers/paraphrase-multilingual-mpnet-base-v2 " ,
model_format = " farm " ,
2022-04-12 11:52:27 +02:00
)
assert (
2022-06-02 15:05:29 +02:00
" You may need to set model_format= ' sentence_transformers ' to ensure correct loading of model. "
2022-04-12 11:52:27 +02:00
in caplog . text
)
2022-04-25 00:53:48 -07:00
@pytest.mark.parametrize ( " retriever " , [ " es_filter_only " ] , indirect = True )
@pytest.mark.parametrize ( " document_store " , [ " elasticsearch " ] , indirect = True )
def test_es_filter_only ( document_store , retriever ) :
docs = [
Document ( content = " Doc1 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc2 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc3 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc4 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc5 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc6 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc7 " , meta = { " f1 " : " 1 " } ) ,
Document ( content = " Doc8 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc9 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc10 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc11 " , meta = { " f1 " : " 0 " } ) ,
Document ( content = " Doc12 " , meta = { " f1 " : " 0 " } ) ,
]
document_store . write_documents ( docs )
retrieved_docs = retriever . retrieve ( query = " " , filters = { " f1 " : [ " 0 " ] } )
assert len ( retrieved_docs ) == 11
2022-10-17 18:58:35 +02:00
#
# MultiModal
#
@pytest.fixture
def text_docs ( ) - > List [ Document ] :
return [
Document (
content = " My name is Paul and I live in New York " ,
meta = {
" meta_field " : " test2 " ,
" name " : " filename2 " ,
" date_field " : " 2019-10-01 " ,
" numeric_field " : 5.0 ,
" odd_field " : 0 ,
} ,
) ,
Document (
content = " My name is Carla and I live in Berlin " ,
meta = {
" meta_field " : " test1 " ,
" name " : " filename1 " ,
" date_field " : " 2020-03-01 " ,
" numeric_field " : 5.5 ,
" odd_field " : 1 ,
} ,
) ,
Document (
content = " My name is Christelle and I live in Paris " ,
meta = {
" meta_field " : " test3 " ,
" name " : " filename3 " ,
" date_field " : " 2018-10-01 " ,
" numeric_field " : 4.5 ,
" odd_field " : 1 ,
} ,
) ,
Document (
content = " My name is Camila and I live in Madrid " ,
meta = {
" meta_field " : " test4 " ,
" name " : " filename4 " ,
" date_field " : " 2021-02-01 " ,
" numeric_field " : 3.0 ,
" odd_field " : 0 ,
} ,
) ,
Document (
content = " My name is Matteo and I live in Rome " ,
meta = {
" meta_field " : " test5 " ,
" name " : " filename5 " ,
" date_field " : " 2019-01-01 " ,
" numeric_field " : 0.0 ,
" odd_field " : 1 ,
} ,
) ,
]
@pytest.fixture
def table_docs ( ) - > List [ Document ] :
return [
Document (
content = pd . DataFrame (
{
" Mountain " : [ " Mount Everest " , " K2 " , " Kangchenjunga " , " Lhotse " , " Makalu " ] ,
" Height " : [ " 8848m " , " 8,611 m " , " 8 586m " , " 8 516 m " , " 8,485m " ] ,
}
) ,
content_type = " table " ,
) ,
Document (
content = pd . DataFrame (
{
" City " : [ " Paris " , " Lyon " , " Marseille " , " Lille " , " Toulouse " , " Bordeaux " ] ,
" Population " : [ " 13,114,718 " , " 2,280,845 " , " 1,873,270 " , " 1,510,079 " , " 1,454,158 " , " 1,363,711 " ] ,
}
) ,
content_type = " table " ,
) ,
Document (
content = pd . DataFrame (
{
" City " : [ " Berlin " , " Hamburg " , " Munich " , " Cologne " ] ,
" Population " : [ " 3,644,826 " , " 1,841,179 " , " 1,471,508 " , " 1,085,664 " ] ,
}
) ,
content_type = " table " ,
) ,
]
@pytest.fixture
2023-04-11 10:33:43 +02:00
def image_docs ( samples_path ) - > List [ Document ] :
2022-10-17 18:58:35 +02:00
return [
2023-04-11 10:33:43 +02:00
Document ( content = str ( samples_path / " images " / imagefile ) , content_type = " image " )
for imagefile in os . listdir ( samples_path / " images " )
2022-10-17 18:58:35 +02:00
]
@pytest.mark.integration
def test_multimodal_text_retrieval ( text_docs : List [ Document ] ) :
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True ) ,
query_embedding_model = " sentence-transformers/multi-qa-mpnet-base-dot-v1 " ,
document_embedding_models = { " text " : " sentence-transformers/multi-qa-mpnet-base-dot-v1 " } ,
)
retriever . document_store . write_documents ( text_docs )
retriever . document_store . update_embeddings ( retriever = retriever )
results = retriever . retrieve ( query = " Who lives in Paris? " )
assert results [ 0 ] . content == " My name is Christelle and I live in Paris "
2022-12-08 08:28:43 +01:00
@pytest.mark.integration
def test_multimodal_text_retrieval_batch ( text_docs : List [ Document ] ) :
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True ) ,
query_embedding_model = " sentence-transformers/multi-qa-mpnet-base-dot-v1 " ,
document_embedding_models = { " text " : " sentence-transformers/multi-qa-mpnet-base-dot-v1 " } ,
)
retriever . document_store . write_documents ( text_docs )
retriever . document_store . update_embeddings ( retriever = retriever )
results = retriever . retrieve_batch ( queries = [ " Who lives in Paris? " , " Who lives in Berlin? " , " Who lives in Madrid? " ] )
assert results [ 0 ] [ 0 ] . content == " My name is Christelle and I live in Paris "
assert results [ 1 ] [ 0 ] . content == " My name is Carla and I live in Berlin "
assert results [ 2 ] [ 0 ] . content == " My name is Camila and I live in Madrid "
2022-10-17 18:58:35 +02:00
@pytest.mark.integration
def test_multimodal_table_retrieval ( table_docs : List [ Document ] ) :
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True ) ,
query_embedding_model = " deepset/all-mpnet-base-v2-table " ,
document_embedding_models = { " table " : " deepset/all-mpnet-base-v2-table " } ,
)
retriever . document_store . write_documents ( table_docs )
retriever . document_store . update_embeddings ( retriever = retriever )
results = retriever . retrieve ( query = " How many people live in Hamburg? " )
assert_frame_equal (
results [ 0 ] . content ,
pd . DataFrame (
{
" City " : [ " Berlin " , " Hamburg " , " Munich " , " Cologne " ] ,
" Population " : [ " 3,644,826 " , " 1,841,179 " , " 1,471,508 " , " 1,085,664 " ] ,
}
) ,
)
2023-03-16 16:32:28 +01:00
@pytest.mark.skip ( " Must be reworked as it fails randomly " )
2022-11-28 19:24:22 +01:00
@pytest.mark.integration
def test_multimodal_retriever_query ( ) :
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True , embedding_dim = 512 ) ,
query_embedding_model = " sentence-transformers/clip-ViT-B-32 " ,
document_embedding_models = { " image " : " sentence-transformers/clip-ViT-B-32 " } ,
)
res_emb = retriever . embed_queries ( [ " dummy query 1 " , " dummy query 1 " ] )
assert np . array_equal ( res_emb [ 0 ] , res_emb [ 1 ] )
2022-10-17 18:58:35 +02:00
@pytest.mark.integration
2023-04-11 10:33:43 +02:00
def test_multimodal_image_retrieval ( image_docs : List [ Document ] , samples_path ) :
2022-10-17 18:58:35 +02:00
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True , embedding_dim = 512 ) ,
query_embedding_model = " sentence-transformers/clip-ViT-B-32 " ,
document_embedding_models = { " image " : " sentence-transformers/clip-ViT-B-32 " } ,
)
retriever . document_store . write_documents ( image_docs )
retriever . document_store . update_embeddings ( retriever = retriever )
results = retriever . retrieve ( query = " What ' s a cat? " )
2023-04-11 10:33:43 +02:00
assert str ( results [ 0 ] . content ) == str ( samples_path / " images " / " cat.jpg " )
2022-10-17 18:58:35 +02:00
@pytest.mark.skip ( " Not working yet as intended " )
@pytest.mark.integration
2023-04-11 10:33:43 +02:00
def test_multimodal_text_image_retrieval ( text_docs : List [ Document ] , image_docs : List [ Document ] , samples_path ) :
2022-10-17 18:58:35 +02:00
retriever = MultiModalRetriever (
document_store = InMemoryDocumentStore ( return_embedding = True , embedding_dim = 512 ) ,
query_embedding_model = " sentence-transformers/clip-ViT-B-32 " ,
document_embedding_models = {
" text " : " sentence-transformers/clip-ViT-B-32 " ,
" image " : " sentence-transformers/clip-ViT-B-32 " ,
} ,
)
retriever . document_store . write_documents ( image_docs )
retriever . document_store . write_documents ( text_docs )
retriever . document_store . update_embeddings ( retriever = retriever )
results = retriever . retrieve ( query = " What ' s Paris? " )
text_results = [ result for result in results if result . content_type == " text " ]
image_results = [ result for result in results if result . content_type == " image " ]
2023-04-11 10:33:43 +02:00
assert str ( image_results [ 0 ] . content ) == str ( samples_path / " images " / " paris.jpg " )
2022-10-17 18:58:35 +02:00
assert text_results [ 0 ] . content == " My name is Christelle and I live in Paris "
2023-03-27 15:31:22 +02:00
@pytest.mark.unit
2023-03-27 18:14:58 +02:00
def test_web_retriever_mode_raw_documents ( monkeypatch ) :
expected_search_results = {
" documents " : [
Document (
content = " Eddard Stark " ,
score = 0.9090909090909091 ,
meta = { " title " : " Eddard Stark " , " link " : " " , " score " : 0.9090909090909091 } ,
id_hash_keys = [ " content " ] ,
id = " f408db6de8de0ffad0cb47cf8830dbb8 " ,
) ,
Document (
content = " The most likely answer for the clue is NED. How many solutions does Arya Stark ' s Father have? With crossword-solver.io you will find 1 solutions. We use ... " ,
score = 0.09090909090909091 ,
meta = {
" title " : " Arya Stark ' s Father - Crossword Clue Answers " ,
" link " : " https://crossword-solver.io/clue/arya-stark %27s -father/ " ,
" position " : 1 ,
" score " : 0.09090909090909091 ,
} ,
id_hash_keys = [ " content " ] ,
id = " 51779277acf94cf90e7663db137c0732 " ,
) ,
]
}
def mock_web_search_run ( self , query : str ) - > Tuple [ Dict , str ] :
return expected_search_results , " output_1 "
class MockResponse :
def __init__ ( self , text , status_code ) :
self . text = text
self . status_code = status_code
def get ( url , headers , timeout ) :
return MockResponse ( " mocked " , 200 )
def get_content ( self , text : str ) - > str :
return " What are the top solutions for \n Arya Stark ' s Father \n We found 1 solutions for \n Arya Stark ' s Father \n .The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED... "
monkeypatch . setattr ( WebSearch , " run " , mock_web_search_run )
monkeypatch . setattr ( ArticleExtractor , " get_content " , get_content )
monkeypatch . setattr ( requests , " get " , get )
web_retriever = WebRetriever ( api_key = " " , top_search_results = 2 , mode = " raw_documents " )
result = web_retriever . retrieve ( query = " Who is the father of Arya Stark? " )
assert len ( result ) == 1
assert isinstance ( result [ 0 ] , Document )
assert (
result [ 0 ] . content
== " What are the top solutions for \n Arya Stark ' s Father \n We found 1 solutions for \n Arya Stark ' s Father \n .The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED... "
)
assert result [ 0 ] . score == None
assert result [ 0 ] . meta [ " url " ] == " https://crossword-solver.io/clue/arya-stark %27s -father/ "
# Only preprocessed docs but not raw docs should have the _split_id field
assert " _split_id " not in result [ 0 ] . meta
@pytest.mark.unit
def test_web_retriever_mode_preprocessed_documents ( monkeypatch ) :
expected_search_results = {
" documents " : [
Document (
content = " Eddard Stark " ,
score = 0.9090909090909091 ,
meta = { " title " : " Eddard Stark " , " link " : " " , " score " : 0.9090909090909091 } ,
id_hash_keys = [ " content " ] ,
id = " f408db6de8de0ffad0cb47cf8830dbb8 " ,
) ,
Document (
content = " The most likely answer for the clue is NED. How many solutions does Arya Stark ' s Father have? With crossword-solver.io you will find 1 solutions. We use ... " ,
score = 0.09090909090909091 ,
meta = {
" title " : " Arya Stark ' s Father - Crossword Clue Answers " ,
" link " : " https://crossword-solver.io/clue/arya-stark %27s -father/ " ,
" position " : 1 ,
" score " : 0.09090909090909091 ,
} ,
id_hash_keys = [ " content " ] ,
id = " 51779277acf94cf90e7663db137c0732 " ,
) ,
]
}
def mock_web_search_run ( self , query : str ) - > Tuple [ Dict , str ] :
return expected_search_results , " output_1 "
class MockResponse :
def __init__ ( self , text , status_code ) :
self . text = text
self . status_code = status_code
def get ( url , headers , timeout ) :
return MockResponse ( " mocked " , 200 )
def get_content ( self , text : str ) - > str :
return " What are the top solutions for \n Arya Stark ' s Father \n We found 1 solutions for \n Arya Stark ' s Father \n .The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED... "
monkeypatch . setattr ( WebSearch , " run " , mock_web_search_run )
monkeypatch . setattr ( ArticleExtractor , " get_content " , get_content )
monkeypatch . setattr ( requests , " get " , get )
web_retriever = WebRetriever ( api_key = " " , top_search_results = 2 , mode = " preprocessed_documents " )
result = web_retriever . retrieve ( query = " Who is the father of Arya Stark? " )
assert len ( result ) == 1
assert isinstance ( result [ 0 ] , Document )
assert (
result [ 0 ] . content
== " What are the top solutions for \n Arya Stark ' s Father \n We found 1 solutions for \n Arya Stark ' s Father \n .The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED... "
)
assert result [ 0 ] . score == None
assert result [ 0 ] . meta [ " url " ] == " https://crossword-solver.io/clue/arya-stark %27s -father/ "
assert result [ 0 ] . meta [ " _split_id " ] == 0
@pytest.mark.unit
def test_web_retriever_mode_snippets ( monkeypatch ) :
expected_search_results = {
" documents " : [
Document (
content = " Eddard Stark " ,
score = 0.9090909090909091 ,
meta = { " title " : " Eddard Stark " , " link " : " " , " score " : 0.9090909090909091 } ,
id_hash_keys = [ " content " ] ,
id = " f408db6de8de0ffad0cb47cf8830dbb8 " ,
) ,
Document (
content = " The most likely answer for the clue is NED. How many solutions does Arya Stark ' s Father have? With crossword-solver.io you will find 1 solutions. We use ... " ,
score = 0.09090909090909091 ,
meta = {
" title " : " Arya Stark ' s Father - Crossword Clue Answers " ,
" link " : " https://crossword-solver.io/clue/arya-stark %27s -father/ " ,
" position " : 1 ,
" score " : 0.09090909090909091 ,
} ,
id_hash_keys = [ " content " ] ,
id = " 51779277acf94cf90e7663db137c0732 " ,
) ,
]
}
def mock_web_search_run ( self , query : str ) - > Tuple [ Dict , str ] :
return expected_search_results , " output_1 "
monkeypatch . setattr ( WebSearch , " run " , mock_web_search_run )
web_retriever = WebRetriever ( api_key = " " , top_search_results = 2 )
result = web_retriever . retrieve ( query = " Who is the father of Arya Stark? " )
assert result == expected_search_results [ " documents " ]
2023-03-27 15:31:22 +02:00
@fail_at_version ( 1 , 17 )
def test_text_2_sparql_retriever_deprecation ( ) :
BartForConditionalGeneration = object ( )
BartTokenizer = object ( )
with patch . multiple (
" haystack.nodes.retriever.text2sparql " , BartForConditionalGeneration = DEFAULT , BartTokenizer = DEFAULT
) :
knowledge_graph = Mock ( )
with pytest . warns ( DeprecationWarning ) as w :
Text2SparqlRetriever ( knowledge_graph )
assert len ( w ) == 1
assert (
w [ 0 ] . message . args [ 0 ]
== " The Text2SparqlRetriever component is deprecated and will be removed in future versions. "
)