mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-20 14:38:44 +00:00
Milvus integration (#771)
* Initial commit for Milvus integration * Add latest docstring and tutorial changes * Updating implementation of Milvus document store * Add latest docstring and tutorial changes * Adding tests and updating doc string * Add latest docstring and tutorial changes * Fixing issue caught by tests * Addressing review comments * Fixing mypy detected issue * Fixing issue caught in test about sorting of vector ids * fixing test * Fixing generator test failure * update docstrings * Addressing review comments about multiple network call while fetching embedding from milvus server * Add latest docstring and tutorial changes * Ignoring mypy issue while converting vector_id to int Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
parent
6efa4f06c1
commit
9f7f95221f
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@ -17,6 +17,9 @@ jobs:
|
|||||||
- name: Run Elasticsearch
|
- name: Run Elasticsearch
|
||||||
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e "ES_JAVA_OPTS=-Xms128m -Xmx128m" elasticsearch:7.9.2
|
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e "ES_JAVA_OPTS=-Xms128m -Xmx128m" elasticsearch:7.9.2
|
||||||
|
|
||||||
|
- name: Run Milvus
|
||||||
|
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:0.10.5-cpu-d010621-4eda95
|
||||||
|
|
||||||
- name: Run Apache Tika
|
- name: Run Apache Tika
|
||||||
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
||||||
|
|
||||||
|
@ -346,7 +346,7 @@ None
|
|||||||
#### delete\_all\_documents
|
#### delete\_all\_documents
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| delete_all_documents(index: str, filters: Optional[Dict[str, List[str]]] = None)
|
| delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
|
||||||
```
|
```
|
||||||
|
|
||||||
Delete documents in an index. All documents are deleted if no filters are passed.
|
Delete documents in an index. All documents are deleted if no filters are passed.
|
||||||
@ -629,6 +629,7 @@ DocumentStore's default index (self.index) will be used.
|
|||||||
- `filters`: Optional filters to narrow down the documents to return.
|
- `filters`: Optional filters to narrow down the documents to return.
|
||||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
- `return_embedding`: Whether to return the document embeddings.
|
- `return_embedding`: Whether to return the document embeddings.
|
||||||
|
- `batch_size`: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
|
||||||
<a name="sql.SQLDocumentStore.get_all_labels"></a>
|
<a name="sql.SQLDocumentStore.get_all_labels"></a>
|
||||||
#### get\_all\_labels
|
#### get\_all\_labels
|
||||||
@ -763,7 +764,7 @@ the vector embeddings are indexed in a FAISS Index.
|
|||||||
#### \_\_init\_\_
|
#### \_\_init\_\_
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", **kwargs, ,)
|
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", **kwargs, ,)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments**:
|
**Arguments**:
|
||||||
@ -796,6 +797,7 @@ added already exists.
|
|||||||
- `index`: Name of index in document store to use.
|
- `index`: Name of index in document store to use.
|
||||||
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
||||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
||||||
|
- `embedding_field`: Name of field containing an embedding vector.
|
||||||
|
|
||||||
<a name="faiss.FAISSDocumentStore.write_documents"></a>
|
<a name="faiss.FAISSDocumentStore.write_documents"></a>
|
||||||
#### write\_documents
|
#### write\_documents
|
||||||
@ -881,7 +883,7 @@ None
|
|||||||
#### delete\_all\_documents
|
#### delete\_all\_documents
|
||||||
|
|
||||||
```python
|
```python
|
||||||
| delete_all_documents(index=None)
|
| delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
|
||||||
```
|
```
|
||||||
|
|
||||||
Delete all documents from the document store.
|
Delete all documents from the document store.
|
||||||
|
@ -198,6 +198,6 @@ class BaseDocumentStore(ABC):
|
|||||||
logger.error("File needs to be in json or jsonl format.")
|
logger.error("File needs to be in json or jsonl format.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
|
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -757,7 +757,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
|
||||||
|
|
||||||
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
|
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||||
"""
|
"""
|
||||||
Delete documents in an index. All documents are deleted if no filters are passed.
|
Delete documents in an index. All documents are deleted if no filters are passed.
|
||||||
|
|
||||||
@ -765,6 +765,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
:param filters: Optional filters to narrow down the documents to be deleted.
|
:param filters: Optional filters to narrow down the documents to be deleted.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
index = index or self.index
|
||||||
query: Dict[str, Any] = {"query": {}}
|
query: Dict[str, Any] = {"query": {}}
|
||||||
if filters:
|
if filters:
|
||||||
filter_clause = []
|
filter_clause = []
|
||||||
|
@ -41,6 +41,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
update_existing_documents: bool = False,
|
update_existing_documents: bool = False,
|
||||||
index: str = "document",
|
index: str = "document",
|
||||||
similarity: str = "dot_product",
|
similarity: str = "dot_product",
|
||||||
|
embedding_field: str = "embedding",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -72,6 +73,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
:param index: Name of index in document store to use.
|
:param index: Name of index in document store to use.
|
||||||
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
|
||||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
||||||
|
:param embedding_field: Name of field containing an embedding vector.
|
||||||
"""
|
"""
|
||||||
self.vector_dim = vector_dim
|
self.vector_dim = vector_dim
|
||||||
|
|
||||||
@ -83,6 +85,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
||||||
|
|
||||||
self.return_embedding = return_embedding
|
self.return_embedding = return_embedding
|
||||||
|
self.embedding_field = embedding_field
|
||||||
if similarity == "dot_product":
|
if similarity == "dot_product":
|
||||||
self.similarity = similarity
|
self.similarity = similarity
|
||||||
else:
|
else:
|
||||||
@ -154,7 +157,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
|
|
||||||
def _create_document_field_map(self) -> Dict:
|
def _create_document_field_map(self) -> Dict:
|
||||||
return {
|
return {
|
||||||
self.index: "embedding",
|
self.index: self.embedding_field,
|
||||||
}
|
}
|
||||||
|
|
||||||
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
|
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
@ -275,13 +278,13 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
embeddings = np.array(embeddings, dtype="float32")
|
embeddings = np.array(embeddings, dtype="float32")
|
||||||
self.faiss_index.train(embeddings)
|
self.faiss_index.train(embeddings)
|
||||||
|
|
||||||
def delete_all_documents(self, index=None):
|
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||||
"""
|
"""
|
||||||
Delete all documents from the document store.
|
Delete all documents from the document store.
|
||||||
"""
|
"""
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
self.faiss_index.reset()
|
self.faiss_index.reset()
|
||||||
super().delete_all_documents(index=index)
|
super().delete_all_documents(index=index, filters=filters)
|
||||||
|
|
||||||
def query_by_embedding(self,
|
def query_by_embedding(self,
|
||||||
query_emb: np.array,
|
query_emb: np.array,
|
||||||
|
498
haystack/document_store/milvus.py
Normal file
498
haystack/document_store/milvus.py
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from milvus import IndexType, MetricType, Milvus, Status
|
||||||
|
from scipy.special import expit
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from haystack import Document
|
||||||
|
from haystack.document_store.sql import SQLDocumentStore
|
||||||
|
from haystack.retriever.base import BaseRetriever
|
||||||
|
from haystack.utils import get_batches_from_generator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MilvusDocumentStore(SQLDocumentStore):
|
||||||
|
"""
|
||||||
|
Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors.
|
||||||
|
Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR).
|
||||||
|
In contrast to FAISS, Milvus ...
|
||||||
|
- runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment
|
||||||
|
- allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index)
|
||||||
|
- encapsulates multiple ANN libraries (FAISS, ANNOY ...)
|
||||||
|
|
||||||
|
This class uses Milvus for all vector related storage, processing and querying.
|
||||||
|
The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus
|
||||||
|
does not allow these data types (yet).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md)
|
||||||
|
2. Init a MilvusDocumentStore in Haystack
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sql_url: str = "sqlite:///",
|
||||||
|
milvus_url: str = "tcp://localhost:19530",
|
||||||
|
connection_pool: str = "SingletonThread",
|
||||||
|
index: str = "document",
|
||||||
|
vector_dim: int = 768,
|
||||||
|
index_file_size: int = 1024,
|
||||||
|
similarity: str = "dot_product",
|
||||||
|
index_type: IndexType = IndexType.FLAT,
|
||||||
|
index_param: Optional[Dict[str, Any]] = None,
|
||||||
|
search_param: Optional[Dict[str, Any]] = None,
|
||||||
|
update_existing_documents: bool = False,
|
||||||
|
return_embedding: bool = False,
|
||||||
|
embedding_field: str = "embedding",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
|
||||||
|
deployment, Postgres is recommended. If using MySQL then same server can also be used for
|
||||||
|
Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md.
|
||||||
|
:param milvus_url: Milvus server connection URL for storing and processing vectors.
|
||||||
|
Protocol, host and port will automatically be inferred from the URL.
|
||||||
|
See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance.
|
||||||
|
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
|
||||||
|
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
|
||||||
|
:param vector_dim: The embedding vector size. Default: 768.
|
||||||
|
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
|
||||||
|
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
|
||||||
|
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
|
||||||
|
As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048.
|
||||||
|
Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory.
|
||||||
|
(From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size)
|
||||||
|
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings.
|
||||||
|
'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus.
|
||||||
|
However, you can normalize your embeddings and use `dot_product` to get the same results.
|
||||||
|
See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating.
|
||||||
|
:param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy.
|
||||||
|
Some popular options:
|
||||||
|
- FLAT (default): Exact method, slow
|
||||||
|
- IVF_FLAT, inverted file based heuristic, fast
|
||||||
|
- HSNW: Graph based, fast
|
||||||
|
- ANNOY: Tree based, fast
|
||||||
|
See: https://milvus.io/docs/v0.10.5/index.md
|
||||||
|
:param index_param: Configuration parameters for the chose index_type needed at indexing time.
|
||||||
|
For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT.
|
||||||
|
See https://milvus.io/docs/v0.10.5/index.md
|
||||||
|
:param search_param: Configuration parameters for the chose index_type needed at query time
|
||||||
|
For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT.
|
||||||
|
See https://milvus.io/docs/v0.10.5/index.md
|
||||||
|
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
|
||||||
|
documents. When set as True, any document with an existing ID gets updated.
|
||||||
|
If set to False, an error is raised if the document ID of the document being
|
||||||
|
added already exists.
|
||||||
|
:param return_embedding: To return document embedding.
|
||||||
|
:param embedding_field: Name of field containing an embedding vector.
|
||||||
|
"""
|
||||||
|
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
|
||||||
|
self.vector_dim = vector_dim
|
||||||
|
self.index_file_size = index_file_size
|
||||||
|
|
||||||
|
if similarity == "dot_product":
|
||||||
|
self.metric_type = MetricType.L2
|
||||||
|
else:
|
||||||
|
raise ValueError("The Milvus document store can currently only support dot_product similarity. "
|
||||||
|
"Please set similarity=\"dot_product\"")
|
||||||
|
|
||||||
|
self.index_type = index_type
|
||||||
|
self.index_param = index_param or {"nlist": 16384}
|
||||||
|
self.search_param = search_param or {"nprobe": 10}
|
||||||
|
self.index = index
|
||||||
|
self._create_collection_and_index_if_not_exist(self.index)
|
||||||
|
self.return_embedding = return_embedding
|
||||||
|
self.embedding_field = embedding_field
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
url=sql_url,
|
||||||
|
update_existing_documents=update_existing_documents,
|
||||||
|
index=index
|
||||||
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
return self.milvus_server.close()
|
||||||
|
|
||||||
|
def _create_collection_and_index_if_not_exist(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
index_param: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
index = index or self.index
|
||||||
|
index_param = index_param or self.index_param
|
||||||
|
|
||||||
|
status, ok = self.milvus_server.has_collection(collection_name=index)
|
||||||
|
if not ok:
|
||||||
|
collection_param = {
|
||||||
|
'collection_name': index,
|
||||||
|
'dimension': self.vector_dim,
|
||||||
|
'index_file_size': self.index_file_size,
|
||||||
|
'metric_type': self.metric_type
|
||||||
|
}
|
||||||
|
|
||||||
|
status = self.milvus_server.create_collection(collection_param)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Collection creation on Milvus server failed: {status}')
|
||||||
|
|
||||||
|
status = self.milvus_server.create_index(index, self.index_type, index_param)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Index creation on Milvus server failed: {status}')
|
||||||
|
|
||||||
|
def _create_document_field_map(self) -> Dict:
|
||||||
|
return {
|
||||||
|
self.index: self.embedding_field,
|
||||||
|
}
|
||||||
|
|
||||||
|
def write_documents(
|
||||||
|
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add new documents to the DocumentStore.
|
||||||
|
|
||||||
|
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
|
||||||
|
them right away in Milvus. If not, you can later call update_embeddings() to create & index them.
|
||||||
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
self._create_collection_and_index_if_not_exist(index)
|
||||||
|
field_map = self._create_document_field_map()
|
||||||
|
|
||||||
|
if len(documents) == 0:
|
||||||
|
logger.warning("Calling DocumentStore.write_documents() with empty list")
|
||||||
|
return
|
||||||
|
|
||||||
|
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
|
||||||
|
|
||||||
|
add_vectors = False if document_objects[0].embedding is None else True
|
||||||
|
|
||||||
|
batched_documents = get_batches_from_generator(document_objects, batch_size)
|
||||||
|
with tqdm(total=len(document_objects)) as progress_bar:
|
||||||
|
for document_batch in batched_documents:
|
||||||
|
vector_ids = []
|
||||||
|
if add_vectors:
|
||||||
|
doc_ids = []
|
||||||
|
embeddings = []
|
||||||
|
for doc in document_batch:
|
||||||
|
doc_ids.append(doc.id)
|
||||||
|
if isinstance(doc.embedding, np.ndarray):
|
||||||
|
embeddings.append(doc.embedding.tolist())
|
||||||
|
elif isinstance(doc.embedding, list):
|
||||||
|
embeddings.append(doc.embedding)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f'Format of supplied document embedding {type(doc.embedding)} is not '
|
||||||
|
f'supported. Please use list or numpy.ndarray')
|
||||||
|
|
||||||
|
if self.update_existing_documents:
|
||||||
|
existing_docs = super().get_documents_by_id(ids=doc_ids, index=index)
|
||||||
|
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index)
|
||||||
|
|
||||||
|
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Vector embedding insertion failed: {status}')
|
||||||
|
|
||||||
|
docs_to_write_in_sql = []
|
||||||
|
for idx, doc in enumerate(document_batch):
|
||||||
|
meta = doc.meta
|
||||||
|
if add_vectors:
|
||||||
|
meta["vector_id"] = vector_ids[idx]
|
||||||
|
docs_to_write_in_sql.append(doc)
|
||||||
|
|
||||||
|
super().write_documents(docs_to_write_in_sql, index=index)
|
||||||
|
progress_bar.update(batch_size)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
self.milvus_server.flush([index])
|
||||||
|
if self.update_existing_documents:
|
||||||
|
self.milvus_server.compact(collection_name=index)
|
||||||
|
|
||||||
|
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
|
||||||
|
"""
|
||||||
|
Updates the embeddings in the the document store using the encoding model specified in the retriever.
|
||||||
|
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
|
||||||
|
|
||||||
|
:param retriever: Retriever to use to get embeddings for text
|
||||||
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
self._create_collection_and_index_if_not_exist(index)
|
||||||
|
|
||||||
|
document_count = self.get_document_count(index=index)
|
||||||
|
if document_count == 0:
|
||||||
|
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Updating embeddings for {document_count} docs...")
|
||||||
|
|
||||||
|
result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
|
||||||
|
batched_documents = get_batches_from_generator(result, batch_size)
|
||||||
|
with tqdm(total=document_count) as progress_bar:
|
||||||
|
for document_batch in batched_documents:
|
||||||
|
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
|
||||||
|
|
||||||
|
embeddings = retriever.embed_passages(document_batch) # type: ignore
|
||||||
|
embeddings_list = [embedding.tolist() for embedding in embeddings]
|
||||||
|
assert len(document_batch) == len(embeddings_list)
|
||||||
|
|
||||||
|
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Vector embedding insertion failed: {status}')
|
||||||
|
|
||||||
|
vector_id_map = {}
|
||||||
|
for vector_id, doc in zip(vector_ids, document_batch):
|
||||||
|
vector_id_map[doc.id] = vector_id
|
||||||
|
|
||||||
|
self.update_vector_ids(vector_id_map, index=index)
|
||||||
|
progress_bar.update(batch_size)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
self.milvus_server.flush([index])
|
||||||
|
self.milvus_server.compact(collection_name=index)
|
||||||
|
|
||||||
|
def query_by_embedding(self,
|
||||||
|
query_emb: np.array,
|
||||||
|
filters: Optional[dict] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
return_embedding: Optional[bool] = None) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
|
||||||
|
|
||||||
|
:param query_emb: Embedding of the query (e.g. gathered from DPR)
|
||||||
|
:param filters: Optional filters to narrow down the search space.
|
||||||
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
|
:param top_k: How many documents to return
|
||||||
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param return_embedding: To return document embedding
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if filters:
|
||||||
|
raise Exception("Query filters are not implemented for the MilvusDocumentStore.")
|
||||||
|
|
||||||
|
index = index or self.index
|
||||||
|
status, ok = self.milvus_server.has_collection(collection_name=index)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Milvus has collection check failed: {status}')
|
||||||
|
if not ok:
|
||||||
|
raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
|
||||||
|
|
||||||
|
if return_embedding is None:
|
||||||
|
return_embedding = self.return_embedding
|
||||||
|
index = index or self.index
|
||||||
|
|
||||||
|
query_emb = query_emb.reshape(1, -1).astype(np.float32)
|
||||||
|
status, search_result = self.milvus_server.search(
|
||||||
|
collection_name=index,
|
||||||
|
query_records=query_emb,
|
||||||
|
top_k=top_k,
|
||||||
|
params=self.search_param
|
||||||
|
)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Vector embedding search failed: {status}')
|
||||||
|
|
||||||
|
vector_ids_for_query = []
|
||||||
|
scores_for_vector_ids: Dict[str, float] = {}
|
||||||
|
for vector_id_list, distance_list in zip(search_result.id_array, search_result.distance_array):
|
||||||
|
for vector_id, distance in zip(vector_id_list, distance_list):
|
||||||
|
vector_ids_for_query.append(str(vector_id))
|
||||||
|
scores_for_vector_ids[str(vector_id)] = distance
|
||||||
|
|
||||||
|
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
|
||||||
|
|
||||||
|
if return_embedding:
|
||||||
|
self._populate_embeddings_to_docs(index=index, docs=documents)
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
|
||||||
|
doc.probability = float(expit(np.asarray(doc.score / 100)))
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
|
||||||
|
"""
|
||||||
|
Delete all documents (from SQL AND Milvus).
|
||||||
|
:param index: (SQL) index name for storing the docs and metadata
|
||||||
|
:param filters: Optional filters to narrow down the search space.
|
||||||
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
super().delete_all_documents(index=index, filters=filters)
|
||||||
|
status, ok = self.milvus_server.has_collection(collection_name=index)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Milvus has collection check failed: {status}')
|
||||||
|
if ok:
|
||||||
|
status = self.milvus_server.drop_collection(collection_name=index)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Milvus drop collection failed: {status}')
|
||||||
|
|
||||||
|
self.milvus_server.flush([index])
|
||||||
|
self.milvus_server.compact(collection_name=index)
|
||||||
|
|
||||||
|
def get_all_documents_generator(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> Generator[Document, None, None]:
|
||||||
|
"""
|
||||||
|
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
|
||||||
|
document store and yielded as individual documents. This method can be used to iteratively process
|
||||||
|
a large number of documents without having to load all documents in memory.
|
||||||
|
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
documents = super().get_all_documents_generator(
|
||||||
|
index=index, filters=filters, batch_size=batch_size
|
||||||
|
)
|
||||||
|
if return_embedding is None:
|
||||||
|
return_embedding = self.return_embedding
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if return_embedding:
|
||||||
|
self._populate_embeddings_to_docs(index=index, docs=[doc])
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
def get_all_documents(
|
||||||
|
self,
|
||||||
|
index: Optional[str] = None,
|
||||||
|
filters: Optional[Dict[str, List[str]]] = None,
|
||||||
|
return_embedding: Optional[bool] = None,
|
||||||
|
batch_size: int = 10_000,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Get documents from the document store (optionally using filter criteria).
|
||||||
|
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index = index or self.index
|
||||||
|
result = self.get_all_documents_generator(
|
||||||
|
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
|
||||||
|
)
|
||||||
|
documents = list(result)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||||
|
"""
|
||||||
|
Fetch a document by specifying its text id string
|
||||||
|
|
||||||
|
:param id: ID of the document
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
"""
|
||||||
|
documents = self.get_documents_by_id([id], index)
|
||||||
|
document = documents[0] if documents else None
|
||||||
|
return document
|
||||||
|
|
||||||
|
def get_documents_by_id(
|
||||||
|
self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Fetch multiple documents by specifying their IDs (strings)
|
||||||
|
|
||||||
|
:param ids: List of IDs of the documents
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
documents = super().get_documents_by_id(ids=ids, index=index)
|
||||||
|
if self.return_embedding:
|
||||||
|
self._populate_embeddings_to_docs(index=index, docs=documents)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _populate_embeddings_to_docs(self, docs: List[Document], index: Optional[str] = None):
|
||||||
|
index = index or self.index
|
||||||
|
docs_with_vector_ids = []
|
||||||
|
for doc in docs:
|
||||||
|
if doc.meta and doc.meta.get("vector_id") is not None:
|
||||||
|
docs_with_vector_ids.append(doc)
|
||||||
|
|
||||||
|
if len(docs_with_vector_ids) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
ids = [int(doc.meta.get("vector_id")) for doc in docs_with_vector_ids] # type: ignore
|
||||||
|
status, vector_embeddings = self.milvus_server.get_entity_by_id(
|
||||||
|
collection_name=index,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError(f'Getting vector embedding by id failed: {status}')
|
||||||
|
|
||||||
|
for embedding, doc in zip(vector_embeddings, docs_with_vector_ids):
|
||||||
|
doc.embedding = numpy.array(embedding, dtype="float32")
|
||||||
|
|
||||||
|
def _delete_vector_ids_from_milvus(self, documents: List[Document], index: Optional[str] = None):
|
||||||
|
index = index or self.index
|
||||||
|
existing_vector_ids = []
|
||||||
|
for doc in documents:
|
||||||
|
if "vector_id" in doc.meta:
|
||||||
|
existing_vector_ids.append(int(doc.meta["vector_id"]))
|
||||||
|
if len(existing_vector_ids) > 0:
|
||||||
|
status = self.milvus_server.delete_entity_by_id(
|
||||||
|
collection_name=index,
|
||||||
|
id_array=existing_vector_ids
|
||||||
|
)
|
||||||
|
if status.code != Status.SUCCESS:
|
||||||
|
raise RuntimeError("E existing vector ids deletion failed: {status}")
|
||||||
|
|
||||||
|
def get_all_vectors(self, index=None) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Helper function to dump all vectors stored in Milvus server.
|
||||||
|
|
||||||
|
:param index: Name of the index to get the documents from. If None, the
|
||||||
|
DocumentStore's default index (self.index) will be used.
|
||||||
|
:return: List[np.array]: List of vectors.
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
status, collection_info = self.milvus_server.get_collection_stats(collection_name=index)
|
||||||
|
if not status.OK():
|
||||||
|
logger.info(f"Failed fetch stats from store ...")
|
||||||
|
return list()
|
||||||
|
|
||||||
|
logger.debug(f"collection_info = {collection_info}")
|
||||||
|
|
||||||
|
ids = list()
|
||||||
|
partition_list = collection_info["partitions"]
|
||||||
|
for partition in partition_list:
|
||||||
|
segment_list = partition["segments"]
|
||||||
|
for segment in segment_list:
|
||||||
|
segment_name = segment["name"]
|
||||||
|
status, id_list = self.milvus_server.list_id_in_segment(
|
||||||
|
collection_name=index,
|
||||||
|
segment_name=segment_name)
|
||||||
|
logger.debug(f"{status}: segment {segment_name} has {len(id_list)} vectors ...")
|
||||||
|
ids.extend(id_list)
|
||||||
|
|
||||||
|
if len(ids) == 0:
|
||||||
|
logger.info(f"No documents in the store ...")
|
||||||
|
return list()
|
||||||
|
|
||||||
|
status, vectors = self.milvus_server.get_entity_by_id(collection_name=index, ids=ids)
|
||||||
|
if not status.OK():
|
||||||
|
logger.info(f"Failed fetch document for ids {ids} from store ...")
|
||||||
|
return list()
|
||||||
|
|
||||||
|
return vectors
|
@ -165,6 +165,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
:param filters: Optional filters to narrow down the documents to return.
|
:param filters: Optional filters to narrow down the documents to return.
|
||||||
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||||
:param return_embedding: Whether to return the document embeddings.
|
:param return_embedding: Whether to return the document embeddings.
|
||||||
|
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
@ -408,7 +409,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
raise NotImplementedError("Delete by filters is not implemented for SQLDocumentStore.")
|
raise NotImplementedError(f"Delete by filters is not implemented for {type(self).__name__}")
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
documents = self.session.query(DocumentORM).filter_by(index=index)
|
documents = self.session.query(DocumentORM).filter_by(index=index)
|
||||||
documents.delete(synchronize_session=False)
|
documents.delete(synchronize_session=False)
|
||||||
|
@ -176,7 +176,7 @@ class RAGenerator(BaseGenerator):
|
|||||||
embeddings = self.retriever.embed_passages(docs)
|
embeddings = self.retriever.embed_passages(docs)
|
||||||
|
|
||||||
embeddings_in_tensor = torch.cat(
|
embeddings_in_tensor = torch.cat(
|
||||||
[torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
|
[torch.from_numpy(embedding).float().unsqueeze(0) for embedding in embeddings],
|
||||||
dim=0
|
dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,3 +23,5 @@ httptools
|
|||||||
nltk
|
nltk
|
||||||
more_itertools
|
more_itertools
|
||||||
networkx
|
networkx
|
||||||
|
# Refer milvus version support matrix at https://github.com/milvus-io/pymilvus#install-pymilvus
|
||||||
|
pymilvus
|
@ -6,6 +6,9 @@ from sys import platform
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
from milvus import Milvus
|
||||||
|
|
||||||
|
from haystack.document_store.milvus import MilvusDocumentStore
|
||||||
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
from haystack.generator.transformers import RAGenerator, RAGeneratorType
|
||||||
|
|
||||||
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
|
||||||
@ -78,6 +81,20 @@ def elasticsearch_fixture():
|
|||||||
time.sleep(30)
|
time.sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def milvus_fixture():
|
||||||
|
# test if a Milvus server is already running. If not, start Milvus docker container locally.
|
||||||
|
# Make sure you have given > 6GB memory to docker engine
|
||||||
|
try:
|
||||||
|
milvus_server = Milvus(uri="tcp://localhost:19530", timeout=5, wait_timeout=5)
|
||||||
|
milvus_server.server_status(timeout=5)
|
||||||
|
except:
|
||||||
|
print("Starting Milvus ...")
|
||||||
|
status = subprocess.run(['docker run -d --name milvus_cpu_0.10.5 -p 19530:19530 -p 19121:19121 '
|
||||||
|
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
|
||||||
|
time.sleep(40)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tika_fixture():
|
def tika_fixture():
|
||||||
try:
|
try:
|
||||||
@ -245,21 +262,19 @@ def get_retriever(retriever_type, document_store):
|
|||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
|
||||||
def document_store_with_docs(request, test_docs_xs):
|
def document_store_with_docs(request, test_docs_xs):
|
||||||
document_store = get_document_store(request.param)
|
document_store = get_document_store(request.param)
|
||||||
document_store.write_documents(test_docs_xs)
|
document_store.write_documents(test_docs_xs)
|
||||||
yield document_store
|
yield document_store
|
||||||
if request.param == "faiss":
|
document_store.delete_all_documents()
|
||||||
document_store.faiss_index.reset()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
|
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
|
||||||
def document_store(request, test_docs_xs):
|
def document_store(request, test_docs_xs):
|
||||||
document_store = get_document_store(request.param)
|
document_store = get_document_store(request.param)
|
||||||
yield document_store
|
yield document_store
|
||||||
if request.param == "faiss":
|
document_store.delete_all_documents()
|
||||||
document_store.faiss_index.reset()
|
|
||||||
|
|
||||||
|
|
||||||
def get_document_store(document_store_type, embedding_field="embedding"):
|
def get_document_store(document_store_type, embedding_field="embedding"):
|
||||||
@ -284,6 +299,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
|||||||
index="haystack_test",
|
index="haystack_test",
|
||||||
)
|
)
|
||||||
return document_store
|
return document_store
|
||||||
|
elif document_store_type == "milvus":
|
||||||
|
document_store = MilvusDocumentStore(
|
||||||
|
sql_url="sqlite://",
|
||||||
|
return_embedding=True,
|
||||||
|
embedding_field=embedding_field,
|
||||||
|
index="haystack_test",
|
||||||
|
)
|
||||||
|
return document_store
|
||||||
else:
|
else:
|
||||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ def test_get_all_documents_generator(document_store):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss", "milvus"], indirect=True)
|
||||||
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||||
def test_update_existing_documents(document_store, update_existing_documents):
|
def test_update_existing_documents(document_store, update_existing_documents):
|
||||||
original_docs = [
|
original_docs = [
|
||||||
@ -177,7 +177,7 @@ def test_write_document_index(document_store):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||||
def test_document_with_embeddings(document_store):
|
def test_document_with_embeddings(document_store):
|
||||||
documents = [
|
documents = [
|
||||||
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
|
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||||
@ -196,7 +196,7 @@ def test_document_with_embeddings(document_store):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||||
def test_update_embeddings(document_store, retriever):
|
def test_update_embeddings(document_store, retriever):
|
||||||
documents = []
|
documents = []
|
||||||
for i in range(23):
|
for i in range(23):
|
||||||
@ -232,17 +232,17 @@ def test_update_embeddings(document_store, retriever):
|
|||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
def test_delete_all_documents(document_store_with_docs):
|
def test_delete_all_documents(document_store_with_docs):
|
||||||
assert len(document_store_with_docs.get_all_documents(index="haystack_test")) == 3
|
assert len(document_store_with_docs.get_all_documents()) == 3
|
||||||
|
|
||||||
document_store_with_docs.delete_all_documents(index="haystack_test")
|
document_store_with_docs.delete_all_documents()
|
||||||
documents = document_store_with_docs.get_all_documents(index="haystack_test")
|
documents = document_store_with_docs.get_all_documents()
|
||||||
assert len(documents) == 0
|
assert len(documents) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
def test_delete_documents_with_filters(document_store_with_docs):
|
def test_delete_documents_with_filters(document_store_with_docs):
|
||||||
document_store_with_docs.delete_all_documents(index="haystack_test", filters={"meta_field": ["test1", "test2"]})
|
document_store_with_docs.delete_all_documents(filters={"meta_field": ["test1", "test2"]})
|
||||||
documents = document_store_with_docs.get_all_documents()
|
documents = document_store_with_docs.get_all_documents()
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
assert documents[0].meta["meta_field"] == "test3"
|
assert documents[0].meta["meta_field"] == "test3"
|
||||||
|
@ -4,13 +4,14 @@ import numpy as np
|
|||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
from haystack.document_store.faiss import FAISSDocumentStore
|
from haystack.document_store.faiss import FAISSDocumentStore
|
||||||
|
from haystack.document_store.milvus import MilvusDocumentStore
|
||||||
from haystack.retriever.dense import DensePassageRetriever
|
from haystack.retriever.dense import DensePassageRetriever
|
||||||
|
|
||||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("return_embedding", [True, False])
|
@pytest.mark.parametrize("return_embedding", [True, False])
|
||||||
def test_dpr_retrieval(document_store, retriever, return_embedding):
|
def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||||
@ -71,7 +72,7 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
|||||||
assert res[0].embedding is None
|
assert res[0].embedding is None
|
||||||
|
|
||||||
# test filtering
|
# test filtering
|
||||||
if not isinstance(document_store, FAISSDocumentStore):
|
if not isinstance(document_store, FAISSDocumentStore) and not isinstance(document_store, MilvusDocumentStore):
|
||||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
|
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
|
||||||
assert len(res) == 2
|
assert len(res) == 2
|
||||||
for r in res:
|
for r in res:
|
||||||
|
@ -4,7 +4,7 @@ from haystack import Finder
|
|||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
def test_embedding_retriever(retriever, document_store):
|
def test_embedding_retriever(retriever, document_store):
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -63,9 +64,9 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
|
|||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||||
@pytest.mark.parametrize("batch_size", [4, 6])
|
@pytest.mark.parametrize("batch_size", [4, 6])
|
||||||
def test_faiss_update_docs(document_store, retriever, batch_size):
|
def test_update_docs(document_store, retriever, batch_size):
|
||||||
# initial write
|
# initial write
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
|
|
||||||
@ -82,9 +83,35 @@ def test_faiss_update_docs(document_store, retriever, batch_size):
|
|||||||
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
|
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["milvus", "faiss"], indirect=True)
|
||||||
def test_faiss_update_with_empty_store(document_store, retriever):
|
def test_update_exiting_docs(document_store, retriever):
|
||||||
|
document_store.update_existing_documents = True
|
||||||
|
old_document = Document(text="text_1")
|
||||||
|
# initial write
|
||||||
|
document_store.write_documents([old_document])
|
||||||
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
old_documents_indexed = document_store.get_all_documents()
|
||||||
|
assert len(old_documents_indexed) == 1
|
||||||
|
|
||||||
|
# Update document data
|
||||||
|
new_document = Document(text="text_2")
|
||||||
|
new_document.id = old_document.id
|
||||||
|
document_store.write_documents([new_document])
|
||||||
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
new_documents_indexed = document_store.get_all_documents()
|
||||||
|
assert len(new_documents_indexed) == 1
|
||||||
|
|
||||||
|
assert old_documents_indexed[0].id == new_documents_indexed[0].id
|
||||||
|
assert old_documents_indexed[0].text == "text_1"
|
||||||
|
assert new_documents_indexed[0].text == "text_2"
|
||||||
|
assert not np.allclose(old_documents_indexed[0].embedding, new_documents_indexed[0].embedding, rtol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||||
|
def test_update_with_empty_store(document_store, retriever):
|
||||||
# Call update with empty doc store
|
# Call update with empty doc store
|
||||||
document_store.update_embeddings(retriever=retriever)
|
document_store.update_embeddings(retriever=retriever)
|
||||||
|
|
||||||
@ -125,8 +152,8 @@ def test_faiss_retrieving(index_factory):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||||
def test_faiss_finding(document_store, retriever):
|
def test_finding(document_store, retriever):
|
||||||
document_store.write_documents(DOCUMENTS)
|
document_store.write_documents(DOCUMENTS)
|
||||||
finder = Finder(reader=None, retriever=retriever)
|
finder = Finder(reader=None, retriever=retriever)
|
||||||
|
|
||||||
@ -136,8 +163,8 @@ def test_faiss_finding(document_store, retriever):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||||
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||||
def test_faiss_pipeline(document_store, retriever):
|
def test_pipeline(document_store, retriever):
|
||||||
documents = [
|
documents = [
|
||||||
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||||
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
@ -415,7 +415,7 @@ def test_rag_token_generator(rag_generator):
|
|||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"retriever,document_store",
|
"retriever,document_store",
|
||||||
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
|
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_generator_pipeline(document_store, retriever, rag_generator):
|
def test_generator_pipeline(document_store, retriever, rag_generator):
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
def test_module_imports():
|
def test_module_imports():
|
||||||
from haystack import Finder
|
from haystack import Finder
|
||||||
from haystack.document_store.sql import SQLDocumentStore
|
from haystack.document_store.sql import SQLDocumentStore
|
||||||
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
|
from haystack.document_store.faiss import FAISSDocumentStore
|
||||||
|
from haystack.document_store.milvus import MilvusDocumentStore
|
||||||
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack.preprocessor.cleaning import clean_wiki_text
|
from haystack.preprocessor.cleaning import clean_wiki_text
|
||||||
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
|
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
|
||||||
from haystack.reader.farm import FARMReader
|
from haystack.reader.farm import FARMReader
|
||||||
@ -10,6 +14,10 @@ def test_module_imports():
|
|||||||
|
|
||||||
assert Finder is not None
|
assert Finder is not None
|
||||||
assert SQLDocumentStore is not None
|
assert SQLDocumentStore is not None
|
||||||
|
assert ElasticsearchDocumentStore is not None
|
||||||
|
assert FAISSDocumentStore is not None
|
||||||
|
assert MilvusDocumentStore is not None
|
||||||
|
assert BaseDocumentStore is not None
|
||||||
assert clean_wiki_text is not None
|
assert clean_wiki_text is not None
|
||||||
assert convert_files_to_dicts is not None
|
assert convert_files_to_dicts is not None
|
||||||
assert fetch_archive_from_http is not None
|
assert fetch_archive_from_http is not None
|
||||||
|
@ -67,7 +67,7 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs):
|
|||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"retriever,document_store",
|
"retriever,document_store",
|
||||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
|
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_faq_pipeline(retriever, document_store):
|
def test_faq_pipeline(retriever, document_store):
|
||||||
@ -97,7 +97,7 @@ def test_faq_pipeline(retriever, document_store):
|
|||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"retriever,document_store",
|
"retriever,document_store",
|
||||||
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
|
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_document_search_pipeline(retriever, document_store):
|
def test_document_search_pipeline(retriever, document_store):
|
||||||
|
@ -57,7 +57,7 @@ def test_summarization_one_summary(summarizer):
|
|||||||
@pytest.mark.summarizer
|
@pytest.mark.summarizer
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"retriever,document_store",
|
"retriever,document_store",
|
||||||
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
|
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_summarization_pipeline(document_store, retriever, summarizer):
|
def test_summarization_pipeline(document_store, retriever, summarizer):
|
||||||
@ -79,7 +79,7 @@ def test_summarization_pipeline(document_store, retriever, summarizer):
|
|||||||
@pytest.mark.summarizer
|
@pytest.mark.summarizer
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"retriever,document_store",
|
"retriever,document_store",
|
||||||
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
|
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):
|
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user