mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-03 06:08:40 +00:00

* add changes for api_base * format retriever * Update haystack/nodes/retriever/dense.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/audio/whisper_transcriber.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/preview/components/audio/whisper_remote.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update haystack/nodes/answer_generator/openai.py Co-authored-by: bogdankostic <bogdankostic@web.de> * Update test_retriever.py * Update test_whisper_remote.py * Update test_generator.py * Update test_retriever.py * reformat with black * Update haystack/nodes/prompt/invocation_layer/chatgpt.py Co-authored-by: Daria Fokina <daria.f93@gmail.com> * Add unit tests * apply docstring suggestions --------- Co-authored-by: bogdankostic <bogdankostic@web.de> Co-authored-by: michaelfeil <me@michaelfeil.eu> Co-authored-by: Daria Fokina <daria.f93@gmail.com>
1307 lines
53 KiB
Python
1307 lines
53 KiB
Python
import logging
|
|
import os
|
|
from math import isclose
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
from unittest.mock import patch, Mock, DEFAULT
|
|
|
|
import pytest
|
|
import numpy as np
|
|
import pandas as pd
|
|
import requests
|
|
from boilerpy3.extractors import ArticleExtractor
|
|
from pandas.testing import assert_frame_equal
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
|
|
try:
|
|
from elasticsearch import Elasticsearch
|
|
except (ImportError, ModuleNotFoundError) as ie:
|
|
from haystack.utils.import_utils import _optional_component_not_installed
|
|
|
|
_optional_component_not_installed(__name__, "elasticsearch", ie)
|
|
|
|
|
|
from haystack.document_stores.base import BaseDocumentStore, FilterType
|
|
from haystack.document_stores.memory import InMemoryDocumentStore
|
|
from haystack.document_stores import WeaviateDocumentStore
|
|
from haystack.nodes.retriever.base import BaseRetriever
|
|
from haystack.nodes.retriever.web import WebRetriever
|
|
from haystack.nodes.search_engine import WebSearch
|
|
from haystack.pipelines import DocumentSearchPipeline
|
|
from haystack.schema import Document
|
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
|
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
|
|
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
|
|
from haystack.nodes.retriever.multimodal import MultiModalRetriever
|
|
|
|
from ..conftest import MockBaseRetriever, fail_at_version
|
|
|
|
|
|
# TODO check if we this works with only "memory" arg
|
|
@pytest.mark.parametrize(
|
|
"retriever_with_docs,document_store_with_docs",
|
|
[
|
|
("mdr", "elasticsearch"),
|
|
("mdr", "faiss"),
|
|
("mdr", "memory"),
|
|
("dpr", "elasticsearch"),
|
|
("dpr", "faiss"),
|
|
("dpr", "memory"),
|
|
("embedding", "elasticsearch"),
|
|
("embedding", "faiss"),
|
|
("embedding", "memory"),
|
|
("bm25", "elasticsearch"),
|
|
("bm25", "memory"),
|
|
("bm25", "weaviate"),
|
|
("es_filter_only", "elasticsearch"),
|
|
("tfidf", "memory"),
|
|
],
|
|
indirect=True,
|
|
)
|
|
def test_retrieval_without_filters(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore):
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, TfidfRetriever)):
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
# NOTE: FilterRetriever simply returns all documents matching a filter,
|
|
# so without filters applied it does nothing
|
|
if not isinstance(retriever_with_docs, FilterRetriever):
|
|
# the BM25 implementation in Weaviate would NOT pick up the expected records
|
|
# because of the lack of stemming: "Who lives in berlin" returns only 1 record while
|
|
# "Who live in berlin" returns all 5 records.
|
|
# TODO - In Weaviate 1.19.0 there is a fix for the lack of stemming, which means that once 1.19.0 is released
|
|
# this `if` can be removed, as the standard search query "Who lives in Berlin?" should work with Weaviate.
|
|
# See https://github.com/weaviate/weaviate/issues/2439
|
|
if isinstance(document_store_with_docs, WeaviateDocumentStore):
|
|
res = retriever_with_docs.retrieve(query="Who live in berlin")
|
|
else:
|
|
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
|
assert res[0].content == "My name is Carla and I live in Berlin"
|
|
assert len(res) == 5
|
|
assert res[0].meta["name"] == "filename1"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"retriever_with_docs,document_store_with_docs",
|
|
[
|
|
("mdr", "elasticsearch"),
|
|
("mdr", "memory"),
|
|
("dpr", "elasticsearch"),
|
|
("dpr", "memory"),
|
|
("embedding", "elasticsearch"),
|
|
("embedding", "memory"),
|
|
("bm25", "elasticsearch"),
|
|
("bm25", "weaviate"),
|
|
("es_filter_only", "elasticsearch"),
|
|
],
|
|
indirect=True,
|
|
)
|
|
def test_retrieval_with_filters(retriever_with_docs: BaseRetriever, document_store_with_docs: BaseDocumentStore):
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever)):
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
# single filter
|
|
result = retriever_with_docs.retrieve(query="Christelle", filters={"name": ["filename3"]}, top_k=5)
|
|
assert len(result) == 1
|
|
assert type(result[0]) == Document
|
|
assert result[0].content == "My name is Christelle and I live in Paris"
|
|
assert result[0].meta["name"] == "filename3"
|
|
|
|
# multiple filters
|
|
result = retriever_with_docs.retrieve(
|
|
query="Paul", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5
|
|
)
|
|
assert len(result) == 1
|
|
assert type(result[0]) == Document
|
|
assert result[0].meta["name"] == "filename2"
|
|
|
|
result = retriever_with_docs.retrieve(
|
|
query="Carla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5
|
|
)
|
|
assert len(result) == 0
|
|
|
|
|
|
def test_tfidf_retriever_multiple_indexes():
|
|
docs_index_0 = [Document(content="test_1"), Document(content="test_2"), Document(content="test_3")]
|
|
docs_index_1 = [Document(content="test_4"), Document(content="test_5")]
|
|
ds = InMemoryDocumentStore(index="index_0")
|
|
tfidf_retriever = TfidfRetriever(document_store=ds)
|
|
|
|
ds.write_documents(docs_index_0)
|
|
tfidf_retriever.fit(ds, index="index_0")
|
|
ds.write_documents(docs_index_1, index="index_1")
|
|
tfidf_retriever.fit(ds, index="index_1")
|
|
|
|
assert tfidf_retriever.document_counts["index_0"] == ds.get_document_count(index="index_0")
|
|
assert tfidf_retriever.document_counts["index_1"] == ds.get_document_count(index="index_1")
|
|
|
|
|
|
def test_retrieval_empty_query(document_store: BaseDocumentStore):
|
|
# test with empty query using the run() method
|
|
mock_document = Document(id="0", content="test")
|
|
retriever = MockBaseRetriever(document_store=document_store, mock_document=mock_document)
|
|
result = retriever.run(root_node="Query", query="", filters={})
|
|
assert result[0]["documents"][0] == mock_document
|
|
|
|
result = retriever.run_batch(root_node="Query", queries=[""], filters={})
|
|
assert result[0]["documents"][0][0] == mock_document
|
|
|
|
|
|
@pytest.mark.parametrize("retriever_with_docs", ["embedding", "dpr", "tfidf"], indirect=True)
|
|
def test_batch_retrieval_single_query(retriever_with_docs, document_store_with_docs):
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?"])
|
|
|
|
# Expected return type: List of lists of Documents
|
|
assert isinstance(res, list)
|
|
assert isinstance(res[0], list)
|
|
assert isinstance(res[0][0], Document)
|
|
|
|
assert len(res) == 1
|
|
assert len(res[0]) == 5
|
|
assert res[0][0].content == "My name is Carla and I live in Berlin"
|
|
assert res[0][0].meta["name"] == "filename1"
|
|
|
|
|
|
@pytest.mark.parametrize("retriever_with_docs", ["embedding", "dpr", "tfidf"], indirect=True)
|
|
def test_batch_retrieval_multiple_queries(retriever_with_docs, document_store_with_docs):
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever, TfidfRetriever)):
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
res = retriever_with_docs.retrieve_batch(queries=["Who lives in Berlin?", "Who lives in New York?"])
|
|
|
|
# Expected return type: list of lists of Documents
|
|
assert isinstance(res, list)
|
|
assert isinstance(res[0], list)
|
|
assert isinstance(res[0][0], Document)
|
|
|
|
assert res[0][0].content == "My name is Carla and I live in Berlin"
|
|
assert len(res[0]) == 5
|
|
assert res[0][0].meta["name"] == "filename1"
|
|
|
|
assert res[1][0].content == "My name is Paul and I live in New York"
|
|
assert len(res[1]) == 5
|
|
assert res[1][0].meta["name"] == "filename2"
|
|
|
|
|
|
@pytest.mark.parametrize("retriever_with_docs", ["bm25"], indirect=True)
|
|
def test_batch_retrieval_multiple_queries_with_filters(retriever_with_docs, document_store_with_docs):
|
|
if not isinstance(retriever_with_docs, (BM25Retriever, FilterRetriever)):
|
|
document_store_with_docs.update_embeddings(retriever_with_docs)
|
|
|
|
# Weaviate does not support BM25 with filters yet, only after Weaviate v1.18.0
|
|
# TODO - remove this once Weaviate starts supporting BM25 WITH filters
|
|
# You might also need to modify the first query, as Weaviate having problems with
|
|
# retrieving the "My name is Carla and I live in Berlin" record just with the
|
|
# "Who lives in Berlin?" BM25 query
|
|
if isinstance(document_store_with_docs, WeaviateDocumentStore):
|
|
return
|
|
|
|
res = retriever_with_docs.retrieve_batch(
|
|
queries=["Who lives in Berlin?", "Who lives in New York?"], filters=[{"name": "filename1"}, None]
|
|
)
|
|
|
|
# Expected return type: list of lists of Documents
|
|
assert isinstance(res, list)
|
|
assert isinstance(res[0], list)
|
|
assert isinstance(res[0][0], Document)
|
|
|
|
assert res[0][0].content == "My name is Carla and I live in Berlin"
|
|
assert len(res[0]) == 5
|
|
assert res[0][0].meta["name"] == "filename1"
|
|
|
|
assert res[1][0].content == "My name is Paul and I live in New York"
|
|
assert len(res[1]) == 5
|
|
assert res[1][0].meta["name"] == "filename2"
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
def test_elasticsearch_custom_query():
|
|
client = Elasticsearch()
|
|
client.indices.delete(index="haystack_test_custom", ignore=[404])
|
|
document_store = ElasticsearchDocumentStore(
|
|
index="haystack_test_custom", content_field="custom_text_field", embedding_field="custom_embedding_field"
|
|
)
|
|
documents = [
|
|
{"content": "test_1", "meta": {"year": "2019"}},
|
|
{"content": "test_2", "meta": {"year": "2020"}},
|
|
{"content": "test_3", "meta": {"year": "2021"}},
|
|
{"content": "test_4", "meta": {"year": "2021"}},
|
|
{"content": "test_5", "meta": {"year": "2021"}},
|
|
]
|
|
document_store.write_documents(documents)
|
|
|
|
# test custom "terms" query
|
|
retriever = BM25Retriever(
|
|
document_store=document_store,
|
|
custom_query="""
|
|
{
|
|
"size": 10,
|
|
"query": {
|
|
"bool": {
|
|
"should": [{
|
|
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
|
"filter": [{"terms": {"year": ${years}}}]}}}""",
|
|
)
|
|
results = retriever.retrieve(query="test", filters={"years": ["2020", "2021"]})
|
|
assert len(results) == 4
|
|
|
|
# test linefeeds in query
|
|
results = retriever.retrieve(query="test\n", filters={"years": ["2020", "2021"]})
|
|
assert len(results) == 3
|
|
|
|
# test double quote in query
|
|
results = retriever.retrieve(query='test"', filters={"years": ["2020", "2021"]})
|
|
assert len(results) == 3
|
|
|
|
# test custom "term" query
|
|
retriever = BM25Retriever(
|
|
document_store=document_store,
|
|
custom_query="""
|
|
{
|
|
"size": 10,
|
|
"query": {
|
|
"bool": {
|
|
"should": [{
|
|
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["content"]}}],
|
|
"filter": [{"term": {"year": ${years}}}]}}}""",
|
|
)
|
|
results = retriever.retrieve(query="test", filters={"years": "2021"})
|
|
assert len(results) == 3
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs.sort(key=lambda d: d.id)
|
|
|
|
print([doc.id for doc in docs])
|
|
|
|
expected_values = [0.00892, 0.00780, 0.00482, -0.00626, 0.010966]
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
embedding = doc.embedding
|
|
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
|
embedding /= np.linalg.norm(embedding)
|
|
assert len(embedding) == 768
|
|
assert isclose(embedding[0], expected_value, rel_tol=0.01)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
|
@pytest.mark.embedding_dim(128)
|
|
def test_retribert_embedding(document_store, retriever, docs_with_ids):
|
|
if isinstance(document_store, WeaviateDocumentStore):
|
|
# Weaviate sets the embedding dimension to 768 as soon as it is initialized.
|
|
# We need 128 here and therefore initialize a new WeaviateDocumentStore.
|
|
document_store = WeaviateDocumentStore(index="haystack_test", embedding_dim=128, recreate_index=True)
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
expected_values = [0.14017, 0.05975, 0.14267, 0.15099, 0.14383]
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
embedding = doc.embedding
|
|
assert len(embedding) == 128
|
|
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
|
embedding /= np.linalg.norm(embedding)
|
|
assert isclose(embedding[0], expected_value, rel_tol=0.001)
|
|
|
|
|
|
def test_openai_embedding_retriever_selection():
|
|
# OpenAI released (Dec 2022) a unifying embedding model called text-embedding-ada-002
|
|
# make sure that we can use it with the retriever selection
|
|
er = EmbeddingRetriever(embedding_model="text-embedding-ada-002", document_store=None)
|
|
assert er.model_format == "openai"
|
|
assert er.embedding_encoder.query_encoder_model == "text-embedding-ada-002"
|
|
assert er.embedding_encoder.doc_encoder_model == "text-embedding-ada-002"
|
|
assert er.api_base == "https://api.openai.com/v1"
|
|
|
|
# but also support old ada and other text-search-<modelname>-*-001 models
|
|
er = EmbeddingRetriever(embedding_model="ada", document_store=None)
|
|
assert er.model_format == "openai"
|
|
assert er.embedding_encoder.query_encoder_model == "text-search-ada-query-001"
|
|
assert er.embedding_encoder.doc_encoder_model == "text-search-ada-doc-001"
|
|
assert er.api_base == "https://api.openai.com/v1"
|
|
|
|
# but also support old babbage and other text-search-<modelname>-*-001 models
|
|
er = EmbeddingRetriever(embedding_model="babbage", document_store=None)
|
|
assert er.model_format == "openai"
|
|
assert er.embedding_encoder.query_encoder_model == "text-search-babbage-query-001"
|
|
assert er.embedding_encoder.doc_encoder_model == "text-search-babbage-doc-001"
|
|
|
|
# make sure that we can handle potential unreleased models
|
|
er = EmbeddingRetriever(embedding_model="text-embedding-babbage-002", document_store=None)
|
|
assert er.model_format == "openai"
|
|
assert er.embedding_encoder.query_encoder_model == "text-embedding-babbage-002"
|
|
assert er.embedding_encoder.doc_encoder_model == "text-embedding-babbage-002"
|
|
# etc etc.
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["cohere"], indirect=True)
|
|
@pytest.mark.embedding_dim(1024)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("COHERE_API_KEY", None),
|
|
reason="Please export an env var called COHERE_API_KEY containing " "the Cohere API key to run this test.",
|
|
)
|
|
def test_basic_cohere_embedding(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
for doc in docs:
|
|
assert len(doc.embedding) == 1024
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
|
@pytest.mark.embedding_dim(1536)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason=("Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test."),
|
|
)
|
|
def test_basic_openai_embedding(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
for doc in docs:
|
|
assert len(doc.embedding) == 1536
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["azure"], indirect=True)
|
|
@pytest.mark.embedding_dim(1536)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("AZURE_OPENAI_API_KEY", None)
|
|
and not os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
|
and not os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED", None),
|
|
reason=(
|
|
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
|
"the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
|
|
"the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
|
|
"the Azure OpenAI deployment name to run this test."
|
|
),
|
|
)
|
|
def test_basic_azure_embedding(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
for doc in docs:
|
|
assert len(doc.embedding) == 1536
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["cohere"], indirect=True)
|
|
@pytest.mark.embedding_dim(1024)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("COHERE_API_KEY", None),
|
|
reason="Please export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
|
|
)
|
|
def test_retriever_basic_cohere_search(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
p_retrieval = DocumentSearchPipeline(retriever)
|
|
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
|
assert len(res["documents"]) == 1
|
|
assert "Madrid" in res["documents"][0].content
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["openai"], indirect=True)
|
|
@pytest.mark.embedding_dim(1536)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
reason="Please export env called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
)
|
|
def test_retriever_basic_openai_search(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
p_retrieval = DocumentSearchPipeline(retriever)
|
|
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
|
assert len(res["documents"]) == 1
|
|
assert "Madrid" in res["documents"][0].content
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.parametrize("retriever", ["azure"], indirect=True)
|
|
@pytest.mark.embedding_dim(1536)
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("AZURE_OPENAI_API_KEY", None)
|
|
and not os.environ.get("AZURE_OPENAI_BASE_URL", None)
|
|
and not os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME_EMBED", None),
|
|
reason=(
|
|
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
|
"the Azure OpenAI key, AZURE_OPENAI_BASE_URL containing "
|
|
"the Azure OpenAI base URL, and AZURE_OPENAI_DEPLOYMENT_NAME_EMBED containing "
|
|
"the Azure OpenAI deployment name to run this test."
|
|
),
|
|
)
|
|
def test_retriever_basic_azure_search(document_store, retriever, docs_with_ids):
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs_with_ids)
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
p_retrieval = DocumentSearchPipeline(retriever)
|
|
res = p_retrieval.run(query="Madrid", params={"Retriever": {"top_k": 1}})
|
|
assert len(res["documents"]) == 1
|
|
assert "Madrid" in res["documents"][0].content
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "memory"], indirect=True)
|
|
@pytest.mark.embedding_dim(512)
|
|
def test_table_text_retriever_embedding(document_store, retriever, docs):
|
|
# BM25 representation is incompatible with table retriever
|
|
if isinstance(document_store, InMemoryDocumentStore):
|
|
document_store.use_bm25 = False
|
|
|
|
document_store.return_embedding = True
|
|
document_store.write_documents(docs)
|
|
table_data = {
|
|
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
|
"Height": ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"],
|
|
}
|
|
table = pd.DataFrame(table_data)
|
|
table_doc = Document(content=table, content_type="table", id="6")
|
|
document_store.write_documents([table_doc])
|
|
document_store.update_embeddings(retriever=retriever)
|
|
|
|
docs = document_store.get_all_documents()
|
|
docs = sorted(docs, key=lambda d: d.id)
|
|
|
|
expected_values = [0.061191384, 0.038075786, 0.27447605, 0.09399721, 0.0959682]
|
|
for doc, expected_value in zip(docs, expected_values):
|
|
assert len(doc.embedding) == 512
|
|
assert isclose(doc.embedding[0], expected_value, rel_tol=0.001)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.embedding_dim(512)
|
|
def test_table_text_retriever_embedding_only_text(document_store, retriever):
|
|
docs = [
|
|
Document(content="This is a test", content_type="text"),
|
|
Document(content="This is another test", content_type="text"),
|
|
]
|
|
document_store.write_documents(docs)
|
|
document_store.update_embeddings(retriever)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
@pytest.mark.embedding_dim(512)
|
|
def test_table_text_retriever_embedding_only_table(document_store, retriever):
|
|
doc = Document(
|
|
content=pd.DataFrame(columns=["id", "text"], data=[["1", "This is a test"], ["2", "This is another test"]]),
|
|
content_type="table",
|
|
)
|
|
document_store.write_documents([doc])
|
|
document_store.update_embeddings(retriever)
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
|
def test_dpr_saving_and_loading(tmp_path, retriever, document_store):
|
|
retriever.save(f"{tmp_path}/test_dpr_save")
|
|
|
|
def sum_params(model):
|
|
s = []
|
|
for p in model.parameters():
|
|
n = p.cpu().data.numpy()
|
|
s.append(np.sum(n))
|
|
return sum(s)
|
|
|
|
original_sum_query = sum_params(retriever.query_encoder)
|
|
original_sum_passage = sum_params(retriever.passage_encoder)
|
|
del retriever
|
|
|
|
loaded_retriever = DensePassageRetriever.load(f"{tmp_path}/test_dpr_save", document_store)
|
|
|
|
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
|
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
|
|
|
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
|
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
|
|
|
# comparison of weights (RAM intense!)
|
|
# for p1, p2 in zip(retriever.query_encoder.parameters(), loaded_retriever.query_encoder.parameters()):
|
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
|
#
|
|
# for p1, p2 in zip(retriever.passage_encoder.parameters(), loaded_retriever.passage_encoder.parameters()):
|
|
# assert (p1.data.ne(p2.data).sum() == 0)
|
|
|
|
# attributes
|
|
assert loaded_retriever.processor.embed_title == True
|
|
assert loaded_retriever.batch_size == 16
|
|
assert loaded_retriever.processor.max_seq_len_passage == 256
|
|
assert loaded_retriever.processor.max_seq_len_query == 64
|
|
|
|
# Tokenizer
|
|
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
|
|
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
|
|
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
|
@pytest.mark.embedding_dim(512)
|
|
def test_table_text_retriever_saving_and_loading(tmp_path, retriever, document_store):
|
|
retriever.save(f"{tmp_path}/test_table_text_retriever_save")
|
|
|
|
def sum_params(model):
|
|
s = []
|
|
for p in model.parameters():
|
|
n = p.cpu().data.numpy()
|
|
s.append(np.sum(n))
|
|
return sum(s)
|
|
|
|
original_sum_query = sum_params(retriever.query_encoder)
|
|
original_sum_passage = sum_params(retriever.passage_encoder)
|
|
original_sum_table = sum_params(retriever.table_encoder)
|
|
del retriever
|
|
|
|
loaded_retriever = TableTextRetriever.load(f"{tmp_path}/test_table_text_retriever_save", document_store)
|
|
|
|
loaded_sum_query = sum_params(loaded_retriever.query_encoder)
|
|
loaded_sum_passage = sum_params(loaded_retriever.passage_encoder)
|
|
loaded_sum_table = sum_params(loaded_retriever.table_encoder)
|
|
|
|
assert abs(original_sum_query - loaded_sum_query) < 0.1
|
|
assert abs(original_sum_passage - loaded_sum_passage) < 0.1
|
|
assert abs(original_sum_table - loaded_sum_table) < 0.01
|
|
|
|
# attributes
|
|
assert loaded_retriever.processor.embed_meta_fields == ["name", "section_title", "caption"]
|
|
assert loaded_retriever.batch_size == 16
|
|
assert loaded_retriever.processor.max_seq_len_passage == 256
|
|
assert loaded_retriever.processor.max_seq_len_table == 256
|
|
assert loaded_retriever.processor.max_seq_len_query == 64
|
|
|
|
# Tokenizer
|
|
assert isinstance(loaded_retriever.passage_tokenizer, PreTrainedTokenizerFast)
|
|
assert isinstance(loaded_retriever.table_tokenizer, PreTrainedTokenizerFast)
|
|
assert isinstance(loaded_retriever.query_tokenizer, PreTrainedTokenizerFast)
|
|
assert loaded_retriever.passage_tokenizer.do_lower_case == True
|
|
assert loaded_retriever.table_tokenizer.do_lower_case == True
|
|
assert loaded_retriever.query_tokenizer.do_lower_case == True
|
|
assert loaded_retriever.passage_tokenizer.vocab_size == 30522
|
|
assert loaded_retriever.table_tokenizer.vocab_size == 30522
|
|
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
|
|
|
|
|
@pytest.mark.embedding_dim(128)
|
|
def test_table_text_retriever_training(tmp_path, document_store, samples_path):
|
|
retriever = TableTextRetriever(
|
|
document_store=document_store,
|
|
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
|
|
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
|
|
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
|
|
use_gpu=False,
|
|
)
|
|
|
|
retriever.train(
|
|
data_dir=samples_path / "mmr",
|
|
train_filename="sample.json",
|
|
n_epochs=1,
|
|
n_gpu=0,
|
|
save_dir=f"{tmp_path}/test_table_text_retriever_train",
|
|
)
|
|
|
|
# Load trained model
|
|
retriever = TableTextRetriever.load(
|
|
load_dir=f"{tmp_path}/test_table_text_retriever_train", document_store=document_store
|
|
)
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
def test_elasticsearch_highlight():
|
|
client = Elasticsearch()
|
|
client.indices.delete(index="haystack_hl_test", ignore=[404])
|
|
|
|
# Mapping the content and title field as "text" perform search on these both fields.
|
|
document_store = ElasticsearchDocumentStore(
|
|
index="haystack_hl_test",
|
|
content_field="title",
|
|
custom_mapping={"mappings": {"properties": {"content": {"type": "text"}, "title": {"type": "text"}}}},
|
|
)
|
|
documents = [
|
|
{
|
|
"title": "Green tea components",
|
|
"meta": {
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink"
|
|
},
|
|
"id": "1",
|
|
},
|
|
{
|
|
"title": "Green tea catechin",
|
|
"meta": {"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG)."},
|
|
"id": "2",
|
|
},
|
|
{
|
|
"title": "Minerals in Green tea",
|
|
"meta": {"content": "Green tea also has small amounts of minerals that can benefit your health."},
|
|
"id": "3",
|
|
},
|
|
{
|
|
"title": "Green tea Benefits",
|
|
"meta": {"content": "Green tea does more than just keep you alert, it may also help boost brain function."},
|
|
"id": "4",
|
|
},
|
|
]
|
|
document_store.write_documents(documents)
|
|
|
|
# Enabled highlighting on "title"&"content" field only using custom query
|
|
retriever_1 = BM25Retriever(
|
|
document_store=document_store,
|
|
custom_query="""{
|
|
"size": 20,
|
|
"query": {
|
|
"bool": {
|
|
"should": [
|
|
{
|
|
"multi_match": {
|
|
"query": ${query},
|
|
"fields": [
|
|
"content^3",
|
|
"title^5"
|
|
]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
},
|
|
"highlight": {
|
|
"pre_tags": [
|
|
"**"
|
|
],
|
|
"post_tags": [
|
|
"**"
|
|
],
|
|
"number_of_fragments": 3,
|
|
"fragment_size": 5,
|
|
"fields": {
|
|
"content": {},
|
|
"title": {}
|
|
}
|
|
}
|
|
}""",
|
|
)
|
|
results = retriever_1.retrieve(query="is green tea healthy")
|
|
|
|
assert len(results[0].meta["highlighted"]) == 2
|
|
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
|
|
assert results[0].meta["highlighted"]["content"] == ["The **green**", "**tea** plant", "range of **healthy**"]
|
|
|
|
# Enabled highlighting on "title" field only using custom query
|
|
retriever_2 = BM25Retriever(
|
|
document_store=document_store,
|
|
custom_query="""{
|
|
"size": 20,
|
|
"query": {
|
|
"bool": {
|
|
"should": [
|
|
{
|
|
"multi_match": {
|
|
"query": ${query},
|
|
"fields": [
|
|
"content^3",
|
|
"title^5"
|
|
]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
},
|
|
"highlight": {
|
|
"pre_tags": [
|
|
"**"
|
|
],
|
|
"post_tags": [
|
|
"**"
|
|
],
|
|
"number_of_fragments": 3,
|
|
"fragment_size": 5,
|
|
"fields": {
|
|
"title": {}
|
|
}
|
|
}
|
|
}""",
|
|
)
|
|
results = retriever_2.retrieve(query="is green tea healthy")
|
|
|
|
assert len(results[0].meta["highlighted"]) == 1
|
|
assert results[0].meta["highlighted"]["title"] == ["**Green**", "**tea** components"]
|
|
|
|
|
|
def test_elasticsearch_filter_must_not_increase_results():
|
|
index = "filter_must_not_increase_results"
|
|
client = Elasticsearch()
|
|
client.indices.delete(index=index, ignore=[404])
|
|
documents = [
|
|
{
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
"meta": {"content_type": "text"},
|
|
"id": "1",
|
|
},
|
|
{
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
"meta": {"content_type": "text"},
|
|
"id": "2",
|
|
},
|
|
{
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "3",
|
|
},
|
|
{
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "4",
|
|
},
|
|
]
|
|
doc_store = ElasticsearchDocumentStore(index=index)
|
|
doc_store.write_documents(documents)
|
|
results_wo_filter = doc_store.query(query="drink")
|
|
assert len(results_wo_filter) == 1
|
|
results_w_filter = doc_store.query(query="drink", filters={"content_type": "text"})
|
|
assert len(results_w_filter) == 1
|
|
doc_store.delete_index(index)
|
|
|
|
|
|
def test_elasticsearch_all_terms_must_match():
|
|
index = "all_terms_must_match"
|
|
client = Elasticsearch()
|
|
client.indices.delete(index=index, ignore=[404])
|
|
documents = [
|
|
{
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
"meta": {"content_type": "text"},
|
|
"id": "1",
|
|
},
|
|
{
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
"meta": {"content_type": "text"},
|
|
"id": "2",
|
|
},
|
|
{
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "3",
|
|
},
|
|
{
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "4",
|
|
},
|
|
]
|
|
doc_store = ElasticsearchDocumentStore(index=index)
|
|
doc_store.write_documents(documents)
|
|
results_wo_all_terms_must_match = doc_store.query(query="drink green tea")
|
|
assert len(results_wo_all_terms_must_match) == 4
|
|
results_w_all_terms_must_match = doc_store.query(query="drink green tea", all_terms_must_match=True)
|
|
assert len(results_w_all_terms_must_match) == 1
|
|
doc_store.delete_index(index)
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
def test_bm25retriever_all_terms_must_match():
|
|
index = "all_terms_must_match"
|
|
client = Elasticsearch()
|
|
client.indices.delete(index=index, ignore=[404])
|
|
documents = [
|
|
{
|
|
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
|
"meta": {"content_type": "text"},
|
|
"id": "1",
|
|
},
|
|
{
|
|
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
|
"meta": {"content_type": "text"},
|
|
"id": "2",
|
|
},
|
|
{
|
|
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "3",
|
|
},
|
|
{
|
|
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
|
"meta": {"content_type": "text"},
|
|
"id": "4",
|
|
},
|
|
]
|
|
doc_store = ElasticsearchDocumentStore(index=index)
|
|
doc_store.write_documents(documents)
|
|
retriever = BM25Retriever(document_store=doc_store)
|
|
results_wo_all_terms_must_match = retriever.retrieve(query="drink green tea")
|
|
assert len(results_wo_all_terms_must_match) == 4
|
|
retriever = BM25Retriever(document_store=doc_store, all_terms_must_match=True)
|
|
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea")
|
|
assert len(results_w_all_terms_must_match) == 1
|
|
retriever = BM25Retriever(document_store=doc_store)
|
|
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea", all_terms_must_match=True)
|
|
assert len(results_w_all_terms_must_match) == 1
|
|
doc_store.delete_index(index)
|
|
|
|
|
|
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format(caplog):
|
|
document_store = InMemoryDocumentStore()
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
EmbeddingRetriever(
|
|
document_store=document_store,
|
|
embedding_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
|
model_format="farm",
|
|
)
|
|
|
|
assert (
|
|
"You may need to set model_format='sentence_transformers' to ensure correct loading of model."
|
|
in caplog.text
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("retriever", ["es_filter_only"], indirect=True)
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
|
def test_es_filter_only(document_store, retriever):
|
|
docs = [
|
|
Document(content="Doc1", meta={"f1": "0"}),
|
|
Document(content="Doc2", meta={"f1": "0"}),
|
|
Document(content="Doc3", meta={"f1": "0"}),
|
|
Document(content="Doc4", meta={"f1": "0"}),
|
|
Document(content="Doc5", meta={"f1": "0"}),
|
|
Document(content="Doc6", meta={"f1": "0"}),
|
|
Document(content="Doc7", meta={"f1": "1"}),
|
|
Document(content="Doc8", meta={"f1": "0"}),
|
|
Document(content="Doc9", meta={"f1": "0"}),
|
|
Document(content="Doc10", meta={"f1": "0"}),
|
|
Document(content="Doc11", meta={"f1": "0"}),
|
|
Document(content="Doc12", meta={"f1": "0"}),
|
|
]
|
|
document_store.write_documents(docs)
|
|
retrieved_docs = retriever.retrieve(query="", filters={"f1": ["0"]})
|
|
assert len(retrieved_docs) == 11
|
|
|
|
|
|
#
|
|
# MultiModal
|
|
#
|
|
|
|
|
|
@pytest.fixture
|
|
def text_docs() -> List[Document]:
|
|
return [
|
|
Document(
|
|
content="My name is Paul and I live in New York",
|
|
meta={
|
|
"meta_field": "test2",
|
|
"name": "filename2",
|
|
"date_field": "2019-10-01",
|
|
"numeric_field": 5.0,
|
|
"odd_field": 0,
|
|
},
|
|
),
|
|
Document(
|
|
content="My name is Carla and I live in Berlin",
|
|
meta={
|
|
"meta_field": "test1",
|
|
"name": "filename1",
|
|
"date_field": "2020-03-01",
|
|
"numeric_field": 5.5,
|
|
"odd_field": 1,
|
|
},
|
|
),
|
|
Document(
|
|
content="My name is Christelle and I live in Paris",
|
|
meta={
|
|
"meta_field": "test3",
|
|
"name": "filename3",
|
|
"date_field": "2018-10-01",
|
|
"numeric_field": 4.5,
|
|
"odd_field": 1,
|
|
},
|
|
),
|
|
Document(
|
|
content="My name is Camila and I live in Madrid",
|
|
meta={
|
|
"meta_field": "test4",
|
|
"name": "filename4",
|
|
"date_field": "2021-02-01",
|
|
"numeric_field": 3.0,
|
|
"odd_field": 0,
|
|
},
|
|
),
|
|
Document(
|
|
content="My name is Matteo and I live in Rome",
|
|
meta={
|
|
"meta_field": "test5",
|
|
"name": "filename5",
|
|
"date_field": "2019-01-01",
|
|
"numeric_field": 0.0,
|
|
"odd_field": 1,
|
|
},
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def table_docs() -> List[Document]:
|
|
return [
|
|
Document(
|
|
content=pd.DataFrame(
|
|
{
|
|
"Mountain": ["Mount Everest", "K2", "Kangchenjunga", "Lhotse", "Makalu"],
|
|
"Height": ["8848m", "8,611 m", "8 586m", "8 516 m", "8,485m"],
|
|
}
|
|
),
|
|
content_type="table",
|
|
),
|
|
Document(
|
|
content=pd.DataFrame(
|
|
{
|
|
"City": ["Paris", "Lyon", "Marseille", "Lille", "Toulouse", "Bordeaux"],
|
|
"Population": ["13,114,718", "2,280,845", "1,873,270 ", "1,510,079", "1,454,158", "1,363,711"],
|
|
}
|
|
),
|
|
content_type="table",
|
|
),
|
|
Document(
|
|
content=pd.DataFrame(
|
|
{
|
|
"City": ["Berlin", "Hamburg", "Munich", "Cologne"],
|
|
"Population": ["3,644,826", "1,841,179", "1,471,508", "1,085,664"],
|
|
}
|
|
),
|
|
content_type="table",
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def image_docs(samples_path) -> List[Document]:
|
|
return [
|
|
Document(content=str(samples_path / "images" / imagefile), content_type="image")
|
|
for imagefile in os.listdir(samples_path / "images")
|
|
]
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_multimodal_text_retrieval(text_docs: List[Document]):
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True),
|
|
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
|
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
|
|
)
|
|
retriever.document_store.write_documents(text_docs)
|
|
retriever.document_store.update_embeddings(retriever=retriever)
|
|
|
|
results = retriever.retrieve(query="Who lives in Paris?")
|
|
assert results[0].content == "My name is Christelle and I live in Paris"
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_multimodal_text_retrieval_batch(text_docs: List[Document]):
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True),
|
|
query_embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
|
document_embedding_models={"text": "sentence-transformers/multi-qa-mpnet-base-dot-v1"},
|
|
)
|
|
retriever.document_store.write_documents(text_docs)
|
|
retriever.document_store.update_embeddings(retriever=retriever)
|
|
|
|
results = retriever.retrieve_batch(queries=["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Madrid?"])
|
|
assert results[0][0].content == "My name is Christelle and I live in Paris"
|
|
assert results[1][0].content == "My name is Carla and I live in Berlin"
|
|
assert results[2][0].content == "My name is Camila and I live in Madrid"
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_multimodal_table_retrieval(table_docs: List[Document]):
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True),
|
|
query_embedding_model="deepset/all-mpnet-base-v2-table",
|
|
document_embedding_models={"table": "deepset/all-mpnet-base-v2-table"},
|
|
)
|
|
retriever.document_store.write_documents(table_docs)
|
|
retriever.document_store.update_embeddings(retriever=retriever)
|
|
|
|
results = retriever.retrieve(query="How many people live in Hamburg?")
|
|
assert_frame_equal(
|
|
results[0].content,
|
|
pd.DataFrame(
|
|
{
|
|
"City": ["Berlin", "Hamburg", "Munich", "Cologne"],
|
|
"Population": ["3,644,826", "1,841,179", "1,471,508", "1,085,664"],
|
|
}
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.skip("Must be reworked as it fails randomly")
|
|
@pytest.mark.integration
|
|
def test_multimodal_retriever_query():
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
|
|
query_embedding_model="sentence-transformers/clip-ViT-B-32",
|
|
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
|
|
)
|
|
|
|
res_emb = retriever.embed_queries(["dummy query 1", "dummy query 1"])
|
|
assert np.array_equal(res_emb[0], res_emb[1])
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_multimodal_image_retrieval(image_docs: List[Document], samples_path):
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
|
|
query_embedding_model="sentence-transformers/clip-ViT-B-32",
|
|
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
|
|
)
|
|
retriever.document_store.write_documents(image_docs)
|
|
retriever.document_store.update_embeddings(retriever=retriever)
|
|
|
|
results = retriever.retrieve(query="What's a cat?")
|
|
assert str(results[0].content) == str(samples_path / "images" / "cat.jpg")
|
|
|
|
|
|
@pytest.mark.skip("Not working yet as intended")
|
|
@pytest.mark.integration
|
|
def test_multimodal_text_image_retrieval(text_docs: List[Document], image_docs: List[Document], samples_path):
|
|
retriever = MultiModalRetriever(
|
|
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
|
|
query_embedding_model="sentence-transformers/clip-ViT-B-32",
|
|
document_embedding_models={
|
|
"text": "sentence-transformers/clip-ViT-B-32",
|
|
"image": "sentence-transformers/clip-ViT-B-32",
|
|
},
|
|
)
|
|
retriever.document_store.write_documents(image_docs)
|
|
retriever.document_store.write_documents(text_docs)
|
|
retriever.document_store.update_embeddings(retriever=retriever)
|
|
|
|
results = retriever.retrieve(query="What's Paris?")
|
|
|
|
text_results = [result for result in results if result.content_type == "text"]
|
|
image_results = [result for result in results if result.content_type == "image"]
|
|
|
|
assert str(image_results[0].content) == str(samples_path / "images" / "paris.jpg")
|
|
assert text_results[0].content == "My name is Christelle and I live in Paris"
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_web_retriever_mode_raw_documents(monkeypatch):
|
|
expected_search_results = {
|
|
"documents": [
|
|
Document(
|
|
content="Eddard Stark",
|
|
score=0.9090909090909091,
|
|
meta={"title": "Eddard Stark", "link": "", "score": 0.9090909090909091},
|
|
id_hash_keys=["content"],
|
|
id="f408db6de8de0ffad0cb47cf8830dbb8",
|
|
),
|
|
Document(
|
|
content="The most likely answer for the clue is NED. How many solutions does Arya Stark's Father have? With crossword-solver.io you will find 1 solutions. We use ...",
|
|
score=0.09090909090909091,
|
|
meta={
|
|
"title": "Arya Stark's Father - Crossword Clue Answers",
|
|
"link": "https://crossword-solver.io/clue/arya-stark%27s-father/",
|
|
"position": 1,
|
|
"score": 0.09090909090909091,
|
|
},
|
|
id_hash_keys=["content"],
|
|
id="51779277acf94cf90e7663db137c0732",
|
|
),
|
|
]
|
|
}
|
|
|
|
def mock_web_search_run(self, query: str) -> Tuple[Dict, str]:
|
|
return expected_search_results, "output_1"
|
|
|
|
class MockResponse:
|
|
def __init__(self, text, status_code):
|
|
self.text = text
|
|
self.status_code = status_code
|
|
|
|
def get(url, headers, timeout):
|
|
return MockResponse("mocked", 200)
|
|
|
|
def get_content(self, text: str) -> str:
|
|
return "What are the top solutions for\nArya Stark's Father\nWe found 1 solutions for\nArya Stark's Father\n.The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED..."
|
|
|
|
monkeypatch.setattr(WebSearch, "run", mock_web_search_run)
|
|
monkeypatch.setattr(ArticleExtractor, "get_content", get_content)
|
|
monkeypatch.setattr(requests, "get", get)
|
|
|
|
web_retriever = WebRetriever(api_key="", top_search_results=2, mode="raw_documents")
|
|
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
|
|
assert len(result) == 1
|
|
assert isinstance(result[0], Document)
|
|
assert (
|
|
result[0].content
|
|
== "What are the top solutions for\nArya Stark's Father\nWe found 1 solutions for\nArya Stark's Father\n.The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED..."
|
|
)
|
|
assert result[0].score == None
|
|
assert result[0].meta["url"] == "https://crossword-solver.io/clue/arya-stark%27s-father/"
|
|
# Only preprocessed docs but not raw docs should have the _split_id field
|
|
assert "_split_id" not in result[0].meta
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_web_retriever_mode_preprocessed_documents(monkeypatch):
|
|
expected_search_results = {
|
|
"documents": [
|
|
Document(
|
|
content="Eddard Stark",
|
|
score=0.9090909090909091,
|
|
meta={"title": "Eddard Stark", "link": "", "score": 0.9090909090909091},
|
|
id_hash_keys=["content"],
|
|
id="f408db6de8de0ffad0cb47cf8830dbb8",
|
|
),
|
|
Document(
|
|
content="The most likely answer for the clue is NED. How many solutions does Arya Stark's Father have? With crossword-solver.io you will find 1 solutions. We use ...",
|
|
score=0.09090909090909091,
|
|
meta={
|
|
"title": "Arya Stark's Father - Crossword Clue Answers",
|
|
"link": "https://crossword-solver.io/clue/arya-stark%27s-father/",
|
|
"position": 1,
|
|
"score": 0.09090909090909091,
|
|
},
|
|
id_hash_keys=["content"],
|
|
id="51779277acf94cf90e7663db137c0732",
|
|
),
|
|
]
|
|
}
|
|
|
|
def mock_web_search_run(self, query: str) -> Tuple[Dict, str]:
|
|
return expected_search_results, "output_1"
|
|
|
|
class MockResponse:
|
|
def __init__(self, text, status_code):
|
|
self.text = text
|
|
self.status_code = status_code
|
|
|
|
def get(url, headers, timeout):
|
|
return MockResponse("mocked", 200)
|
|
|
|
def get_content(self, text: str) -> str:
|
|
return "What are the top solutions for\nArya Stark's Father\nWe found 1 solutions for\nArya Stark's Father\n.The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED..."
|
|
|
|
monkeypatch.setattr(WebSearch, "run", mock_web_search_run)
|
|
monkeypatch.setattr(ArticleExtractor, "get_content", get_content)
|
|
monkeypatch.setattr(requests, "get", get)
|
|
|
|
web_retriever = WebRetriever(api_key="", top_search_results=2, mode="preprocessed_documents")
|
|
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
|
|
assert len(result) == 1
|
|
assert isinstance(result[0], Document)
|
|
assert (
|
|
result[0].content
|
|
== "What are the top solutions for\nArya Stark's Father\nWe found 1 solutions for\nArya Stark's Father\n.The top solutions is determined by popularity, ratings and frequency of searches. The most likely answer for the clue is NED..."
|
|
)
|
|
assert result[0].score == None
|
|
assert result[0].meta["url"] == "https://crossword-solver.io/clue/arya-stark%27s-father/"
|
|
assert result[0].meta["_split_id"] == 0
|
|
|
|
|
|
@pytest.mark.unit
|
|
def test_web_retriever_mode_snippets(monkeypatch):
|
|
expected_search_results = {
|
|
"documents": [
|
|
Document(
|
|
content="Eddard Stark",
|
|
score=0.9090909090909091,
|
|
meta={"title": "Eddard Stark", "link": "", "score": 0.9090909090909091},
|
|
id_hash_keys=["content"],
|
|
id="f408db6de8de0ffad0cb47cf8830dbb8",
|
|
),
|
|
Document(
|
|
content="The most likely answer for the clue is NED. How many solutions does Arya Stark's Father have? With crossword-solver.io you will find 1 solutions. We use ...",
|
|
score=0.09090909090909091,
|
|
meta={
|
|
"title": "Arya Stark's Father - Crossword Clue Answers",
|
|
"link": "https://crossword-solver.io/clue/arya-stark%27s-father/",
|
|
"position": 1,
|
|
"score": 0.09090909090909091,
|
|
},
|
|
id_hash_keys=["content"],
|
|
id="51779277acf94cf90e7663db137c0732",
|
|
),
|
|
]
|
|
}
|
|
|
|
def mock_web_search_run(self, query: str) -> Tuple[Dict, str]:
|
|
return expected_search_results, "output_1"
|
|
|
|
monkeypatch.setattr(WebSearch, "run", mock_web_search_run)
|
|
web_retriever = WebRetriever(api_key="", top_search_results=2)
|
|
result = web_retriever.retrieve(query="Who is the father of Arya Stark?")
|
|
assert result == expected_search_results["documents"]
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
|
def test_openai_default_api_base(mock_request):
|
|
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
|
retriever = EmbeddingRetriever(embedding_model="text-embedding-ada-002", api_key="fake_api_key")
|
|
assert retriever.api_base == "https://api.openai.com/v1"
|
|
|
|
retriever.embed_queries(queries=["test query"])
|
|
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
|
|
mock_request.reset_mock()
|
|
|
|
retriever.embed_documents(documents=[Document(content="test document")])
|
|
assert mock_request.call_args.kwargs["url"] == "https://api.openai.com/v1/embeddings"
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("haystack.nodes.retriever._openai_encoder.openai_request")
|
|
def test_openai_custom_api_base(mock_request):
|
|
with patch("haystack.nodes.retriever._openai_encoder.load_openai_tokenizer"):
|
|
retriever = EmbeddingRetriever(
|
|
embedding_model="text-embedding-ada-002", api_key="fake_api_key", api_base="https://fake_api_base.com"
|
|
)
|
|
assert retriever.api_base == "https://fake_api_base.com"
|
|
|
|
retriever.embed_queries(queries=["test query"])
|
|
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
|
|
mock_request.reset_mock()
|
|
|
|
retriever.embed_documents(documents=[Document(content="test document")])
|
|
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/embeddings"
|