Rename every occurrence of 'embed_passages' with 'embed_documents' (#1667)

* Rename every occurrence of 'embed_passages' with 'embed_documents'

* Remove aliased method embed_documents

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2021-10-28 12:17:56 +02:00 committed by GitHub
parent 6892955e95
commit eab475bb5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 34 additions and 64 deletions

View File

@ -335,14 +335,14 @@ Create embeddings for a list of queries using the query encoder
Embeddings, one per input queries
<a name="dense.DensePassageRetriever.embed_passages"></a>
#### embed\_passages
<a name="dense.DensePassageRetriever.embed_documents"></a>
#### embed\_documents
```python
| embed_passages(docs: List[Document]) -> List[np.ndarray]
| embed_documents(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of passages using the passage encoder
Create embeddings for a list of documents using the passage encoder
**Arguments**:
@ -503,7 +503,7 @@ Embeddings, one per input queries
| embed_documents(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of text passages and / or tables using the text passage encoder and
Create embeddings for a list of text documents and / or tables using the text passage encoder and
the table encoder.
**Arguments**:
@ -515,25 +515,6 @@ the table encoder.
Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
<a name="dense.TableTextRetriever.embed_passages"></a>
#### embed\_passages
```python
| embed_passages(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of passages using the passage encoder.
This method just calls embed_documents. It is neeeded as the document stores call embed_passages when updating
embeddings.
**Arguments**:
- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack.
**Returns**:
Embeddings of documents / passages shape (batch_size, embedding_dim)
<a name="dense.TableTextRetriever.train"></a>
#### train
@ -683,14 +664,14 @@ Create embeddings for a list of queries.
Embeddings, one per input queries
<a name="dense.EmbeddingRetriever.embed_passages"></a>
#### embed\_passages
<a name="dense.EmbeddingRetriever.embed_documents"></a>
#### embed\_documents
```python
| embed_passages(docs: List[Document]) -> List[np.ndarray]
| embed_documents(docs: List[Document]) -> List[np.ndarray]
```
Create embeddings for a list of passages.
Create embeddings for a list of documents.
**Arguments**:
@ -698,7 +679,7 @@ Create embeddings for a list of passages.
**Returns**:
Embeddings, one per input passage
Embeddings, one per input document
<a name="text2sparql"></a>
# Module text2sparql

View File

@ -968,7 +968,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
if embeddings[0].shape[0] != self.embedding_dim:

View File

@ -274,7 +274,7 @@ class FAISSDocumentStore(SQLDocumentStore):
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
desc="Updating Embedding") as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
embeddings_to_index = np.array(embeddings, dtype="float32")

View File

@ -246,7 +246,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
with tqdm(total=document_count, disable=not self.progress_bar, position=0, unit=" docs",
desc="Updating Embedding") as progress_bar:
for document_batch in batched_documents:
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
if embeddings[0].shape[0] != self.embedding_dim:

View File

@ -296,7 +296,7 @@ class MilvusDocumentStore(SQLDocumentStore):
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
embeddings_list = [embedding.tolist() for embedding in embeddings]
assert len(document_batch) == len(embeddings_list)

View File

@ -403,7 +403,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
]
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
embeddings_list = [embedding.tolist() for embedding in embeddings]
assert len(document_batch) == len(embeddings_list)

View File

@ -671,7 +671,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch]
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings = retriever.embed_documents(document_batch) # type: ignore
assert len(document_batch) == len(embeddings)
if embeddings[0].shape[0] != self.embedding_dim:

View File

@ -183,7 +183,7 @@ class RAGenerator(BaseGenerator):
if self.retriever is None:
raise AttributeError("_prepare_passage_embeddings need a DPR instance as self.retriever to embed document")
embeddings = self.retriever.embed_passages(docs)
embeddings = self.retriever.embed_documents(docs)
embeddings_in_tensor = torch.cat(
[torch.from_numpy(embedding).float().unsqueeze(0) for embedding in embeddings],

View File

@ -33,12 +33,12 @@ class _BaseEmbeddingEncoder:
pass
@abstractmethod
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages.
Create embeddings for a list of documents.
:param docs: List of documents to embed
:return: Embeddings, one per input passage
:return: Embeddings, one per input document
"""
pass
@ -78,7 +78,7 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
return self.embed(texts)
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
passages = [d.content for d in docs] # type: ignore
return self.embed(passages)
@ -116,7 +116,7 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
return self.embed(texts)
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
return self.embed(passages)
@ -154,7 +154,7 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
return np.concatenate(embeddings)
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
doc_text = [{"text": d.content} for d in docs]
dataloader = self._create_dataloader(doc_text)

View File

@ -243,7 +243,7 @@ class BaseRetriever(BaseComponent):
if self.__class__.__name__ in ["DensePassageRetriever", "EmbeddingRetriever"]:
documents = deepcopy(documents)
document_objects = [Document.from_dict(doc) for doc in documents]
embeddings = self.embed_passages(document_objects) # type: ignore
embeddings = self.embed_documents(document_objects) # type: ignore
for doc, emb in zip(documents, embeddings):
doc["embedding"] = emb
output = {"documents": documents}

View File

@ -274,9 +274,9 @@ class DensePassageRetriever(BaseRetriever):
result = self._get_predictions(queries)["query"]
return result
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages using the passage encoder
Create embeddings for a list of documents using the passage encoder
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
:return: Embeddings of documents / passages shape (batch_size, embedding_dim)
@ -705,7 +705,7 @@ class TableTextRetriever(BaseRetriever):
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of text passages and / or tables using the text passage encoder and
Create embeddings for a list of text documents and / or tables using the text passage encoder and
the table encoder.
:param docs: List of Document objects used to represent documents / passages in
@ -742,17 +742,6 @@ class TableTextRetriever(BaseRetriever):
return embeddings
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages using the passage encoder.
This method just calls embed_documents. It is neeeded as the document stores call embed_passages when updating
embeddings.
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
:return: Embeddings of documents / passages shape (batch_size, embedding_dim)
"""
return self.embed_documents(docs)
def train(self,
data_dir: str,
train_filename: str,
@ -1025,11 +1014,11 @@ class EmbeddingRetriever(BaseRetriever):
assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
return self.embedding_encoder.embed_queries(texts)
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
"""
Create embeddings for a list of passages.
Create embeddings for a list of documents.
:param docs: List of documents to embed
:return: Embeddings, one per input passage
:return: Embeddings, one per input document
"""
return self.embedding_encoder.embed_passages(docs)
return self.embedding_encoder.embed_documents(docs)

View File

@ -126,7 +126,7 @@ def index_to_doc_store(doc_store, docs, retriever, labels=None):
doc_store.write_labels(labels, index=label_index)
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
# See the prepare_data() fn in the retriever benchmark script
if callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
if callable(getattr(retriever, "embed_documents", None)) and docs[0].embedding is None:
doc_store.update_embeddings(retriever, index=doc_index)
def load_config(config_filename, ci):

View File

@ -112,7 +112,7 @@ def test_update_docs(document_store, retriever, batch_size):
# test if correct vectors are associated with docs
for doc in documents_indexed:
original_doc = [d for d in DOCUMENTS if d["content"] == doc.content][0]
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
updated_embedding = retriever.embed_documents([Document.from_dict(original_doc)])
stored_doc = document_store.get_all_documents(filters={"name": [doc.meta["name"]]})[0]
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
@ -328,7 +328,7 @@ def test_faiss_cosine_similarity(tmp_path):
# now check if vectors are normalized when updating embeddings
class MockRetriever():
def embed_passages(self, docs):
def embed_documents(self, docs):
return [np.random.rand(768).astype(np.float32) for doc in docs]
retriever = MockRetriever()