Milvus integration (#771)

* Initial commit for Milvus integration

* Add latest docstring and tutorial changes

* Updating implementation of Milvus document store

* Add latest docstring and tutorial changes

* Adding tests and updating doc string

* Add latest docstring and tutorial changes

* Fixing issue caught by tests

* Addressing review comments

* Fixing mypy detected issue

* Fixing issue caught in test about sorting of vector ids

* fixing test

* Fixing generator test failure

* update docstrings

* Addressing review comments about multiple network call while fetching embedding from milvus server

* Add latest docstring and tutorial changes

* Ignoring mypy issue while converting vector_id to int

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
This commit is contained in:
Lalit Pagaria 2021-01-29 13:29:12 +01:00 committed by GitHub
parent 6efa4f06c1
commit 9f7f95221f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 609 additions and 40 deletions

View File

@ -17,6 +17,9 @@ jobs:
- name: Run Elasticsearch
run: docker run -d -p 9200:9200 -e "discovery.type=single-node" -e "ES_JAVA_OPTS=-Xms128m -Xmx128m" elasticsearch:7.9.2
- name: Run Milvus
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:0.10.5-cpu-d010621-4eda95
- name: Run Apache Tika
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1

View File

@ -346,7 +346,7 @@ None
#### delete\_all\_documents
```python
| delete_all_documents(index: str, filters: Optional[Dict[str, List[str]]] = None)
| delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
```
Delete documents in an index. All documents are deleted if no filters are passed.
@ -629,6 +629,7 @@ DocumentStore's default index (self.index) will be used.
- `filters`: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
- `return_embedding`: Whether to return the document embeddings.
- `batch_size`: When working with large number of documents, batching can help reduce memory footprint.
<a name="sql.SQLDocumentStore.get_all_labels"></a>
#### get\_all\_labels
@ -763,7 +764,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_
```python
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", **kwargs, ,)
| __init__(sql_url: str = "sqlite:///", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional[faiss.swigfaiss.Index] = None, return_embedding: bool = False, update_existing_documents: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", **kwargs, ,)
```
**Arguments**:
@ -796,6 +797,7 @@ added already exists.
- `index`: Name of index in document store to use.
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
- `embedding_field`: Name of field containing an embedding vector.
<a name="faiss.FAISSDocumentStore.write_documents"></a>
#### write\_documents
@ -881,7 +883,7 @@ None
#### delete\_all\_documents
```python
| delete_all_documents(index=None)
| delete_all_documents(index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None)
```
Delete all documents from the document store.

View File

@ -198,6 +198,6 @@ class BaseDocumentStore(ABC):
logger.error("File needs to be in json or jsonl format.")
@abstractmethod
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
pass

View File

@ -757,7 +757,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
bulk(self.client, doc_updates, request_timeout=300, refresh=self.refresh_type)
def delete_all_documents(self, index: str, filters: Optional[Dict[str, List[str]]] = None):
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete documents in an index. All documents are deleted if no filters are passed.
@ -765,6 +765,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
:param filters: Optional filters to narrow down the documents to be deleted.
:return: None
"""
index = index or self.index
query: Dict[str, Any] = {"query": {}}
if filters:
filter_clause = []

View File

@ -41,6 +41,7 @@ class FAISSDocumentStore(SQLDocumentStore):
update_existing_documents: bool = False,
index: str = "document",
similarity: str = "dot_product",
embedding_field: str = "embedding",
**kwargs,
):
"""
@ -72,6 +73,7 @@ class FAISSDocumentStore(SQLDocumentStore):
:param index: Name of index in document store to use.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default sine it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
:param embedding_field: Name of field containing an embedding vector.
"""
self.vector_dim = vector_dim
@ -83,6 +85,7 @@ class FAISSDocumentStore(SQLDocumentStore):
self.faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
self.return_embedding = return_embedding
self.embedding_field = embedding_field
if similarity == "dot_product":
self.similarity = similarity
else:
@ -154,7 +157,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def _create_document_field_map(self) -> Dict:
return {
self.index: "embedding",
self.index: self.embedding_field,
}
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
@ -275,13 +278,13 @@ class FAISSDocumentStore(SQLDocumentStore):
embeddings = np.array(embeddings, dtype="float32")
self.faiss_index.train(embeddings)
def delete_all_documents(self, index=None):
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete all documents from the document store.
"""
index = index or self.index
self.faiss_index.reset()
super().delete_all_documents(index=index)
super().delete_all_documents(index=index, filters=filters)
def query_by_embedding(self,
query_emb: np.array,

View File

@ -0,0 +1,498 @@
import logging
from typing import Any, Dict, Generator, List, Optional, Union
import numpy
import numpy as np
from milvus import IndexType, MetricType, Milvus, Status
from scipy.special import expit
from tqdm import tqdm
from haystack import Document
from haystack.document_store.sql import SQLDocumentStore
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_batches_from_generator
logger = logging.getLogger(__name__)
class MilvusDocumentStore(SQLDocumentStore):
"""
Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors.
Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR).
In contrast to FAISS, Milvus ...
- runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment
- allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index)
- encapsulates multiple ANN libraries (FAISS, ANNOY ...)
This class uses Milvus for all vector related storage, processing and querying.
The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus
does not allow these data types (yet).
Usage:
1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md)
2. Init a MilvusDocumentStore in Haystack
"""
def __init__(
self,
sql_url: str = "sqlite:///",
milvus_url: str = "tcp://localhost:19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
index_type: IndexType = IndexType.FLAT,
index_param: Optional[Dict[str, Any]] = None,
search_param: Optional[Dict[str, Any]] = None,
update_existing_documents: bool = False,
return_embedding: bool = False,
embedding_field: str = "embedding",
**kwargs,
):
"""
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
deployment, Postgres is recommended. If using MySQL then same server can also be used for
Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md.
:param milvus_url: Milvus server connection URL for storing and processing vectors.
Protocol, host and port will automatically be inferred from the URL.
See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance.
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
:param vector_dim: The embedding vector size. Default: 768.
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048.
Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory.
(From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size)
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings.
'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus.
However, you can normalize your embeddings and use `dot_product` to get the same results.
See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating.
:param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy.
Some popular options:
- FLAT (default): Exact method, slow
- IVF_FLAT, inverted file based heuristic, fast
- HSNW: Graph based, fast
- ANNOY: Tree based, fast
See: https://milvus.io/docs/v0.10.5/index.md
:param index_param: Configuration parameters for the chose index_type needed at indexing time.
For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT.
See https://milvus.io/docs/v0.10.5/index.md
:param search_param: Configuration parameters for the chose index_type needed at query time
For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT.
See https://milvus.io/docs/v0.10.5/index.md
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists.
:param return_embedding: To return document embedding.
:param embedding_field: Name of field containing an embedding vector.
"""
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
self.vector_dim = vector_dim
self.index_file_size = index_file_size
if similarity == "dot_product":
self.metric_type = MetricType.L2
else:
raise ValueError("The Milvus document store can currently only support dot_product similarity. "
"Please set similarity=\"dot_product\"")
self.index_type = index_type
self.index_param = index_param or {"nlist": 16384}
self.search_param = search_param or {"nprobe": 10}
self.index = index
self._create_collection_and_index_if_not_exist(self.index)
self.return_embedding = return_embedding
self.embedding_field = embedding_field
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
index=index
)
def __del__(self):
return self.milvus_server.close()
def _create_collection_and_index_if_not_exist(
self,
index: Optional[str] = None,
index_param: Optional[Dict[str, Any]] = None
):
index = index or self.index
index_param = index_param or self.index_param
status, ok = self.milvus_server.has_collection(collection_name=index)
if not ok:
collection_param = {
'collection_name': index,
'dimension': self.vector_dim,
'index_file_size': self.index_file_size,
'metric_type': self.metric_type
}
status = self.milvus_server.create_collection(collection_param)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Collection creation on Milvus server failed: {status}')
status = self.milvus_server.create_index(index, self.index_type, index_param)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Index creation on Milvus server failed: {status}')
def _create_document_field_map(self) -> Dict:
return {
self.index: self.embedding_field,
}
def write_documents(
self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000
):
"""
Add new documents to the DocumentStore.
:param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
them right away in Milvus. If not, you can later call update_embeddings() to create & index them.
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return:
"""
index = index or self.index
self._create_collection_and_index_if_not_exist(index)
field_map = self._create_document_field_map()
if len(documents) == 0:
logger.warning("Calling DocumentStore.write_documents() with empty list")
return
document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
add_vectors = False if document_objects[0].embedding is None else True
batched_documents = get_batches_from_generator(document_objects, batch_size)
with tqdm(total=len(document_objects)) as progress_bar:
for document_batch in batched_documents:
vector_ids = []
if add_vectors:
doc_ids = []
embeddings = []
for doc in document_batch:
doc_ids.append(doc.id)
if isinstance(doc.embedding, np.ndarray):
embeddings.append(doc.embedding.tolist())
elif isinstance(doc.embedding, list):
embeddings.append(doc.embedding)
else:
raise AttributeError(f'Format of supplied document embedding {type(doc.embedding)} is not '
f'supported. Please use list or numpy.ndarray')
if self.update_existing_documents:
existing_docs = super().get_documents_by_id(ids=doc_ids, index=index)
self._delete_vector_ids_from_milvus(documents=existing_docs, index=index)
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding insertion failed: {status}')
docs_to_write_in_sql = []
for idx, doc in enumerate(document_batch):
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_ids[idx]
docs_to_write_in_sql.append(doc)
super().write_documents(docs_to_write_in_sql, index=index)
progress_bar.update(batch_size)
progress_bar.close()
self.milvus_server.flush([index])
if self.update_existing_documents:
self.milvus_server.compact(collection_name=index)
def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to get embeddings for text
:param index: (SQL) index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
:return: None
"""
index = index or self.index
self._create_collection_and_index_if_not_exist(index)
document_count = self.get_document_count(index=index)
if document_count == 0:
logger.warning("Calling DocumentStore.update_embeddings() on an empty index")
return
logger.info(f"Updating embeddings for {document_count} docs...")
result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False)
batched_documents = get_batches_from_generator(result, batch_size)
with tqdm(total=document_count) as progress_bar:
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
embeddings = retriever.embed_passages(document_batch) # type: ignore
embeddings_list = [embedding.tolist() for embedding in embeddings]
assert len(document_batch) == len(embeddings_list)
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding insertion failed: {status}')
vector_id_map = {}
for vector_id, doc in zip(vector_ids, document_batch):
vector_id_map[doc.id] = vector_id
self.update_vector_ids(vector_id_map, index=index)
progress_bar.update(batch_size)
progress_bar.close()
self.milvus_server.flush([index])
self.milvus_server.compact(collection_name=index)
def query_by_embedding(self,
query_emb: np.array,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None,
return_embedding: Optional[bool] = None) -> List[Document]:
"""
Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
:param query_emb: Embedding of the query (e.g. gathered from DPR)
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param top_k: How many documents to return
:param index: (SQL) index name for storing the docs and metadata
:param return_embedding: To return document embedding
:return:
"""
if filters:
raise Exception("Query filters are not implemented for the MilvusDocumentStore.")
index = index or self.index
status, ok = self.milvus_server.has_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus has collection check failed: {status}')
if not ok:
raise Exception("No index exists. Use 'update_embeddings()` to create an index.")
if return_embedding is None:
return_embedding = self.return_embedding
index = index or self.index
query_emb = query_emb.reshape(1, -1).astype(np.float32)
status, search_result = self.milvus_server.search(
collection_name=index,
query_records=query_emb,
top_k=top_k,
params=self.search_param
)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Vector embedding search failed: {status}')
vector_ids_for_query = []
scores_for_vector_ids: Dict[str, float] = {}
for vector_id_list, distance_list in zip(search_result.id_array, search_result.distance_array):
for vector_id, distance in zip(vector_id_list, distance_list):
vector_ids_for_query.append(str(vector_id))
scores_for_vector_ids[str(vector_id)] = distance
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
if return_embedding:
self._populate_embeddings_to_docs(index=index, docs=documents)
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
doc.probability = float(expit(np.asarray(doc.score / 100)))
return documents
def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
"""
Delete all documents (from SQL AND Milvus).
:param index: (SQL) index name for storing the docs and metadata
:param filters: Optional filters to narrow down the search space.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:return: None
"""
index = index or self.index
super().delete_all_documents(index=index, filters=filters)
status, ok = self.milvus_server.has_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus has collection check failed: {status}')
if ok:
status = self.milvus_server.drop_collection(collection_name=index)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Milvus drop collection failed: {status}')
self.milvus_server.flush([index])
self.milvus_server.compact(collection_name=index)
def get_all_documents_generator(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> Generator[Document, None, None]:
"""
Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
document store and yielded as individual documents. This method can be used to iteratively process
a large number of documents without having to load all documents in memory.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
documents = super().get_all_documents_generator(
index=index, filters=filters, batch_size=batch_size
)
if return_embedding is None:
return_embedding = self.return_embedding
for doc in documents:
if return_embedding:
self._populate_embeddings_to_docs(index=index, docs=[doc])
yield doc
def get_all_documents(
self,
index: Optional[str] = None,
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
) -> List[Document]:
"""
Get documents from the document store (optionally using filter criteria).
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
documents = list(result)
return documents
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
"""
Fetch a document by specifying its text id string
:param id: ID of the document
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
"""
documents = self.get_documents_by_id([id], index)
document = documents[0] if documents else None
return document
def get_documents_by_id(
self, ids: List[str], index: Optional[str] = None, batch_size: int = 10_000
) -> List[Document]:
"""
Fetch multiple documents by specifying their IDs (strings)
:param ids: List of IDs of the documents
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
documents = super().get_documents_by_id(ids=ids, index=index)
if self.return_embedding:
self._populate_embeddings_to_docs(index=index, docs=documents)
return documents
def _populate_embeddings_to_docs(self, docs: List[Document], index: Optional[str] = None):
index = index or self.index
docs_with_vector_ids = []
for doc in docs:
if doc.meta and doc.meta.get("vector_id") is not None:
docs_with_vector_ids.append(doc)
if len(docs_with_vector_ids) == 0:
return
ids = [int(doc.meta.get("vector_id")) for doc in docs_with_vector_ids] # type: ignore
status, vector_embeddings = self.milvus_server.get_entity_by_id(
collection_name=index,
ids=ids
)
if status.code != Status.SUCCESS:
raise RuntimeError(f'Getting vector embedding by id failed: {status}')
for embedding, doc in zip(vector_embeddings, docs_with_vector_ids):
doc.embedding = numpy.array(embedding, dtype="float32")
def _delete_vector_ids_from_milvus(self, documents: List[Document], index: Optional[str] = None):
index = index or self.index
existing_vector_ids = []
for doc in documents:
if "vector_id" in doc.meta:
existing_vector_ids.append(int(doc.meta["vector_id"]))
if len(existing_vector_ids) > 0:
status = self.milvus_server.delete_entity_by_id(
collection_name=index,
id_array=existing_vector_ids
)
if status.code != Status.SUCCESS:
raise RuntimeError("E existing vector ids deletion failed: {status}")
def get_all_vectors(self, index=None) -> List[np.array]:
"""
Helper function to dump all vectors stored in Milvus server.
:param index: Name of the index to get the documents from. If None, the
DocumentStore's default index (self.index) will be used.
:return: List[np.array]: List of vectors.
"""
index = index or self.index
status, collection_info = self.milvus_server.get_collection_stats(collection_name=index)
if not status.OK():
logger.info(f"Failed fetch stats from store ...")
return list()
logger.debug(f"collection_info = {collection_info}")
ids = list()
partition_list = collection_info["partitions"]
for partition in partition_list:
segment_list = partition["segments"]
for segment in segment_list:
segment_name = segment["name"]
status, id_list = self.milvus_server.list_id_in_segment(
collection_name=index,
segment_name=segment_name)
logger.debug(f"{status}: segment {segment_name} has {len(id_list)} vectors ...")
ids.extend(id_list)
if len(ids) == 0:
logger.info(f"No documents in the store ...")
return list()
status, vectors = self.milvus_server.get_entity_by_id(collection_name=index, ids=ids)
if not status.OK():
logger.info(f"Failed fetch document for ids {ids} from store ...")
return list()
return vectors

View File

@ -165,6 +165,7 @@ class SQLDocumentStore(BaseDocumentStore):
:param filters: Optional filters to narrow down the documents to return.
Example: {"name": ["some", "more"], "category": ["only_one"]}
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
index = index or self.index
@ -408,7 +409,7 @@ class SQLDocumentStore(BaseDocumentStore):
"""
if filters:
raise NotImplementedError("Delete by filters is not implemented for SQLDocumentStore.")
raise NotImplementedError(f"Delete by filters is not implemented for {type(self).__name__}")
index = index or self.index
documents = self.session.query(DocumentORM).filter_by(index=index)
documents.delete(synchronize_session=False)

View File

@ -176,7 +176,7 @@ class RAGenerator(BaseGenerator):
embeddings = self.retriever.embed_passages(docs)
embeddings_in_tensor = torch.cat(
[torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
[torch.from_numpy(embedding).float().unsqueeze(0) for embedding in embeddings],
dim=0
)

View File

@ -22,4 +22,6 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
httptools
nltk
more_itertools
networkx
networkx
# Refer milvus version support matrix at https://github.com/milvus-io/pymilvus#install-pymilvus
pymilvus

View File

@ -6,6 +6,9 @@ from sys import platform
import pytest
import requests
from elasticsearch import Elasticsearch
from milvus import Milvus
from haystack.document_store.milvus import MilvusDocumentStore
from haystack.generator.transformers import RAGenerator, RAGeneratorType
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever, ElasticsearchRetriever, TfidfRetriever
@ -78,6 +81,20 @@ def elasticsearch_fixture():
time.sleep(30)
@pytest.fixture(scope="session")
def milvus_fixture():
# test if a Milvus server is already running. If not, start Milvus docker container locally.
# Make sure you have given > 6GB memory to docker engine
try:
milvus_server = Milvus(uri="tcp://localhost:19530", timeout=5, wait_timeout=5)
milvus_server.server_status(timeout=5)
except:
print("Starting Milvus ...")
status = subprocess.run(['docker run -d --name milvus_cpu_0.10.5 -p 19530:19530 -p 19121:19121 '
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
time.sleep(40)
@pytest.fixture(scope="session")
def tika_fixture():
try:
@ -245,21 +262,19 @@ def get_retriever(retriever_type, document_store):
return retriever
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store_with_docs(request, test_docs_xs):
document_store = get_document_store(request.param)
document_store.write_documents(test_docs_xs)
yield document_store
if request.param == "faiss":
document_store.faiss_index.reset()
document_store.delete_all_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql"])
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store(request, test_docs_xs):
document_store = get_document_store(request.param)
yield document_store
if request.param == "faiss":
document_store.faiss_index.reset()
document_store.delete_all_documents()
def get_document_store(document_store_type, embedding_field="embedding"):
@ -284,6 +299,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
index="haystack_test",
)
return document_store
elif document_store_type == "milvus":
document_store = MilvusDocumentStore(
sql_url="sqlite://",
return_embedding=True,
embedding_field=embedding_field,
index="haystack_test",
)
return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")

View File

@ -113,7 +113,7 @@ def test_get_all_documents_generator(document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss", "milvus"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
original_docs = [
@ -177,7 +177,7 @@ def test_write_document_index(document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_document_with_embeddings(document_store):
documents = [
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
@ -196,7 +196,7 @@ def test_document_with_embeddings(document_store):
@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
def test_update_embeddings(document_store, retriever):
documents = []
for i in range(23):
@ -232,17 +232,17 @@ def test_update_embeddings(document_store, retriever):
@pytest.mark.elasticsearch
def test_delete_all_documents(document_store_with_docs):
assert len(document_store_with_docs.get_all_documents(index="haystack_test")) == 3
assert len(document_store_with_docs.get_all_documents()) == 3
document_store_with_docs.delete_all_documents(index="haystack_test")
documents = document_store_with_docs.get_all_documents(index="haystack_test")
document_store_with_docs.delete_all_documents()
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 0
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_delete_documents_with_filters(document_store_with_docs):
document_store_with_docs.delete_all_documents(index="haystack_test", filters={"meta_field": ["test1", "test2"]})
document_store_with_docs.delete_all_documents(filters={"meta_field": ["test1", "test2"]})
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 1
assert documents[0].meta["meta_field"] == "test3"

View File

@ -4,13 +4,14 @@ import numpy as np
from haystack import Document
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.document_store.milvus import MilvusDocumentStore
from haystack.retriever.dense import DensePassageRetriever
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("return_embedding", [True, False])
def test_dpr_retrieval(document_store, retriever, return_embedding):
@ -71,7 +72,7 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
assert res[0].embedding is None
# test filtering
if not isinstance(document_store, FAISSDocumentStore):
if not isinstance(document_store, FAISSDocumentStore) and not isinstance(document_store, MilvusDocumentStore):
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
assert len(res) == 2
for r in res:

View File

@ -4,7 +4,7 @@ from haystack import Finder
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
def test_embedding_retriever(retriever, document_store):

View File

@ -1,4 +1,5 @@
import os
from copy import deepcopy
import faiss
import numpy as np
@ -63,9 +64,9 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
@pytest.mark.parametrize("batch_size", [4, 6])
def test_faiss_update_docs(document_store, retriever, batch_size):
def test_update_docs(document_store, retriever, batch_size):
# initial write
document_store.write_documents(DOCUMENTS)
@ -82,9 +83,35 @@ def test_faiss_update_docs(document_store, retriever, batch_size):
assert np.allclose(updated_embedding, stored_doc.embedding, rtol=0.01)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_update_with_empty_store(document_store, retriever):
@pytest.mark.parametrize("document_store", ["milvus", "faiss"], indirect=True)
def test_update_exiting_docs(document_store, retriever):
document_store.update_existing_documents = True
old_document = Document(text="text_1")
# initial write
document_store.write_documents([old_document])
document_store.update_embeddings(retriever=retriever)
old_documents_indexed = document_store.get_all_documents()
assert len(old_documents_indexed) == 1
# Update document data
new_document = Document(text="text_2")
new_document.id = old_document.id
document_store.write_documents([new_document])
document_store.update_embeddings(retriever=retriever)
new_documents_indexed = document_store.get_all_documents()
assert len(new_documents_indexed) == 1
assert old_documents_indexed[0].id == new_documents_indexed[0].id
assert old_documents_indexed[0].text == "text_1"
assert new_documents_indexed[0].text == "text_2"
assert not np.allclose(old_documents_indexed[0].embedding, new_documents_indexed[0].embedding, rtol=0.01)
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_update_with_empty_store(document_store, retriever):
# Call update with empty doc store
document_store.update_embeddings(retriever=retriever)
@ -125,8 +152,8 @@ def test_faiss_retrieving(index_factory):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_finding(document_store, retriever):
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_finding(document_store, retriever):
document_store.write_documents(DOCUMENTS)
finder = Finder(reader=None, retriever=retriever)
@ -136,8 +163,8 @@ def test_faiss_finding(document_store, retriever):
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_pipeline(document_store, retriever):
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_pipeline(document_store, retriever):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},

View File

@ -415,7 +415,7 @@ def test_rag_token_generator(rag_generator):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_generator_pipeline(document_store, retriever, rag_generator):

View File

@ -1,6 +1,10 @@
def test_module_imports():
from haystack import Finder
from haystack.document_store.sql import SQLDocumentStore
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
from haystack.document_store.milvus import MilvusDocumentStore
from haystack.document_store.base import BaseDocumentStore
from haystack.preprocessor.cleaning import clean_wiki_text
from haystack.preprocessor.utils import convert_files_to_dicts, fetch_archive_from_http
from haystack.reader.farm import FARMReader
@ -10,6 +14,10 @@ def test_module_imports():
assert Finder is not None
assert SQLDocumentStore is not None
assert ElasticsearchDocumentStore is not None
assert FAISSDocumentStore is not None
assert MilvusDocumentStore is not None
assert BaseDocumentStore is not None
assert clean_wiki_text is not None
assert convert_files_to_dicts is not None
assert fetch_archive_from_http is not None

View File

@ -67,7 +67,7 @@ def test_extractive_qa_answers_single_result(reader, retriever_with_docs):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
indirect=True,
)
def test_faq_pipeline(retriever, document_store):
@ -97,7 +97,7 @@ def test_faq_pipeline(retriever, document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "elasticsearch")],
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("embedding", "elasticsearch")],
indirect=True,
)
def test_document_search_pipeline(retriever, document_store):

View File

@ -57,7 +57,7 @@ def test_summarization_one_summary(summarizer):
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_summarization_pipeline(document_store, retriever, summarizer):
@ -79,7 +79,7 @@ def test_summarization_pipeline(document_store, retriever, summarizer):
@pytest.mark.summarizer
@pytest.mark.parametrize(
"retriever,document_store",
[("embedding", "memory"), ("embedding", "faiss"), ("elasticsearch", "elasticsearch")],
[("embedding", "memory"), ("embedding", "faiss"), ("embedding", "milvus"), ("elasticsearch", "elasticsearch")],
indirect=True,
)
def test_summarization_pipeline_one_summary(document_store, retriever, summarizer):