mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 08:37:20 +00:00
refactor: change MultiModal retriever to be of type DenseRetriever (#3598)
* changed Multimodal retriever to be of type DenseRetriever * format fix * Pylint fix * Added embed_queries and tests
This commit is contained in:
parent
6f9a0f2215
commit
95cf666a20
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
from haystack.document_stores import BaseDocumentStore
|
||||
from haystack.schema import ContentTypes, Document
|
||||
from haystack.nodes.retriever.multimodal.embedder import MultiModalEmbedder, MultiModalRetrieverError
|
||||
@ -18,7 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
FilterType = Optional[Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]]
|
||||
|
||||
|
||||
class MultiModalRetriever(BaseRetriever):
|
||||
class MultiModalRetriever(DenseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
@ -221,3 +221,7 @@ class MultiModalRetriever(BaseRetriever):
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
return self.document_embedder.embed(documents=docs)
|
||||
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
query_documents = [Document(content=query, content_type="text") for query in queries]
|
||||
return self.query_embedder.embed(documents=query_documents)
|
||||
|
||||
@ -831,6 +831,18 @@ def test_multimodal_table_retrieval(table_docs: List[Document]):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multimodal_retriever_query():
|
||||
retriever = MultiModalRetriever(
|
||||
document_store=InMemoryDocumentStore(return_embedding=True, embedding_dim=512),
|
||||
query_embedding_model="sentence-transformers/clip-ViT-B-32",
|
||||
document_embedding_models={"image": "sentence-transformers/clip-ViT-B-32"},
|
||||
)
|
||||
|
||||
res_emb = retriever.embed_queries(["dummy query 1", "dummy query 1"])
|
||||
assert np.array_equal(res_emb[0], res_emb[1])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_multimodal_image_retrieval(image_docs: List[Document]):
|
||||
retriever = MultiModalRetriever(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user