diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index fe7966ea9..1cadc0339 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -76,6 +76,9 @@ jobs:
- name: Run Milvus
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.1.0-cpu-d050721-5e559c
+ - name: Run Weaviate
+ run: docker run -d -p 8080:8080 --name haystack_test_weaviate --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
+
- name: Run GraphDB
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11
diff --git a/docs/_src/usage/usage/document_store.md b/docs/_src/usage/usage/document_store.md
index 43baecdb0..bd4e5b3f4 100644
--- a/docs/_src/usage/usage/document_store.md
+++ b/docs/_src/usage/usage/document_store.md
@@ -116,6 +116,27 @@ from haystack.document_store import SQLDocumentStore
document_store = SQLDocumentStore()
```
+
+
+
+
+
+
+
+
+The `WeaviateDocumentStore` requires a running Weaviate Server.
+You can start a basic instance like this (see Weaviate docs for details):
+```
+ docker run -d -p 8080:8080 --env AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED='true' --env PERSISTENCE_DATA_PATH='/var/lib/weaviate' semitechnologies/weaviate:1.4.0
+```
+
+Afterwards, you can use it in Haystack:
+```python
+from haystack.document_store import WeaviateDocumentStore
+
+document_store = WeaviateDocumentStore()
+```
+
@@ -264,6 +285,24 @@ The Document Stores have different characteristics. You should choose one depend
+
+
+
+
+
+
+**Pros:**
+- Simple vector search
+- Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
+- Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
+
+**Cons:**
+- Less options for ANN algorithms than FAISS or Milvus
+- No BM25 / Tf-idf retrieval
+
+
+
+
@@ -276,4 +315,4 @@ The Document Stores have different characteristics. You should choose one depend
**Vector Specialist:** Use the `MilvusDocumentStore`, if you want to focus on dense retrieval and possibly deal with larger datasets
-
\ No newline at end of file
+
diff --git a/haystack/document_store/weaviate.py b/haystack/document_store/weaviate.py
new file mode 100644
index 000000000..e654a520d
--- /dev/null
+++ b/haystack/document_store/weaviate.py
@@ -0,0 +1,716 @@
+import logging
+from typing import Any, Dict, Generator, List, Optional, Union
+import numpy as np
+from tqdm import tqdm
+
+from haystack import Document
+from haystack.document_store.base import BaseDocumentStore
+from haystack.utils import get_batches_from_generator
+
+from weaviate import client, auth, AuthClientPassword
+from weaviate import ObjectsBatchRequest
+
+logger = logging.getLogger(__name__)
+
+class WeaviateDocumentStore(BaseDocumentStore):
+ """
+
+ Weaviate is a cloud-native, modular, real-time vector search engine built to scale your machine learning models.
+ (See https://www.semi.technology/developers/weaviate/current/index.html#what-is-weaviate)
+
+ Some of the key differences in contrast to FAISS & Milvus:
+ 1. Stores everything in one place: documents, meta data and vectors - so less network overhead when scaling this up
+ 2. Allows combination of vector search and scalar filtering, i.e. you can filter for a certain tag and do dense retrieval on that subset
+ 3. Has less variety of ANN algorithms, as of now only HNSW.
+
+ Weaviate python client is used to connect to the server, more details are here
+ https://weaviate-python-client.readthedocs.io/en/docs/weaviate.html
+
+ Usage:
+ 1. Start a Weaviate server (see https://www.semi.technology/developers/weaviate/current/getting-started/installation.html)
+ 2. Init a WeaviateDocumentStore in Haystack
+ """
+
+ def __init__(
+ self,
+ host: Union[str, List[str]] = "http://localhost",
+ port: Union[int, List[int]] = 8080,
+ timeout_config: tuple = (5, 15),
+ username: str = None,
+ password: str = None,
+ index: str = "Document",
+ embedding_dim: int = 768,
+ text_field: str = "text",
+ name_field: str = "name",
+ faq_question_field = "question",
+ similarity: str = "dot_product",
+ index_type: str = "hnsw",
+ custom_schema: Optional[dict] = None,
+ return_embedding: bool = False,
+ embedding_field: str = "embedding",
+ progress_bar: bool = True,
+ duplicate_documents: str = 'overwrite',
+ **kwargs,
+ ):
+ """
+ :param host: Weaviate server connection URL for storing and processing documents and vectors.
+ For more details, refer "https://www.semi.technology/developers/weaviate/current/getting-started/installation.html"
+ :param port: port of Weaviate instance
+ :param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds).
+ :param username: username (standard authentication via http_auth)
+ :param password: password (standard authentication via http_auth)
+ :param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema).
+ :param embedding_dim: The embedding vector size. Default: 768.
+ :param text_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
+ If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
+ :param name_field: Name of field that contains the title of the the doc
+ :param faq_question_field: Name of field containing the question in case of FAQ-Style QA
+ :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default.
+ :param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
+ Currently, HSNW is only supported.
+ See: https://www.semi.technology/developers/weaviate/current/more-resources/performance.html
+ :param custom_schema: Allows to create custom schema in Weaviate, for more details
+ See https://www.semi.technology/developers/weaviate/current/data-schema/schema-configuration.html
+ :param module_name : Vectorization module to convert data into vectors. Default is "text2vec-trasnformers"
+ For more details, See https://www.semi.technology/developers/weaviate/current/modules/
+ :param return_embedding: To return document embedding.
+ :param embedding_field: Name of field containing an embedding vector.
+ :param progress_bar: Whether to show a tqdm progress bar or not.
+ Can be helpful to disable in production deployments to keep the logs clean.
+ :param duplicate_documents:Handle duplicates document based on parameter options.
+ Parameter options : ( 'skip','overwrite','fail')
+ skip: Ignore the duplicates documents
+ overwrite: Update any existing documents with the same ID when adding documents.
+ fail: an error is raised if the document ID of the document being added already exists.
+ """
+
+ # save init parameters to enable export of component config as YAML
+ self.set_config(
+ host=host, port=port, timeout_config=timeout_config, username=username, password=password,
+ index=index, embedding_dim=embedding_dim, text_field=text_field, name_field=name_field,
+ faq_question_field=faq_question_field, similarity=similarity, index_type=index_type,
+ custom_schema=custom_schema,return_embedding=return_embedding, embedding_field=embedding_field,
+ progress_bar=progress_bar, duplicate_documents=duplicate_documents
+ )
+
+ # Connect to Weaviate server using python binding
+ weaviate_url =f"{host}:{port}"
+ if username and password:
+ secret = AuthClientPassword(username, password)
+ self.weaviate_client = client.Client(url=weaviate_url,
+ auth_client_secret=secret,
+ timeout_config=timeout_config)
+ else:
+ self.weaviate_client = client.Client(url=weaviate_url,
+ timeout_config=timeout_config)
+
+ # Test Weaviate connection
+ try:
+ status = self.weaviate_client.is_ready()
+ if not status:
+ raise ConnectionError(
+ f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
+ f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
+ )
+ except Exception:
+ raise ConnectionError(
+ f"Initial connection to Weaviate failed. Make sure you run Weaviate instance "
+ f"at `{weaviate_url}` and that it has finished the initial ramp up (can take > 30s)."
+ )
+ self.index = index
+ self.embedding_dim = embedding_dim
+ self.text_field = text_field
+ self.name_field = name_field
+ self.faq_question_field = faq_question_field
+ self.similarity = similarity
+ self.index_type = index_type
+ self.custom_schema = custom_schema
+ self.return_embedding = return_embedding
+ self.embedding_field = embedding_field
+ self.progress_bar = progress_bar
+ self.duplicate_documents = duplicate_documents
+
+ self._create_schema_and_index_if_not_exist(self.index)
+
+ def _create_schema_and_index_if_not_exist(
+ self,
+ index: Optional[str] = None,
+ ):
+ """Create a new index (schema/class in Weaviate) for storing documents in case if an index (schema) with the name doesn't exist already."""
+ index = index or self.index
+
+ if self.custom_schema:
+ schema = self.custom_schema
+ else:
+ schema = {
+ "classes": [
+ {
+ "class": index,
+ "description": "Haystack index, it's a class in Weaviate",
+ "invertedIndexConfig": {
+ "cleanupIntervalSeconds": 60
+ },
+ "vectorizer": "none",
+ "properties": [
+ {
+ "dataType": [
+ "string"
+ ],
+ "description": "Name Field",
+ "name": self.name_field
+ },
+ {
+ "dataType": [
+ "string"
+ ],
+ "description": "Question Field",
+ "name": self.faq_question_field
+ },
+ {
+ "dataType": [
+ "text"
+ ],
+ "description": "Document Text",
+ "name": self.text_field
+ },
+ ],
+ }
+ ]
+ }
+ if not self.weaviate_client.schema.contains(schema):
+ self.weaviate_client.schema.create(schema)
+
+ def _convert_weaviate_result_to_document(
+ self,
+ result: dict,
+ return_embedding: bool
+ ) -> Document:
+ """
+ Convert weaviate result dict into haystack document object. This is more involved because
+ weaviate search result dict varies between get and query interfaces.
+ Weaviate get methods return the data items in properties key, whereas the query doesn't.
+ """
+ score = None
+ probability = None
+ text = ""
+ question = None
+
+ id = result.get("id")
+ embedding = result.get("vector")
+
+ # If properties key is present, get all the document fields from it.
+ # otherwise, a direct lookup in result root dict
+ props = result.get("properties")
+ if not props:
+ props = result
+
+ if props.get(self.text_field) is not None:
+ text = str(props.get(self.text_field))
+
+ if props.get(self.faq_question_field) is not None:
+ question = props.get(self.faq_question_field)
+
+ # Weaviate creates "_additional" key for semantic search
+ if "_additional" in props:
+ if "certainty" in props["_additional"]:
+ score = props["_additional"]['certainty']
+ probability = score
+ if "id" in props["_additional"]:
+ id = props["_additional"]['id']
+ if "vector" in props["_additional"]:
+ embedding = props["_additional"]['vector']
+ props.pop("_additional", None)
+
+ # We put all additional data of the doc into meta_data and return it in the API
+ meta_data = {k:v for k,v in props.items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)}
+
+ if return_embedding and embedding:
+ embedding = np.asarray(embedding, dtype=np.float32)
+
+ document = Document(
+ id=id,
+ text=text,
+ meta=meta_data,
+ score=score,
+ probability=probability,
+ question=question,
+ embedding=embedding,
+ )
+ return document
+
+ def _create_document_field_map(self) -> Dict:
+ return {
+ self.text_field: "text",
+ self.embedding_field: "embedding",
+ self.faq_question_field if self.faq_question_field else "question": "question"
+ }
+
+ def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
+ """Fetch a document by specifying its text id string"""
+ # Sample result dict from a get method
+ '''{'class': 'Document',
+ 'creationTimeUnix': 1621075584724,
+ 'id': '1bad51b7-bd77-485d-8871-21c50fab248f',
+ 'properties': {'meta': "{'key1':'value1'}",
+ 'name': 'name_5',
+ 'text': 'text_5'},
+ 'vector': []}'''
+ index = index or self.index
+ document = None
+ result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
+ if result:
+ document = self._convert_weaviate_result_to_document(result, return_embedding=True)
+ return document
+
+ def get_documents_by_id(self, ids: List[str], index: Optional[str] = None,
+ batch_size: int = 10_000) -> List[Document]:
+ """Fetch documents by specifying a list of text id strings"""
+ index = index or self.index
+ documents = []
+ #TODO: better implementation with multiple where filters instead of chatty call below?
+ for id in ids:
+ result = self.weaviate_client.data_object.get_by_id(id, with_vector=True)
+ if result:
+ document = self._convert_weaviate_result_to_document(result, return_embedding=True)
+ documents.append(document)
+ return documents
+
+ def _get_current_properties(self, index: Optional[str] = None) -> List[str]:
+ """Get all the existing properties in the schema"""
+ index = index or self.index
+ cur_properties = []
+ for class_item in self.weaviate_client.schema.get()['classes']:
+ if class_item['class'] == index:
+ cur_properties = [item['name'] for item in class_item['properties']]
+
+ return cur_properties
+
+ def _build_filter_clause(self, filters:Dict[str, List[str]]) -> dict:
+ """Transform Haystack filter conditions to Weaviate where filter clauses"""
+ weaviate_filters = []
+ weaviate_filter = {}
+ for key, values in filters.items():
+ for value in values:
+ weaviate_filter = {
+ "path": [key],
+ "operator": "Equal",
+ "valueString": value
+ }
+ weaviate_filters.append(weaviate_filter)
+ if len(weaviate_filters) > 1:
+ filter_dict = {
+ "operator": "Or",
+ "operands": weaviate_filters
+ }
+ return filter_dict
+ else:
+ return weaviate_filter
+
+ def _update_schema(self, new_prop:str, index: Optional[str] = None):
+ """Updates the schema with a new property"""
+ index = index or self.index
+ property_dict = {
+ "dataType": [
+ "string"
+ ],
+ "description": f"dynamic property {new_prop}",
+ "name": new_prop
+ }
+ self.weaviate_client.schema.property.create(index, property_dict)
+
+ def _check_document(self, cur_props: List[str], doc: dict) -> List[str]:
+ """Find the properties in the document that don't exist in the existing schema"""
+ return [item for item in doc.keys() if item not in cur_props]
+
+ def write_documents(
+ self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
+ batch_size: int = 10_000, duplicate_documents: Optional[str] = None):
+ """
+ Add new documents to the DocumentStore.
+
+ :param documents: List of `Dicts` or List of `Documents`. Passing an Embedding/Vector is mandatory in case weaviate is not
+ configured with a module. If a module is configured, the embedding is automatically generated by Weaviate.
+ :param index: index name for storing the docs and metadata
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ :param duplicate_documents: Handle duplicates document based on parameter options.
+ Parameter options : ( 'skip','overwrite','fail')
+ skip: Ignore the duplicates documents
+ overwrite: Update any existing documents with the same ID when adding documents.
+ fail: an error is raised if the document ID of the document being added already
+ exists.
+ :raises DuplicateDocumentError: Exception trigger on duplicate document
+ :return: None
+ """
+ index = index or self.index
+ self._create_schema_and_index_if_not_exist(index)
+ field_map = self._create_document_field_map()
+
+ duplicate_documents = duplicate_documents or self.duplicate_documents
+ assert duplicate_documents in self.duplicate_documents_options, \
+ f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}"
+
+ if len(documents) == 0:
+ logger.warning("Calling DocumentStore.write_documents() with empty list")
+ return
+
+ # Auto schema feature https://github.com/semi-technologies/weaviate/issues/1539
+ # Get and cache current properties in the schema
+ current_properties = self._get_current_properties(index)
+
+ document_objects = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
+ document_objects = self._handle_duplicate_documents(document_objects, duplicate_documents)
+
+ batched_documents = get_batches_from_generator(document_objects, batch_size)
+ with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar:
+ for document_batch in batched_documents:
+ docs_batch = ObjectsBatchRequest()
+ for idx, doc in enumerate(document_batch):
+ _doc = {
+ **doc.to_dict(field_map=self._create_document_field_map())
+ }
+ _ = _doc.pop("score", None)
+ _ = _doc.pop("probability", None)
+
+ # In order to have a flat structure in elastic + similar behaviour to the other DocumentStores,
+ # we "unnest" all value within "meta"
+ if "meta" in _doc.keys():
+ for k, v in _doc["meta"].items():
+ _doc[k] = v
+ _doc.pop("meta")
+
+ doc_id = str(_doc.pop("id"))
+ vector = _doc.pop(self.embedding_field)
+ if _doc.get(self.faq_question_field) is None:
+ _doc.pop(self.faq_question_field)
+
+ # Check if additional properties are in the document, if so,
+ # append the schema with all the additional properties
+ missing_props = self._check_document(current_properties, _doc)
+ if missing_props:
+ for property in missing_props:
+ self._update_schema(property, index)
+ current_properties.append(property)
+
+ docs_batch.add(_doc, class_name=index, uuid=doc_id, vector=vector)
+
+ # Ingest a batch of documents
+ results = self.weaviate_client.batch.create(docs_batch)
+ # Weaviate returns errors for every failed document in the batch
+ if results is not None:
+ for result in results:
+ if 'result' in result and 'errors' in result['result'] \
+ and 'error' in result['result']['errors']:
+ for message in result['result']['errors']['error']:
+ logger.error(f"{message['message']}")
+ progress_bar.update(batch_size)
+ progress_bar.close()
+
+ def update_document_meta(self, id: str, meta: Dict[str, str]):
+ """
+ Update the metadata dictionary of a document by specifying its string id
+ """
+ self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
+
+ def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
+ """
+ Return the number of documents in the document store.
+ """
+ index = index or self.index
+ doc_count = 0
+ if filters:
+ filter_dict = self._build_filter_clause(filters=filters)
+ result = self.weaviate_client.query.aggregate(index) \
+ .with_fields("meta { count }") \
+ .with_where(filter_dict)\
+ .do()
+ else:
+ result = self.weaviate_client.query.aggregate(index)\
+ .with_fields("meta { count }")\
+ .do()
+
+ if "data" in result:
+ if "Aggregate" in result.get('data'):
+ doc_count = result.get('data').get('Aggregate').get(index)[0]['meta']['count']
+
+ return doc_count
+
+ def get_all_documents(
+ self,
+ index: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ return_embedding: Optional[bool] = None,
+ batch_size: int = 10_000,
+ ) -> List[Document]:
+ """
+ Get documents from the document store.
+
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :param filters: Optional filters to narrow down the documents to return.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param return_embedding: Whether to return the document embeddings.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ """
+ index = index or self.index
+ result = self.get_all_documents_generator(
+ index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
+ )
+ documents = list(result)
+ return documents
+
+ def _get_all_documents_in_index(
+ self,
+ index: Optional[str],
+ filters: Optional[Dict[str, List[str]]] = None,
+ batch_size: int = 10_000,
+ only_documents_without_embedding: bool = False,
+ ) -> Generator[dict, None, None]:
+ """
+ Return all documents in a specific index in the document store
+ """
+ index = index or self.index
+
+ # Build the properties to retrieve from Weaviate
+ properties = self._get_current_properties(index)
+ properties.append("_additional {id, certainty, vector}")
+
+ if filters:
+ filter_dict = self._build_filter_clause(filters=filters)
+ result = self.weaviate_client.query.get(class_name=index, properties=properties)\
+ .with_where(filter_dict)\
+ .do()
+ else:
+ result = self.weaviate_client.query.get(class_name=index, properties=properties)\
+ .do()
+
+ all_docs = {}
+ if result and "data" in result and "Get" in result.get("data"):
+ if result.get("data").get("Get").get(index):
+ all_docs = result.get("data").get("Get").get(index)
+
+ yield from all_docs
+
+ def get_all_documents_generator(
+ self,
+ index: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ return_embedding: Optional[bool] = None,
+ batch_size: int = 10_000,
+ ) -> Generator[Document, None, None]:
+ """
+ Get documents from the document store. Under-the-hood, documents are fetched in batches from the
+ document store and yielded as individual documents. This method can be used to iteratively process
+ a large number of documents without having to load all documents in memory.
+
+ :param index: Name of the index to get the documents from. If None, the
+ DocumentStore's default index (self.index) will be used.
+ :param filters: Optional filters to narrow down the documents to return.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param return_embedding: Whether to return the document embeddings.
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ """
+
+ if index is None:
+ index = self.index
+
+ if return_embedding is None:
+ return_embedding = self.return_embedding
+
+ results = self._get_all_documents_in_index(index=index, filters=filters, batch_size=batch_size)
+ for result in results:
+ document = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
+ yield document
+
+ def query(
+ self,
+ query: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ top_k: int = 10,
+ custom_query: Optional[str] = None,
+ index: Optional[str] = None,
+ ) -> List[Document]:
+ """
+ Scan through documents in DocumentStore and return a small number documents
+ that are most relevant to the query as defined by Weaviate semantic search.
+
+ :param query: The query
+ :param filters: A dictionary where the keys specify a metadata field and the value is a list of accepted values for that field
+ :param top_k: How many documents to return per query.
+ :param custom_query: Custom query that will executed using query.raw method, for more details refer
+ https://www.semi.technology/developers/weaviate/current/graphql-references/filters.html
+ :param index: The name of the index in the DocumentStore from which to retrieve documents
+ """
+ index = index or self.index
+
+ # Build the properties to retrieve from Weaviate
+ properties = self._get_current_properties(index)
+ properties.append("_additional {id, certainty, vector}")
+
+ if custom_query:
+ query_output = self.weaviate_client.query.raw(custom_query)
+ elif filters:
+ filter_dict = self._build_filter_clause(filters)
+ query_output = self.weaviate_client.query\
+ .get(class_name=index, properties=properties)\
+ .with_where(filter_dict)\
+ .with_limit(top_k)\
+ .do()
+ else:
+ raise NotImplementedError("Weaviate does not support inverted index text query. However, "
+ "it allows to search by filters example : {'text': 'some text'} or "
+ "use a custom GraphQL query in text format!")
+
+ results = []
+ if query_output and "data" in query_output and "Get" in query_output.get("data"):
+ if query_output.get("data").get("Get").get(index):
+ results = query_output.get("data").get("Get").get(index)
+
+ documents = []
+ for result in results:
+ doc = self._convert_weaviate_result_to_document(result, return_embedding=True)
+ documents.append(doc)
+
+ return documents
+
+ def query_by_embedding(self,
+ query_emb: np.ndarray,
+ filters: Optional[dict] = None,
+ top_k: int = 10,
+ index: Optional[str] = None,
+ return_embedding: Optional[bool] = None) -> List[Document]:
+ """
+ Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.
+
+ :param query_emb: Embedding of the query (e.g. gathered from DPR)
+ :param filters: Optional filters to narrow down the search space.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param top_k: How many documents to return
+ :param index: index name for storing the docs and metadata
+ :param return_embedding: To return document embedding
+ :return:
+ """
+ if return_embedding is None:
+ return_embedding = self.return_embedding
+ index = index or self.index
+
+ # Build the properties to retrieve from Weaviate
+ properties = self._get_current_properties(index)
+ properties.append("_additional {id, certainty, vector}")
+
+ query_emb = query_emb.reshape(1, -1).astype(np.float32)
+ query_string = {
+ "vector" : query_emb
+ }
+ if filters:
+ filter_dict = self._build_filter_clause(filters)
+ query_output = self.weaviate_client.query\
+ .get(class_name=index, properties=properties)\
+ .with_where(filter_dict)\
+ .with_near_vector(query_string)\
+ .with_limit(top_k)\
+ .do()
+ else:
+ query_output = self.weaviate_client.query\
+ .get(class_name=index, properties=properties)\
+ .with_near_vector(query_string)\
+ .with_limit(top_k)\
+ .do()
+
+ results = []
+ if query_output and "data" in query_output and "Get" in query_output.get("data"):
+ if query_output.get("data").get("Get").get(index):
+ results = query_output.get("data").get("Get").get(index)
+
+ documents = []
+ for result in results:
+ doc = self._convert_weaviate_result_to_document(result, return_embedding=return_embedding)
+ documents.append(doc)
+
+ return documents
+
+ def update_embeddings(
+ self,
+ retriever,
+ index: Optional[str] = None,
+ filters: Optional[Dict[str, List[str]]] = None,
+ update_existing_embeddings: bool = True,
+ batch_size: int = 10_000
+ ):
+ """
+ Updates the embeddings in the the document store using the encoding model specified in the retriever.
+ This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config).
+
+ :param retriever: Retriever to use to update the embeddings.
+ :param index: Index name to update
+ :param update_existing_embeddings: Weaviate mandates an embedding while creating the document itself.
+ This option must be always true for weaviate and it will update the embeddings for all the documents.
+ :param filters: Optional filters to narrow down the documents for which embeddings are to be updated.
+ Example: {"name": ["some", "more"], "category": ["only_one"]}
+ :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
+ :return: None
+ """
+ if index is None:
+ index = self.index
+
+ if not self.embedding_field:
+ raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()")
+
+ if update_existing_embeddings:
+ logger.info(f"Updating embeddings for all {self.get_document_count(index=index)} docs ...")
+ else:
+ raise RuntimeError("All the documents in Weaviate store have an embedding by default. Only update is allowed!")
+
+ result = self._get_all_documents_in_index(
+ index=index,
+ filters=filters,
+ batch_size=batch_size,
+ )
+
+ for result_batch in get_batches_from_generator(result, batch_size):
+ document_batch = [self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch]
+ embeddings = retriever.embed_passages(document_batch) # type: ignore
+ assert len(document_batch) == len(embeddings)
+
+ if embeddings[0].shape[0] != self.embedding_dim:
+ raise RuntimeError(f"Embedding dim. of model ({embeddings[0].shape[0]})"
+ f" doesn't match embedding dim. in DocumentStore ({self.embedding_dim})."
+ "Specify the arg `embedding_dim` when initializing WeaviateDocumentStore()")
+ for doc, emb in zip(document_batch, embeddings):
+ # This doc processing will not required once weaviate's update
+ # method works. To be improved.
+ _doc = {
+ **doc.to_dict(field_map=self._create_document_field_map())
+ }
+ _ = _doc.pop("score", None)
+ _ = _doc.pop("probability", None)
+
+ if "meta" in _doc.keys():
+ for k, v in _doc["meta"].items():
+ _doc[k] = v
+ _doc.pop("meta")
+
+ doc_id = str(_doc.pop("id"))
+ _ = _doc.pop(self.embedding_field)
+ keys_to_remove = [k for k,v in _doc.items() if v is None]
+ for key in keys_to_remove:
+ _doc.pop(key)
+
+ # TODO: Weaviate's update throws an error while passing a vector now, have to improve this later
+ self.weaviate_client.data_object.replace(_doc, class_name=index, uuid=doc_id, vector=emb)
+
+ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
+ """
+ Delete documents in an index. All documents are deleted if no filters are passed.
+
+ :param index: Index name to delete the document from.
+ :param filters: Optional filters to narrow down the documents to be deleted.
+ :return: None
+ """
+ index = index or self.index
+ if filters:
+ docs_to_delete = self.get_all_documents(index, filters=filters)
+ for doc in docs_to_delete:
+ self.weaviate_client.data_object.delete(doc.id)
+ else:
+ self.weaviate_client.schema.delete_class(index)
+ self._create_schema_and_index_if_not_exist(index)
diff --git a/requirements.txt b/requirements.txt
index 7303dae0a..c83ba516e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -30,4 +30,5 @@ pymilvus
#selenium
#webdriver-manager
SPARQLWrapper
-mmh3
\ No newline at end of file
+mmh3
+weaviate-client
\ No newline at end of file
diff --git a/test/conftest.py b/test/conftest.py
index 42aac7e06..ceeb4a718 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -9,6 +9,9 @@ from elasticsearch import Elasticsearch
from haystack.knowledge_graph.graphdb import GraphDBKnowledgeGraph
from milvus import Milvus
+import weaviate
+from haystack.document_store.weaviate import WeaviateDocumentStore
+
from haystack.document_store.milvus import MilvusDocumentStore
from haystack.generator.transformers import RAGenerator, RAGeneratorType
@@ -61,6 +64,8 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.pipeline)
elif "slow" in item.nodeid:
item.add_marker(pytest.mark.slow)
+ elif "weaviate" in item.nodeid:
+ item.add_marker(pytest.mark.weaviate)
@pytest.fixture(scope="session")
@@ -98,6 +103,27 @@ def milvus_fixture():
'milvusdb/milvus:0.10.5-cpu-d010621-4eda95'], shell=True)
time.sleep(40)
+@pytest.fixture(scope="session")
+def weaviate_fixture():
+ # test if a Weaviate server is already running. If not, start Weaviate docker container locally.
+ # Make sure you have given > 6GB memory to docker engine
+ try:
+ weaviate_server = weaviate.Client(url='http://localhost:8080', timeout_config=(5, 15))
+ weaviate_server.is_ready()
+ except:
+ print("Starting Weaviate servers ...")
+ status = subprocess.run(
+ ['docker rm haystack_test_weaviate'],
+ shell=True
+ )
+ status = subprocess.run(
+ ['docker run -d --name haystack_test_weaviate -p 8080:8080 semitechnologies/weaviate:1.4.0'],
+ shell=True
+ )
+ if status.returncode:
+ raise Exception(
+ "Failed to launch Weaviate. Please check docker container logs.")
+ time.sleep(60)
@pytest.fixture(scope="session")
def graphdb_fixture():
@@ -310,7 +336,6 @@ def document_store_with_docs(request, test_docs_xs):
yield document_store
document_store.delete_all_documents()
-
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "sql", "milvus"])
def document_store(request, test_docs_xs):
document_store = get_document_store(request.param)
@@ -352,6 +377,14 @@ def get_document_store(document_store_type, embedding_field="embedding"):
if collection.startswith("haystack_test"):
document_store.milvus_server.drop_collection(collection)
return document_store
+ elif document_store_type == "weaviate":
+ document_store = WeaviateDocumentStore(
+ weaviate_url="http://localhost:8080",
+ index="Haystacktest"
+ )
+ document_store.weaviate_client.schema.delete_all()
+ document_store._create_schema_and_index_if_not_exist()
+ return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")
diff --git a/test/pytest.ini b/test/pytest.ini
index 9da069d5b..fe36dbb18 100644
--- a/test/pytest.ini
+++ b/test/pytest.ini
@@ -8,3 +8,4 @@ markers =
generator: marks generator tests (deselect with '-m "not generator"')
pipeline: marks tests with pipeline
summarizer: marks summarizer tests
+ weaviate: marks tests that require weaviate container
diff --git a/test/test_weaviate.py b/test/test_weaviate.py
new file mode 100644
index 000000000..56fc50f82
--- /dev/null
+++ b/test/test_weaviate.py
@@ -0,0 +1,330 @@
+import numpy as np
+import pytest
+from haystack import Document
+from conftest import get_document_store
+import uuid
+
+embedding_dim = 768
+
+def get_uuid():
+ return str(uuid.uuid4())
+
+DOCUMENTS = [
+ {"text": "text1", "id":get_uuid(), "key": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "text2", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "text3", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "text4", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "text5", "id":get_uuid(), "key": "b", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+]
+
+DOCUMENTS_XS = [
+ # current "dict" format for a document
+ {"text": "My name is Carla and I live in Berlin", "id":get_uuid(), "meta": {"metafield": "test1", "name": "filename1"}, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ # meta_field at the top level for backward compatibility
+ {"text": "My name is Paul and I live in New York", "id":get_uuid(), "metafield": "test2", "name": "filename2", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ # Document object for a doc
+ Document(text="My name is Christelle and I live in Paris", id=get_uuid(), meta={"metafield": "test3", "name": "filename3"}, embedding=np.random.rand(embedding_dim).astype(np.float32))
+ ]
+
+@pytest.fixture(params=["weaviate"])
+def document_store_with_docs(request):
+ document_store = get_document_store(request.param)
+ document_store.write_documents(DOCUMENTS_XS)
+ yield document_store
+ document_store.delete_all_documents()
+
+@pytest.fixture(params=["weaviate"])
+def document_store(request):
+ document_store = get_document_store(request.param)
+ yield document_store
+ document_store.delete_all_documents()
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
+def test_get_all_documents_without_filters(document_store_with_docs):
+ documents = document_store_with_docs.get_all_documents()
+ assert all(isinstance(d, Document) for d in documents)
+ assert len(documents) == 3
+ assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"}
+ assert {d.meta["metafield"] for d in documents} == {"test1", "test2", "test3"}
+
+@pytest.mark.weaviate
+def test_get_all_documents_with_correct_filters(document_store_with_docs):
+ documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test2"]})
+ assert len(documents) == 1
+ assert documents[0].meta["name"] == "filename2"
+
+ documents = document_store_with_docs.get_all_documents(filters={"metafield": ["test1", "test3"]})
+ assert len(documents) == 2
+ assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
+ assert {d.meta["metafield"] for d in documents} == {"test1", "test3"}
+
+@pytest.mark.weaviate
+def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
+ documents = document_store_with_docs.get_all_documents(filters={"incorrectmetafield": ["test2"]})
+ assert len(documents) == 0
+
+@pytest.mark.weaviate
+def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs):
+ documents = document_store_with_docs.get_all_documents(filters={"metafield": ["incorrect_value"]})
+ assert len(documents) == 0
+
+@pytest.mark.weaviate
+def test_get_documents_by_id(document_store_with_docs):
+ documents = document_store_with_docs.get_all_documents()
+ doc = document_store_with_docs.get_document_by_id(documents[0].id)
+ assert doc.id == documents[0].id
+ assert doc.text == documents[0].text
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_get_document_count(document_store):
+ document_store.write_documents(DOCUMENTS)
+ assert document_store.get_document_count() == 5
+ assert document_store.get_document_count(filters={"key": ["a"]}) == 1
+ assert document_store.get_document_count(filters={"key": ["b"]}) == 4
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+@pytest.mark.parametrize("batch_size", [2])
+def test_weaviate_write_docs(document_store, batch_size):
+ # Write in small batches
+ for i in range(0, len(DOCUMENTS), batch_size):
+ document_store.write_documents(DOCUMENTS[i: i + batch_size])
+
+ documents_indexed = document_store.get_all_documents()
+ assert len(documents_indexed) == len(DOCUMENTS)
+
+ documents_indexed = document_store.get_all_documents(batch_size=batch_size)
+ assert len(documents_indexed) == len(DOCUMENTS)
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_get_all_document_filter_duplicate_value(document_store):
+ documents = [
+ Document(
+ text="Doc1",
+ meta={"fone": "f0"},
+ id = get_uuid(),
+ embedding= np.random.rand(embedding_dim).astype(np.float32)
+ ),
+ Document(
+ text="Doc1",
+ meta={"fone": "f1", "metaid": "0"},
+ id = get_uuid(),
+ embedding = np.random.rand(embedding_dim).astype(np.float32)
+ ),
+ Document(
+ text="Doc2",
+ meta={"fthree": "f0"},
+ id = get_uuid(),
+ embedding=np.random.rand(embedding_dim).astype(np.float32)
+ )
+ ]
+ document_store.write_documents(documents)
+ documents = document_store.get_all_documents(filters={"fone": ["f1"]})
+ assert documents[0].text == "Doc1"
+ assert len(documents) == 1
+ assert {d.meta["metaid"] for d in documents} == {"0"}
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_get_all_documents_generator(document_store):
+ document_store.write_documents(DOCUMENTS)
+ assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_write_with_duplicate_doc_ids(document_store):
+ id = get_uuid()
+ documents = [
+ Document(
+ text="Doc1",
+ id=id,
+ embedding=np.random.rand(embedding_dim).astype(np.float32)
+ ),
+ Document(
+ text="Doc2",
+ id=id,
+ embedding=np.random.rand(embedding_dim).astype(np.float32)
+ )
+ ]
+ document_store.write_documents(documents, duplicate_documents="skip")
+ with pytest.raises(Exception):
+ document_store.write_documents(documents, duplicate_documents="fail")
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+@pytest.mark.parametrize("update_existing_documents", [True, False])
+def test_update_existing_documents(document_store, update_existing_documents):
+ id = uuid.uuid4()
+ original_docs = [
+ {"text": "text1_orig", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ ]
+
+ updated_docs = [
+ {"text": "text1_new", "id": id, "metafieldforcount": "a", "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ ]
+
+ document_store.update_existing_documents = update_existing_documents
+ document_store.write_documents(original_docs)
+ assert document_store.get_document_count() == 1
+
+ if update_existing_documents:
+ document_store.write_documents(updated_docs, duplicate_documents="overwrite")
+ else:
+ with pytest.raises(Exception):
+ document_store.write_documents(updated_docs, duplicate_documents="fail")
+
+ stored_docs = document_store.get_all_documents()
+ assert len(stored_docs) == 1
+ if update_existing_documents:
+ assert stored_docs[0].text == updated_docs[0]["text"]
+ else:
+ assert stored_docs[0].text == original_docs[0]["text"]
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_write_document_meta(document_store):
+ uid1 = get_uuid()
+ uid2 = get_uuid()
+ uid3 = get_uuid()
+ uid4 = get_uuid()
+ documents = [
+ {"text": "dict_without_meta", "id": uid1, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "dict_with_meta", "metafield": "test2", "name": "filename2", "id": uid2, "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ Document(text="document_object_without_meta", id=uid3, embedding=np.random.rand(embedding_dim).astype(np.float32)),
+ Document(text="document_object_with_meta", meta={"metafield": "test4", "name": "filename3"}, id=uid4, embedding=np.random.rand(embedding_dim).astype(np.float32)),
+ ]
+ document_store.write_documents(documents)
+ documents_in_store = document_store.get_all_documents()
+ assert len(documents_in_store) == 4
+
+ assert not document_store.get_document_by_id(uid1).meta
+ assert document_store.get_document_by_id(uid2).meta["metafield"] == "test2"
+ assert not document_store.get_document_by_id(uid3).meta
+ assert document_store.get_document_by_id(uid4).meta["metafield"] == "test4"
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_write_document_index(document_store):
+ documents = [
+ {"text": "text1", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ {"text": "text2", "id": uuid.uuid4(), "embedding": np.random.rand(embedding_dim).astype(np.float32)},
+ ]
+
+ document_store.write_documents([documents[0]], index="Haystackone")
+ assert len(document_store.get_all_documents(index="Haystackone")) == 1
+
+ document_store.write_documents([documents[1]], index="Haystacktwo")
+ assert len(document_store.get_all_documents(index="Haystacktwo")) == 1
+
+ assert len(document_store.get_all_documents(index="Haystackone")) == 1
+ assert len(document_store.get_all_documents()) == 0
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("retriever", ["dpr", "embedding"], indirect=True)
+@pytest.mark.parametrize("document_store", ["weaviate"], indirect=True)
+def test_update_embeddings(document_store, retriever):
+ documents = []
+ for i in range(6):
+ documents.append({"text": f"text_{i}", "id": str(uuid.uuid4()), "metafield": f"value_{i}", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
+ documents.append({"text": "text_0", "id": str(uuid.uuid4()), "metafield": "value_0", "embedding": np.random.rand(embedding_dim).astype(np.float32)})
+
+ document_store.write_documents(documents, index="HaystackTestOne")
+ document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3)
+ documents = document_store.get_all_documents(index="HaystackTestOne", return_embedding=True)
+ assert len(documents) == 7
+ for doc in documents:
+ assert type(doc.embedding) is np.ndarray
+
+ documents = document_store.get_all_documents(
+ index="HaystackTestOne",
+ filters={"metafield": ["value_0"]},
+ return_embedding=True,
+ )
+ assert len(documents) == 2
+ for doc in documents:
+ assert doc.meta["metafield"] == "value_0"
+ np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
+
+ documents = document_store.get_all_documents(
+ index="HaystackTestOne",
+ filters={"metafield": ["value_1", "value_5"]},
+ return_embedding=True,
+ )
+ np.testing.assert_raises(
+ AssertionError,
+ np.testing.assert_array_equal,
+ documents[0].embedding,
+ documents[1].embedding
+ )
+
+ doc = {"text": "text_7", "id": str(uuid.uuid4()), "metafield": "value_7",
+ "embedding": retriever.embed_queries(texts=["a random string"])[0]}
+ document_store.write_documents([doc], index="HaystackTestOne")
+
+ doc_before_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
+ embedding_before_update = doc_before_update.embedding
+
+ document_store.update_embeddings(
+ retriever, index="HaystackTestOne", batch_size=3, filters={"metafield": ["value_0", "value_1"]}
+ )
+ doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
+ embedding_after_update = doc_after_update.embedding
+ np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
+
+ # test update all embeddings
+ document_store.update_embeddings(retriever, index="HaystackTestOne", batch_size=3, update_existing_embeddings=True)
+ assert document_store.get_document_count(index="HaystackTestOne") == 8
+ doc_after_update = document_store.get_all_documents(index="HaystackTestOne", filters={"metafield": ["value_7"]})[0]
+ embedding_after_update = doc_after_update.embedding
+ np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
+def test_query_by_embedding(document_store_with_docs):
+ docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32))
+ assert len(docs) == 3
+
+ docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
+ top_k=1)
+ assert len(docs) == 1
+
+ docs = document_store_with_docs.query_by_embedding(np.random.rand(embedding_dim).astype(np.float32),
+ filters = {"name": ['filename2']})
+ assert len(docs) == 1
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
+def test_query(document_store_with_docs):
+ query_text = 'My name is Carla and I live in Berlin'
+ with pytest.raises(Exception):
+ docs = document_store_with_docs.query(query_text)
+
+ docs = document_store_with_docs.query(filters = {"name": ['filename2']})
+ assert len(docs) == 1
+
+ docs = document_store_with_docs.query(filters={"text":[query_text.lower()]})
+ assert len(docs) == 1
+
+ docs = document_store_with_docs.query(filters={"text":['live']})
+ assert len(docs) == 3
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
+def test_delete_all_documents(document_store_with_docs):
+ assert len(document_store_with_docs.get_all_documents()) == 3
+
+ document_store_with_docs.delete_all_documents()
+ documents = document_store_with_docs.get_all_documents()
+ assert len(documents) == 0
+
+@pytest.mark.weaviate
+@pytest.mark.parametrize("document_store_with_docs", ["weaviate"], indirect=True)
+def test_delete_documents_with_filters(document_store_with_docs):
+ document_store_with_docs.delete_all_documents(filters={"metafield": ["test1", "test2"]})
+ documents = document_store_with_docs.get_all_documents()
+ assert len(documents) == 1
+ assert documents[0].meta["metafield"] == "test3"
+