mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
feat: embedding instructions for dense retrieval (#6372)
* Embedding instructions in EmbeddingRetriever Query and documents embeddings are prefixed with instructions, useful for retrievers finetuned on specific tasks, such as Q&A. * Tests Checking vectors 0th component vs. reference, using different stores. * Normalizing vectors * Release notes
This commit is contained in:
parent
07cda09aa8
commit
0cef17ac13
@ -1453,6 +1453,8 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
max_seq_len: int = 512,
|
||||
model_format: Optional[str] = None,
|
||||
pooling_strategy: str = "reduce_mean",
|
||||
query_prompt: Optional[str] = None,
|
||||
passage_prompt: Optional[str] = None,
|
||||
emb_extraction_layer: int = -1,
|
||||
top_k: int = 10,
|
||||
progress_bar: bool = True,
|
||||
@ -1495,7 +1497,8 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
2. `reduce_mean` (sentence vector)
|
||||
3. `reduce_max` (sentence vector)
|
||||
4. `per_token` (individual token vectors)
|
||||
|
||||
:param query_prompt: Model instruction for embedding texts to be used as queries.
|
||||
:param passage_prompt: Model instruction for embedding texts to be retrieved.
|
||||
:param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only).
|
||||
Default: -1 (very last layer).
|
||||
:param top_k: How many documents to return per query.
|
||||
@ -1550,6 +1553,8 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.pooling_strategy = pooling_strategy
|
||||
self.emb_extraction_layer = emb_extraction_layer
|
||||
self.query_prompt = query_prompt
|
||||
self.passage_prompt = passage_prompt
|
||||
self.top_k = top_k
|
||||
self.progress_bar = progress_bar
|
||||
self.use_auth_token = use_auth_token
|
||||
@ -1830,6 +1835,8 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
if self.query_prompt:
|
||||
queries = [self.query_prompt + " " + q for q in queries]
|
||||
return self.embedding_encoder.embed_queries(queries)
|
||||
|
||||
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||||
@ -1840,6 +1847,9 @@ class EmbeddingRetriever(DenseRetriever):
|
||||
:return: Embeddings, one per input document, shape: (docs, embedding_dim)
|
||||
"""
|
||||
documents = self._preprocess_documents(documents)
|
||||
if self.passage_prompt:
|
||||
for doc in documents:
|
||||
doc.content = self.passage_prompt + " " + doc.content
|
||||
return self.embedding_encoder.embed_documents(documents)
|
||||
|
||||
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Support for dense embedding instructions, used in retrieval models such as BGE and LLM-Embedder.
|
||||
@ -647,6 +647,15 @@ def get_retriever(retriever_type, document_store):
|
||||
model_format="sentence_transformers",
|
||||
use_gpu=False,
|
||||
)
|
||||
elif retriever_type == "embedding_sbert_instructions":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
embedding_model="sentence-transformers/msmarco-distilbert-dot-v5",
|
||||
model_format="sentence_transformers",
|
||||
query_prompt="Embed this query for retrieval:",
|
||||
passage_prompt="Embed this passage for retrieval:",
|
||||
use_gpu=False,
|
||||
)
|
||||
elif retriever_type == "retribert":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
|
||||
|
||||
@ -311,6 +311,28 @@ def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_i
|
||||
assert isclose(embedding[0], expected_value, rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["embedding_sbert_instructions"], indirect=True)
|
||||
def test_embedding_with_instructions(document_store: BaseDocumentStore, retriever, docs_with_ids):
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(docs_with_ids)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
|
||||
docs = document_store.get_all_documents()
|
||||
docs.sort(key=lambda d: d.id)
|
||||
|
||||
print([doc.id for doc in docs])
|
||||
|
||||
expected_values = [0.00484978, 0.02258789, 0.03414359, -0.01461711, 0.01784192]
|
||||
for doc, expected_value in zip(docs, expected_values):
|
||||
embedding = doc.embedding
|
||||
# always normalize vector as faiss returns normalized vectors and other document stores do not
|
||||
embedding /= np.linalg.norm(embedding)
|
||||
assert len(embedding) == 768
|
||||
assert isclose(embedding[0], expected_value, rel_tol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user