refactor: change MultiModal retriever to be of type DenseRetriever (#3598)

* changed Multimodal retriever to be of type DenseRetriever

* format fix

* Pylint fix

* Added embed_queries and tests
This commit is contained in:
Mayank Jobanputra 2022-11-28 19:24:22 +01:00 committed by GitHub
parent 6f9a0f2215
commit 95cf666a20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
import numpy as np
from haystack.nodes.retriever import BaseRetriever
from haystack.nodes.retriever import DenseRetriever
from haystack.document_stores import BaseDocumentStore
from haystack.schema import ContentTypes, Document
from haystack.nodes.retriever.multimodal.embedder import MultiModalEmbedder, MultiModalRetrieverError
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
FilterType = Optional[Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]]
class MultiModalRetriever(BaseRetriever):
class MultiModalRetriever(DenseRetriever):
def __init__(
self,
document_store: BaseDocumentStore,
@ -221,3 +221,7 @@ class MultiModalRetriever(BaseRetriever):
def embed_documents(self, docs: List[Document]) -> np.ndarray:
return self.document_embedder.embed(documents=docs)
def embed_queries(self, queries: List[str]) -> np.ndarray:
query_documents = [Document(content=query, content_type="text") for query in queries]
return self.query_embedder.embed(documents=query_documents)

View File

@ -831,6 +831,18 @@ def test_multimodal_table_retrieval(table_docs: List[Document]):
)
@pytest.mark.integration
def test_multimodal_retriever_query():
retriever = MultiModalRetriever(
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
query_embedding_model="sentence-transformers/clip-ViT-B-32",
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
)
res_emb = retriever.embed_queries(["dummy query 1", "dummy query 1"])
assert np.array_equal(res_emb[0], res_emb[1])
@pytest.mark.integration
def test_multimodal_image_retrieval(image_docs: List[Document]):
retriever = MultiModalRetriever(