feat: embedding instructions for dense retrieval (#6372)

* Embedding instructions in EmbeddingRetriever

Query and documents embeddings are prefixed with instructions, useful
for retrievers finetuned on specific tasks, such as Q&A.

* Tests

Checking vectors 0th component vs. reference, using different stores.

* Normalizing vectors

* Release notes
This commit is contained in:
Daniel Fleischer 2023-11-21 13:56:40 +02:00 committed by GitHub
parent 07cda09aa8
commit 0cef17ac13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 1 deletions

View File

@ -1453,6 +1453,8 @@ class EmbeddingRetriever(DenseRetriever):
max_seq_len: int = 512,
model_format: Optional[str] = None,
pooling_strategy: str = "reduce_mean",
query_prompt: Optional[str] = None,
passage_prompt: Optional[str] = None,
emb_extraction_layer: int = -1,
top_k: int = 10,
progress_bar: bool = True,
@ -1495,7 +1497,8 @@ class EmbeddingRetriever(DenseRetriever):
2. `reduce_mean` (sentence vector)
3. `reduce_max` (sentence vector)
4. `per_token` (individual token vectors)
:param query_prompt: Model instruction for embedding texts to be used as queries.
:param passage_prompt: Model instruction for embedding texts to be retrieved.
:param emb_extraction_layer: Number of layer from which the embeddings shall be extracted (for farm / transformers models only).
Default: -1 (very last layer).
:param top_k: How many documents to return per query.
@ -1550,6 +1553,8 @@ class EmbeddingRetriever(DenseRetriever):
self.max_seq_len = max_seq_len
self.pooling_strategy = pooling_strategy
self.emb_extraction_layer = emb_extraction_layer
self.query_prompt = query_prompt
self.passage_prompt = passage_prompt
self.top_k = top_k
self.progress_bar = progress_bar
self.use_auth_token = use_auth_token
@ -1830,6 +1835,8 @@ class EmbeddingRetriever(DenseRetriever):
if isinstance(queries, str):
queries = [queries]
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
if self.query_prompt:
queries = [self.query_prompt + " " + q for q in queries]
return self.embedding_encoder.embed_queries(queries)
def embed_documents(self, documents: List[Document]) -> np.ndarray:
@ -1840,6 +1847,9 @@ class EmbeddingRetriever(DenseRetriever):
:return: Embeddings, one per input document, shape: (docs, embedding_dim)
"""
documents = self._preprocess_documents(documents)
if self.passage_prompt:
for doc in documents:
doc.content = self.passage_prompt + " " + doc.content
return self.embedding_encoder.embed_documents(documents)
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:

View File

@ -0,0 +1,4 @@
---
features:
- |
Support for dense embedding instructions, used in retrieval models such as BGE and LLM-Embedder.

View File

@ -647,6 +647,15 @@ def get_retriever(retriever_type, document_store):
model_format="sentence_transformers",
use_gpu=False,
)
elif retriever_type == "embedding_sbert_instructions":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="sentence-transformers/msmarco-distilbert-dot-v5",
model_format="sentence_transformers",
query_prompt="Embed this query for retrieval:",
passage_prompt="Embed this passage for retrieval:",
use_gpu=False,
)
elif retriever_type == "retribert":
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False

View File

@ -311,6 +311,28 @@ def test_dpr_embedding(document_store: BaseDocumentStore, retriever, docs_with_i
assert isclose(embedding[0], expected_value, rel_tol=0.01)
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding_sbert_instructions"], indirect=True)
def test_embedding_with_instructions(document_store: BaseDocumentStore, retriever, docs_with_ids):
document_store.return_embedding = True
document_store.write_documents(docs_with_ids)
document_store.update_embeddings(retriever=retriever)
docs = document_store.get_all_documents()
docs.sort(key=lambda d: d.id)
print([doc.id for doc in docs])
expected_values = [0.00484978, 0.02258789, 0.03414359, -0.01461711, 0.01784192]
for doc, expected_value in zip(docs, expected_values):
embedding = doc.embedding
# always normalize vector as faiss returns normalized vectors and other document stores do not
embedding /= np.linalg.norm(embedding)
assert len(embedding) == 768
assert isclose(embedding[0], expected_value, rel_tol=0.01)
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "weaviate", "pinecone"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)