diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 60e8ab359..e88362125 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -17,6 +17,9 @@ jobs:
- name: Run Elasticsearch
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e "ES_JAVA_OPTS=-Xms128m -Xmx128m" elasticsearch:7.9.2
+ - name: Run Milvus
+ run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:0.10.5-cpu-d010621-4eda95
+
- name: Run Apache Tika
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md
index 116fc7908..3389b808d 100644
--- a/docs/_src/api/api/document_store.md
+++ b/docs/_src/api/api/document_store.md
@@ -346,7 +346,7 @@ None
#### delete\_all\_documents
```python
- | delete_all_documents(index: str, filters: Optional[Dict[str, List[str]]] = None)
+ | delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
```
Delete documents in an index. All documents are deleted if no filters are passed.
@@ -629,6 +629,7 @@ DocumentStore's default index (self.index) will be used.
- `filters`: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
- `return_embedding`: Whether to return the document embeddings.
+- `batch_size`: When working with large number of documents, batching can help reduce memory footprint.
#### get\_all\_labels
@@ -763,7 +764,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_
```python
- | __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", **kwargs, ,)
+ | __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", **kwargs, ,)
```
**Arguments**:
@@ -796,6 +797,7 @@ added already exists.
- `index`: Name of index in document store to use.
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
+- `embedding_field`: Name of field containing an embedding vector.
#### write\_documents
@@ -881,7 +883,7 @@ None
#### delete\_all\_documents
```python
- | delete_all_documents(index=None)
+ | delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
```
Delete all documents from the document store.
diff --git a/haystack/document_store/base.py b/haystack/document_store/base.py
index dcba2dee6..58bbc7440 100644
--- a/haystack/document_store/base.py
+++ b/haystack/document_store/base.py
@@ -198,6 +198,6 @@ class BaseDocumentStore(ABC):
logger.error("File needs to be in json or jsonl format.")
@abstractmethod
- def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
+ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
pass
diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py
index ca18e87d3..6c6649f4e 100644
--- a/haystack/document_store/elasticsearch.py
+++ b/haystack/document_store/elasticsearch.py
@@ -757,7 +757,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
- def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
+ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete documents in an index. All documents are deleted if no filters are passed.
@@ -765,6 +765,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
:param filters: Optional filters to narrow down the documents to be deleted.
:return: None
"""
+ index = index or self.index
query: Dict[str, Any] = {"query": {}}
if filters:
filter_clause = []
diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py
index e21b41dea..d0df99624 100644
--- a/haystack/document_store/faiss.py
+++ b/haystack/document_store/faiss.py
@@ -41,6 +41,7 @@ class FAISSDocumentStore(SQLDocumentStore):
update_existing_documents: bool = False,
index: str = "document",
similarity: str = "dot_product",
+ embedding_field: str = "embedding",
**kwargs,
):
"""
@@ -72,6 +73,7 @@ class FAISSDocumentStore(SQLDocumentStore):
:param index: Name of index in document store to use.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
+ :param embedding_field: Name of field containing an embedding vector.
"""
self.vector_dim = vector_dim
@@ -83,6 +85,7 @@ class FAISSDocumentStore(SQLDocumentStore):
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
self.return_embedding = return_embedding
+ self.embedding_field = embedding_field
if similarity == "dot_product":
self.similarity = similarity
else:
@@ -154,7 +157,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def _create_document_field_map(self) -> Dict:
return {
- self.index: "embedding",
+ self.index: self.embedding_field,
}
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
@@ -275,13 +278,13 @@ class FAISSDocumentStore(SQLDocumentStore):
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.train(embeddings)
- def delete_all_documents(self, index=None):
+ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete all documents from the document store.
"""
index = index or self.index
self.faiss_index.reset()
- super().delete_all_documents(index=index)
+ super().delete_all_documents(index=index, filters=filters)
def query_by_embedding(self,
query_emb: np.array,
diff --git a/haystack/document_store/milvus.py b/haystack/document_store/milvus.py
new file mode 100644
index 000000000..28332ed0a
--- /dev/null
+++ b/haystack/document_store/milvus.py
@@ -0,0 +1,498 @@
+import logging
+from typing import Any, Dict, Generator, List, Optional, Union
+
+import numpy
+import numpy as np
+
+from milvus import IndexType, MetricType, Milvus, Status
+from scipy.special import expit
+from tqdm import tqdm
+
+from haystack import Document
+from haystack.document_store.sql import SQLDocumentStore
+from haystack.retriever.base import BaseRetriever
+from haystack.utils import get_batches_from_generator
+
+logger = logging.getLogger(__name__)
+
+
+class MilvusDocumentStore(SQLDocumentStore):
+ """
+ Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors.
+ Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR).
+ In contrast to FAISS, Milvus ...
+ - runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment
+ - allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index)
+ - encapsulates multiple ANN libraries (FAISS, ANNOY ...)
+
+ This class uses Milvus for all vector related storage, processing and querying.
+ The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus
+ does not allow these data types (yet).
+
+ Usage:
+ 1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md)
+ 2. Init a MilvusDocumentStore in Haystack
+ """
+
+ def __init__(
+ self,
+ sql_url: str = "sqlite:///",
+ milvus_url: str = "tcp://localhost:19530",
+ connection_pool: str = "SingletonThread",
+ index: str = "document",
+ vector_dim: int = 768,
+ index_file_size: int = 1024,
+ similarity: str = "dot_product",
+ index_type: IndexType = IndexType.FLAT,
+ index_param: Optional[Dict[str, Any]] = None,
+ search_param: Optional[Dict[str, Any]] = None,
+ update_existing_documents: bool = False,
+ return_embedding: bool = False,
+ embedding_field: str = "embedding",
+ **kwargs,
+ ):
+ """
+ :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
+ deployment, Postgres is recommended. If using MySQL then same server can also be used for
+ Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md.
+ :param milvus_url: Milvus server connection URL for storing and processing vectors.
+ Protocol, host and port will automatically be inferred from the URL.
+ See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance.
+ :param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
+ :param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
+ :param vector_dim: The embedding vector size. Default: 768.
+ :param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
+ When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
+ Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
+ As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048.
+ Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory.
+ (From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size)
+ :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings.
+ 'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus.
+ However, you can normalize your embeddings and use `dot_product` to get the same results.
+ See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating.
+ :param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy.
+ Some popular options:
+ - FLAT (default): Exact method, slow
+ - IVF_FLAT, inverted file based heuristic, fast
+ - HSNW: Graph based, fast
+ - ANNOY: Tree based, fast
+ See: https://milvus.io/docs/v0.10.5/index.md
+ :param index_param: Configuration parameters for the chose index_type needed at indexing time.
+ For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT.
+ See https://milvus.io/docs/v0.10.5/index.md
+ :param search_param: Configuration parameters for the chose index_type needed at query time
+ For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT.
+ See https://milvus.io/docs/v0.10.5/index.md
+ :param update_existing_documents: Whether to update any existing documents with the same ID when adding
+ documents. When set as True, any document with an existing ID gets updated.
+ If set to False, an error is raised if the document ID of the document being
+ added already exists.
+ :param return_embedding: To return document embedding.
+ :param embedding_field: Name of field containing an embedding vector.
+ """
+ self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
+ self.vector_dim = vector_dim
+ self.index_file_size = index_file_size
+
+ if similarity == "dot_product":
+ self.metric_type = MetricType.L2
+ else:
+ raise ValueError("The Milvus document store can currently only support dot_product similarity. "
+ "Please set similarity=\"dot_product\"")
+
+ self.index_type = index_type
+ self.index_param = index_param or {"nlist": 16384}
+ self.search_param = search_param or {"nprobe": 10}
+ self.index = index
+ self._create_collection_and_index_if_not_exist(self.index)
+ self.return_embedding = return_embedding
+ self.embedding_field = embedding_field
+
+ super().__init__(
+ url=sql_url,
+ update_existing_documents=update_existing_documents,
+ index=index
+ )
+
+ def __del__(self):
+ return self.milvus_server.close()
+
+ def _create_collection_and_index_if_not_exist(
+ self,
+ index: Optional[str] = None,
+ index_param: Optional[Dict[str, Any]] = None
+ ):
+ index = index or self.index
+ index_param = index_param or self.index_param
+
+ status, ok = self.milvus_server.has_collection(collection_name=index)
+ if not ok:
+ collection_param = {
+ 'collection_name': index,
+ 'dimension': self.vector_dim,
+ 'index_file_size': self.index_file_size,
+ 'metric_type': self.metric_type
+ }
+
+ status = self.milvus_server.create_collection(collection_param)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Collection creation on Milvus server failed: {status}')
+
+ status = self.milvus_server.create_index(index, self.index_type, index_param)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Index creation on Milvus server failed: {status}')
+
+ def _create_document_field_map(self) -> Dict:
+ return {
+ self.index: self.embedding_field,
+ }
+
+ def write_documents(
+ self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
+ ):
+ """
+ Add new documents to the DocumentStore.
+
+ :param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
+ them right away in Milvus. If not, you can later call update_embeddings() to create & index them.
+ :param index: (SQL) index name for storing the docs and metadata
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ :return:
+ """
+ index = index or self.index
+ self._create_collection_and_index_if_not_exist(index)
+ field_map = self._create_document_field_map()
+
+ if len(documents) == 0:
+ logger.warning("Calling DocumentStore.write_documents() with empty list")
+ return
+
+ document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
+
+ add_vectors = False if document_objects[0].embedding is None else True
+
+ batched_documents = get_batches_from_generator(document_objects, batch_size)
+ with tqdm(total=len(document_objects)) as progress_bar:
+ for document_batch in batched_documents:
+ vector_ids = []
+ if add_vectors:
+ doc_ids = []
+ embeddings = []
+ for doc in document_batch:
+ doc_ids.append(doc.id)
+ if isinstance(doc.embedding, np.ndarray):
+ embeddings.append(doc.embedding.tolist())
+ elif isinstance(doc.embedding, list):
+ embeddings.append(doc.embedding)
+ else:
+ raise AttributeError(f'Format of supplied document embedding {type(doc.embedding)} is not '
+ f'supported. Please use list or numpy.ndarray')
+
+ if self.update_existing_documents:
+ existing_docs = super().get_documents_by_id(ids=doc_ids, index=index)
+ self._delete_vector_ids_from_milvus(documents=existing_docs, index=index)
+
+ status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Vector embedding insertion failed: {status}')
+
+ docs_to_write_in_sql = []
+ for idx, doc in enumerate(document_batch):
+ meta = doc.meta
+ if add_vectors:
+ meta["vector_id"] = vector_ids[idx]
+ docs_to_write_in_sql.append(doc)
+
+ super().write_documents(docs_to_write_in_sql, index=index)
+ progress_bar.update(batch_size)
+ progress_bar.close()
+
+ self.milvus_server.flush([index])
+ if self.update_existing_documents:
+ self.milvus_server.compact(collection_name=index)
+
+ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
+ """
+ Updates the embeddings in the the document store using the encoding model specified in the retriever.
+ This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
+
+ :param retriever: Retriever to use to get embeddings for text
+ :param index: (SQL) index name for storing the docs and metadata
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ :return: None
+ """
+ index = index or self.index
+ self._create_collection_and_index_if_not_exist(index)
+
+ document_count = self.get_document_count(index=index)
+ if document_count == 0:
+ logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
+ return
+
+ logger.info(f"Updating embeddings for {document_count} docs...")
+
+ result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
+ batched_documents = get_batches_from_generator(result, batch_size)
+ with tqdm(total=document_count) as progress_bar:
+ for document_batch in batched_documents:
+ self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
+
+ embeddings = retriever.embed_passages(document_batch) # type: ignore
+ embeddings_list = [embedding.tolist() for embedding in embeddings]
+ assert len(document_batch) == len(embeddings_list)
+
+ status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Vector embedding insertion failed: {status}')
+
+ vector_id_map = {}
+ for vector_id, doc in zip(vector_ids, document_batch):
+ vector_id_map[doc.id] = vector_id
+
+ self.update_vector_ids(vector_id_map, index=index)
+ progress_bar.update(batch_size)
+ progress_bar.close()
+
+ self.milvus_server.flush([index])
+ self.milvus_server.compact(collection_name=index)
+
+ def query_by_embedding(self,
+ query_emb: np.array,
+ filters: Optional[dict] = None,
+ top_k: int = 10,
+ index: Optional[str] = None,
+ return_embedding: Optional[bool] = None) -> List[Document]:
+ """
+ Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
+
+ :param query_emb: Embedding of the query (e.g. gathered from DPR)
+ :param filters: Optional filters to narrow down the search space.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param top_k: How many documents to return
+ :param index: (SQL) index name for storing the docs and metadata
+ :param return_embedding: To return document embedding
+ :return:
+ """
+ if filters:
+ raise Exception("Query filters are not implemented for the MilvusDocumentStore.")
+
+ index = index or self.index
+ status, ok = self.milvus_server.has_collection(collection_name=index)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Milvus has collection check failed: {status}')
+ if not ok:
+ raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
+
+ if return_embedding is None:
+ return_embedding = self.return_embedding
+ index = index or self.index
+
+ query_emb = query_emb.reshape(1, -1).astype(np.float32)
+ status, search_result = self.milvus_server.search(
+ collection_name=index,
+ query_records=query_emb,
+ top_k=top_k,
+ params=self.search_param
+ )
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Vector embedding search failed: {status}')
+
+ vector_ids_for_query = []
+ scores_for_vector_ids: Dict[str, float] = {}
+ for vector_id_list, distance_list in zip(search_result.id_array, search_result.distance_array):
+ for vector_id, distance in zip(vector_id_list, distance_list):
+ vector_ids_for_query.append(str(vector_id))
+ scores_for_vector_ids[str(vector_id)] = distance
+
+ documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
+
+ if return_embedding:
+ self._populate_embeddings_to_docs(index=index, docs=documents)
+
+ for doc in documents:
+ doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
+ doc.probability = float(expit(np.asarray(doc.score / 100)))
+
+ return documents
+
+ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
+ """
+ Delete all documents (from SQL AND Milvus).
+ :param index: (SQL) index name for storing the docs and metadata
+ :param filters: Optional filters to narrow down the search space.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :return: None
+ """
+ index = index or self.index
+ super().delete_all_documents(index=index, filters=filters)
+ status, ok = self.milvus_server.has_collection(collection_name=index)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Milvus has collection check failed: {status}')
+ if ok:
+ status = self.milvus_server.drop_collection(collection_name=index)
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Milvus drop collection failed: {status}')
+
+ self.milvus_server.flush([index])
+ self.milvus_server.compact(collection_name=index)
+
+ def get_all_documents_generator(
+ self,
+ index: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ return_embedding: Optional[bool] = None,
+ batch_size: int = 10_000,
+ ) -> Generator[Document, None, None]:
+ """
+ Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
+ document store and yielded as individual documents. This method can be used to iteratively process
+ a large number of documents without having to load all documents in memory.
+
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :param filters: Optional filters to narrow down the documents to return.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param return_embedding: Whether to return the document embeddings.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ """
+ index = index or self.index
+ documents = super().get_all_documents_generator(
+ index=index, filters=filters, batch_size=batch_size
+ )
+ if return_embedding is None:
+ return_embedding = self.return_embedding
+
+ for doc in documents:
+ if return_embedding:
+ self._populate_embeddings_to_docs(index=index, docs=[doc])
+ yield doc
+
+ def get_all_documents(
+ self,
+ index: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ return_embedding: Optional[bool] = None,
+ batch_size: int = 10_000,
+ ) -> List[Document]:
+ """
+ Get documents from the document store (optionally using filter criteria).
+
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :param filters: Optional filters to narrow down the documents to return.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param return_embedding: Whether to return the document embeddings.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ """
+
+ index = index or self.index
+ result = self.get_all_documents_generator(
+ index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
+ )
+ documents = list(result)
+ return documents
+
+ def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
+ """
+ Fetch a document by specifying its text id string
+
+ :param id: ID of the document
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ """
+ documents = self.get_documents_by_id([id], index)
+ document = documents[0] if documents else None
+ return document
+
+ def get_documents_by_id(
+ self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
+ ) -> List[Document]:
+ """
+ Fetch multiple documents by specifying their IDs (strings)
+
+ :param ids: List of IDs of the documents
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ """
+ index = index or self.index
+ documents = super().get_documents_by_id(ids=ids, index=index)
+ if self.return_embedding:
+ self._populate_embeddings_to_docs(index=index, docs=documents)
+
+ return documents
+
+ def _populate_embeddings_to_docs(self, docs: List[Document], index: Optional[str] = None):
+ index = index or self.index
+ docs_with_vector_ids = []
+ for doc in docs:
+ if doc.meta and doc.meta.get("vector_id") is not None:
+ docs_with_vector_ids.append(doc)
+
+ if len(docs_with_vector_ids) == 0:
+ return
+
+ ids = [int(doc.meta.get("vector_id")) for doc in docs_with_vector_ids] # type: ignore
+ status, vector_embeddings = self.milvus_server.get_entity_by_id(
+ collection_name=index,
+ ids=ids
+ )
+ if status.code != Status.SUCCESS:
+ raise RuntimeError(f'Getting vector embedding by id failed: {status}')
+
+ for embedding, doc in zip(vector_embeddings, docs_with_vector_ids):
+ doc.embedding = numpy.array(embedding, dtype="float32")
+
+ def _delete_vector_ids_from_milvus(self, documents: List[Document], index: Optional[str] = None):
+ index = index or self.index
+ existing_vector_ids = []
+ for doc in documents:
+ if "vector_id" in doc.meta:
+ existing_vector_ids.append(int(doc.meta["vector_id"]))
+ if len(existing_vector_ids) > 0:
+ status = self.milvus_server.delete_entity_by_id(
+ collection_name=index,
+ id_array=existing_vector_ids
+ )
+ if status.code != Status.SUCCESS:
+ raise RuntimeError("E existing vector ids deletion failed: {status}")
+
+ def get_all_vectors(self, index=None) -> List[np.array]:
+ """
+ Helper function to dump all vectors stored in Milvus server.
+
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :return: List[np.array]: List of vectors.
+ """
+ index = index or self.index
+ status, collection_info = self.milvus_server.get_collection_stats(collection_name=index)
+ if not status.OK():
+ logger.info(f"Failed fetch stats from store ...")
+ return list()
+
+ logger.debug(f"collection_info = {collection_info}")
+
+ ids = list()
+ partition_list = collection_info["partitions"]
+ for partition in partition_list:
+ segment_list = partition["segments"]
+ for segment in segment_list:
+ segment_name = segment["name"]
+ status, id_list = self.milvus_server.list_id_in_segment(
+ collection_name=index,
+ segment_name=segment_name)
+ logger.debug(f"{status}: segment {segment_name} has {len(id_list)} vectors ...")
+ ids.extend(id_list)
+
+ if len(ids) == 0:
+ logger.info(f"No documents in the store ...")
+ return list()
+
+ status, vectors = self.milvus_server.get_entity_by_id(collection_name=index, ids=ids)
+ if not status.OK():
+ logger.info(f"Failed fetch document for ids {ids} from store ...")
+ return list()
+
+ return vectors
diff --git a/haystack/document_store/sql.py b/haystack/document_store/sql.py
index df52c5b76..e0aa85b25 100644
--- a/haystack/document_store/sql.py
+++ b/haystack/document_store/sql.py
@@ -165,6 +165,7 @@ class SQLDocumentStore(BaseDocumentStore):
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
@@ -408,7 +409,7 @@ class SQLDocumentStore(BaseDocumentStore):
"""
if filters:
- raise NotImplementedError("Delete by filters is not implemented for SQLDocumentStore.")
+ raise NotImplementedError(f"Delete by filters is not implemented for {type(self).__name__}")
index = index or self.index
documents = self.session.query(DocumentORM).filter_by(index=index)
documents.delete(synchronize_session=False)
diff --git a/haystack/generator/transformers.py b/haystack/generator/transformers.py
index 97e5509d1..99b642691 100644
--- a/haystack/generator/transformers.py
+++ b/haystack/generator/transformers.py
@@ -176,7 +176,7 @@ class RAGenerator(BaseGenerator):
embeddings = self.retriever.embed_passages(docs)
embeddings_in_tensor = torch.cat(
- [torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
+ [torch.from_numpy(embedding).float().unsqueeze(0) for embedding in embeddings],
dim=0
)
diff --git a/requirements.txt b/requirements.txt
index 83a77c027..85db4b510 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,4 +22,6 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
httptools
nltk
more_itertools
-networkx
\ No newline at end of file
+networkx
+# Refer milvus version support matrix at https://github.com/milvus-io/pymilvus#install-pymilvus
+pymilvus
\ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
index 94110ad65..7537a3255 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -6,6 +6,9 @@ from sys import platform
import pytest
import requests
from elasticsearch import Elasticsearch
+from milvus import Milvus
+
+from haystack.document_store.milvus import MilvusDocumentStore
from haystack.generator.transformers import RAGenerator, RAGeneratorType
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
@@ -78,6 +81,20 @@ def elasticsearch_fixture():
time.sleep(30)
+@pytest.fixture(scope="session")
+def milvus_fixture():
+ # test if a Milvus server is already running. If not, start Milvus docker container locally.
+ # Make sure you have given > 6GB memory to docker engine
+ try:
+ milvus_server = Milvus(uri="tcp://localhost:19530", timeout=5, wait_timeout=5)
+ milvus_server.server_status(timeout=5)
+ except:
+ print("Starting Milvus ...")
+ status = subprocess.run(['docker run -d --name milvus_cpu_0.10.5 -p 19530:19530 -p 19121:19121 '
+ 'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
+ time.sleep(40)
+
+
@pytest.fixture(scope="session")
def tika_fixture():
try:
@@ -245,21 +262,19 @@ def get_retriever(retriever_type, document_store):
return retriever
-@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
+@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store_with_docs(request, test_docs_xs):
document_store = get_document_store(request.param)
document_store.write_documents(test_docs_xs)
yield document_store
- if request.param == "faiss":
- document_store.faiss_index.reset()
+ document_store.delete_all_documents()
-@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
+@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store(request, test_docs_xs):
document_store = get_document_store(request.param)
yield document_store
- if request.param == "faiss":
- document_store.faiss_index.reset()
+ document_store.delete_all_documents()
def get_document_store(document_store_type, embedding_field="embedding"):
@@ -284,6 +299,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
index="haystack_test",
)
return document_store
+ elif document_store_type == "milvus":
+ document_store = MilvusDocumentStore(
+ sql_url="sqlite://",
+ return_embedding=True,
+ embedding_field=embedding_field,
+ index="haystack_test",
+ )
+ return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
diff --git a/test/test_document_store.py b/test/test_document_store.py
index b02a9166c..44239b703 100644
--- a/test/test_document_store.py
+++ b/test/test_document_store.py
@@ -113,7 +113,7 @@ def test_get_all_documents_generator(document_store):
@pytest.mark.elasticsearch
-@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss", "milvus"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
original_docs = [
@@ -177,7 +177,7 @@ def test_write_document_index(document_store):
@pytest.mark.elasticsearch
-@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_document_with_embeddings(document_store):
documents = [
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
@@ -196,7 +196,7 @@ def test_document_with_embeddings(document_store):
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
-@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_update_embeddings(document_store, retriever):
documents = []
for i in range(23):
@@ -232,17 +232,17 @@ def test_update_embeddings(document_store, retriever):
@pytest.mark.elasticsearch
def test_delete_all_documents(document_store_with_docs):
- assert len(document_store_with_docs.get_all_documents(index="haystack_test")) == 3
+ assert len(document_store_with_docs.get_all_documents()) == 3
- document_store_with_docs.delete_all_documents(index="haystack_test")
- documents = document_store_with_docs.get_all_documents(index="haystack_test")
+ document_store_with_docs.delete_all_documents()
+ documents = document_store_with_docs.get_all_documents()
assert len(documents) == 0
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_delete_documents_with_filters(document_store_with_docs):
- document_store_with_docs.delete_all_documents(index="haystack_test", filters={"meta_field": ["test1", "test2"]})
+ document_store_with_docs.delete_all_documents(filters={"meta_field": ["test1", "test2"]})
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 1
assert documents[0].meta["meta_field"] == "test3"
diff --git a/test/test_dpr_retriever.py b/test/test_dpr_retriever.py
index a9c8a2edc..45b72c398 100644
--- a/test/test_dpr_retriever.py
+++ b/test/test_dpr_retriever.py
@@ -4,13 +4,14 @@ import numpy as np
from haystack import Document
from haystack.document_store.faiss import FAISSDocumentStore
+from haystack.document_store.milvus import MilvusDocumentStore
from haystack.retriever.dense import DensePassageRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
@pytest.mark.slow
@pytest.mark.elasticsearch
-@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("return_embedding", [True, False])
def test_dpr_retrieval(document_store, retriever, return_embedding):
@@ -71,7 +72,7 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
assert res[0].embedding is None
# test filtering
- if not isinstance(document_store, FAISSDocumentStore):
+ if not isinstance(document_store, FAISSDocumentStore) and not isinstance(document_store, MilvusDocumentStore):
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
assert len(res) == 2
for r in res:
diff --git a/test/test_embedding_retriever.py b/test/test_embedding_retriever.py
index 4ab5f000e..049814a26 100644
--- a/test/test_embedding_retriever.py
+++ b/test/test_embedding_retriever.py
@@ -4,7 +4,7 @@ from haystack import Finder
@pytest.mark.slow
@pytest.mark.elasticsearch
-@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_embedding_retriever(retriever, document_store):
diff --git a/test/test_faiss.py b/test/test_faiss_and_milvus.py
similarity index 79%
rename from test/test_faiss.py
rename to test/test_faiss_and_milvus.py
index 5cc76667a..163568690 100644
--- a/test/test_faiss.py
+++ b/test/test_faiss_and_milvus.py
@@ -1,4 +1,5 @@
import os
+from copy import deepcopy
import faiss
import numpy as np
@@ -63,9 +64,9 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
-@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
+@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
@pytest.mark.parametrize("batch_size", [4, 6])
-def test_faiss_update_docs(document_store, retriever, batch_size):
+def test_update_docs(document_store, retriever, batch_size):
# initial write
document_store.write_documents(DOCUMENTS)
@@ -82,9 +83,35 @@ def test_faiss_update_docs(document_store, retriever, batch_size):
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
+@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
-@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
-def test_faiss_update_with_empty_store(document_store, retriever):
+@pytest.mark.parametrize("document_store", ["milvus", "faiss"], indirect=True)
+def test_update_exiting_docs(document_store, retriever):
+ document_store.update_existing_documents = True
+ old_document = Document(text="text_1")
+ # initial write
+ document_store.write_documents([old_document])
+ document_store.update_embeddings(retriever=retriever)
+ old_documents_indexed = document_store.get_all_documents()
+ assert len(old_documents_indexed) == 1
+
+ # Update document data
+ new_document = Document(text="text_2")
+ new_document.id = old_document.id
+ document_store.write_documents([new_document])
+ document_store.update_embeddings(retriever=retriever)
+ new_documents_indexed = document_store.get_all_documents()
+ assert len(new_documents_indexed) == 1
+
+ assert old_documents_indexed[0].id == new_documents_indexed[0].id
+ assert old_documents_indexed[0].text == "text_1"
+ assert new_documents_indexed[0].text == "text_2"
+ assert not np.allclose(old_documents_indexed[0].embedding, new_documents_indexed[0].embedding, rtol=0.01)
+
+
+@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
+@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
+def test_update_with_empty_store(document_store, retriever):
# Call update with empty doc store
document_store.update_embeddings(retriever=retriever)
@@ -125,8 +152,8 @@ def test_faiss_retrieving(index_factory):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
-@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
-def test_faiss_finding(document_store, retriever):
+@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
+def test_finding(document_store, retriever):
document_store.write_documents(DOCUMENTS)
finder = Finder(reader=None, retriever=retriever)
@@ -136,8 +163,8 @@ def test_faiss_finding(document_store, retriever):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
-@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
-def test_faiss_pipeline(document_store, retriever):
+@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
+def test_pipeline(document_store, retriever):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
diff --git a/test/test_generator.py b/test/test_generator.py
index 9528080ce..bf8a6083a 100644
--- a/test/test_generator.py
+++ b/test/test_generator.py
@@ -415,7 +415,7 @@ def test_rag_token_generator(rag_generator):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
- [("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
+ [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_generator_pipeline(document_store, retriever, rag_generator):
diff --git a/test/test_imports.py b/test/test_imports.py
index b6eea08ce..8ef311a2e 100644
--- a/test/test_imports.py
+++ b/test/test_imports.py
@@ -1,6 +1,10 @@
def test_module_imports():
from haystack import Finder
from haystack.document_store.sql import SQLDocumentStore
+ from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
+ from haystack.document_store.faiss import FAISSDocumentStore
+ from haystack.document_store.milvus import MilvusDocumentStore
+ from haystack.document_store.base import BaseDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
@@ -10,6 +14,10 @@ def test_module_imports():
assert Finder is not None
assert SQLDocumentStore is not None
+ assert ElasticsearchDocumentStore is not None
+ assert FAISSDocumentStore is not None
+ assert MilvusDocumentStore is not None
+ assert BaseDocumentStore is not None
assert clean_wiki_text is not None
assert convert_files_to_dicts is not None
assert fetch_archive_from_http is not None
diff --git a/test/test_pipeline.py b/test/test_pipeline.py
index 7109d99ae..27cfd95d8 100644
--- a/test/test_pipeline.py
+++ b/test/test_pipeline.py
@@ -67,7 +67,7 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
- [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
+ [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
indirect=True,
)
def test_faq_pipeline(retriever, document_store):
@@ -97,7 +97,7 @@ def test_faq_pipeline(retriever, document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
- [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
+ [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
indirect=True,
)
def test_document_search_pipeline(retriever, document_store):
diff --git a/test/test_summarizer.py b/test/test_summarizer.py
index fa7ff18a2..0f9a6527e 100644
--- a/test/test_summarizer.py
+++ b/test/test_summarizer.py
@@ -57,7 +57,7 @@ def test_summarization_one_summary(summarizer):
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store",
- [("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
+ [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_summarization_pipeline(document_store, retriever, summarizer):
@@ -79,7 +79,7 @@ def test_summarization_pipeline(document_store, retriever, summarizer):
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store",
- [("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
+ [("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):