diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 08847a90c..d4663e6cf 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1453,6 +1453,8 @@ class EmbeddingRetriever(DenseRetriever): max_seq_len: int = 512, model_format: Optional[str] = None, pooling_strategy: str = "reduce_mean", + query_prompt: Optional[str] = None, + passage_prompt: Optional[str] = None, emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, @@ -1495,7 +1497,8 @@ class EmbeddingRetriever(DenseRetriever): 2. `reduce_mean` (sentence vector) 3. `reduce_max` (sentence vector) 4. `per_token` (individual token vectors) - + :param query_prompt: Model instruction for embedding texts to be used as queries. + :param passage_prompt: Model instruction for embedding texts to be retrieved. :param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only). Default: -1 (very last layer). :param top_k: How many documents to return per query. @@ -1550,6 +1553,8 @@ class EmbeddingRetriever(DenseRetriever): self.max_seq_len = max_seq_len self.pooling_strategy = pooling_strategy self.emb_extraction_layer = emb_extraction_layer + self.query_prompt = query_prompt + self.passage_prompt = passage_prompt self.top_k = top_k self.progress_bar = progress_bar self.use_auth_token = use_auth_token @@ -1830,6 +1835,8 @@ class EmbeddingRetriever(DenseRetriever): if isinstance(queries, str): queries = [queries] assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" + if self.query_prompt: + queries = [self.query_prompt + " " + q for q in queries] return self.embedding_encoder.embed_queries(queries) def embed_documents(self, documents: List[Document]) -> np.ndarray: @@ -1840,6 +1847,9 @@ class EmbeddingRetriever(DenseRetriever): :return: Embeddings, one per input document, shape: (docs, embedding_dim) """ documents = self._preprocess_documents(documents) + if self.passage_prompt: + for doc in documents: + doc.content = self.passage_prompt + " " + doc.content return self.embedding_encoder.embed_documents(documents) def _preprocess_documents(self, docs: List[Document]) -> List[Document]: diff --git a/releasenotes/notes/embedding-instructions-4feb216cbf796678.yaml b/releasenotes/notes/embedding-instructions-4feb216cbf796678.yaml new file mode 100644 index 000000000..612a3160c --- /dev/null +++ b/releasenotes/notes/embedding-instructions-4feb216cbf796678.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Support for dense embedding instructions, used in retrieval models such as BGE and LLM-Embedder. diff --git a/test/conftest.py b/test/conftest.py index 630ddf494..9e6fdb742 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -647,6 +647,15 @@ def get_retriever(retriever_type, document_store): model_format="sentence_transformers", use_gpu=False, ) + elif retriever_type == "embedding_sbert_instructions": + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/msmarco-distilbert-dot-v5", + model_format="sentence_transformers", + query_prompt="Embed this query for retrieval:", + passage_prompt="Embed this passage for retrieval:", + use_gpu=False, + ) elif retriever_type == "retribert": retriever = EmbeddingRetriever( document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False diff --git a/test/nodes/test_retriever.py b/test/nodes/test_retriever.py index f22304c0b..44cb09387 100644 --- a/test/nodes/test_retriever.py +++ b/test/nodes/test_retriever.py @@ -311,6 +311,28 @@ def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_i assert isclose(embedding[0], expected_value, rel_tol=0.01) +@pytest.mark.integration +@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True) +@pytest.mark.parametrize("retriever", ["embedding_sbert_instructions"], indirect=True) +def test_embedding_with_instructions(document_store: BaseDocumentStore, retriever, docs_with_ids): + document_store.return_embedding = True + document_store.write_documents(docs_with_ids) + document_store.update_embeddings(retriever=retriever) + + docs = document_store.get_all_documents() + docs.sort(key=lambda d: d.id) + + print([doc.id for doc in docs]) + + expected_values = [0.00484978, 0.02258789, 0.03414359, -0.01461711, 0.01784192] + for doc, expected_value in zip(docs, expected_values): + embedding = doc.embedding + # always normalize vector as faiss returns normalized vectors and other document stores do not + embedding /= np.linalg.norm(embedding) + assert len(embedding) == 768 + assert isclose(embedding[0], expected_value, rel_tol=0.01) + + @pytest.mark.integration @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True) @pytest.mark.parametrize("retriever", ["retribert"], indirect=True)