diff --git a/haystack/nodes/retriever/multimodal/retriever.py b/haystack/nodes/retriever/multimodal/retriever.py index b50701201..8ff9fdd2e 100644 --- a/haystack/nodes/retriever/multimodal/retriever.py +++ b/haystack/nodes/retriever/multimodal/retriever.py @@ -6,7 +6,7 @@ from pathlib import Path import torch import numpy as np -from haystack.nodes.retriever import BaseRetriever +from haystack.nodes.retriever import DenseRetriever from haystack.document_stores import BaseDocumentStore from haystack.schema import ContentTypes, Document from haystack.nodes.retriever.multimodal.embedder import MultiModalEmbedder, MultiModalRetrieverError @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) FilterType = Optional[Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]] -class MultiModalRetriever(BaseRetriever): +class MultiModalRetriever(DenseRetriever): def __init__( self, document_store: BaseDocumentStore, @@ -221,3 +221,7 @@ class MultiModalRetriever(BaseRetriever): def embed_documents(self, docs: List[Document]) -> np.ndarray: return self.document_embedder.embed(documents=docs) + + def embed_queries(self, queries: List[str]) -> np.ndarray: + query_documents = [Document(content=query, content_type="text") for query in queries] + return self.query_embedder.embed(documents=query_documents) diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index b779b1ee5..af6e4f1c3 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -831,6 +831,18 @@ def test_multimodal_table_retrieval(table_docs: List[Document]): ) +@pytest.mark.integration +def test_multimodal_retriever_query(): + retriever = MultiModalRetriever( + document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512), + query_embedding_model="sentence-transformers/clip-ViT-B-32", + document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"}, + ) + + res_emb = retriever.embed_queries(["dummy query 1", "dummy query 1"]) + assert np.array_equal(res_emb[0], res_emb[1]) + + @pytest.mark.integration def test_multimodal_image_retrieval(image_docs: List[Document]): retriever = MultiModalRetriever(